• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/ad/dfunctor.h"
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 
23 #include "mindspore/core/ops/structure_ops.h"
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/math_ops.h"
26 #include "mindspore/core/ops/array_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "ir/anf.h"
29 #include "utils/info.h"
30 #include "ir/func_graph_cloner.h"
31 #include "ir/manager.h"
32 #include "pipeline/jit/ps/resource.h"
33 #include "frontend/optimizer/ad/adjoint.h"
34 #include "frontend/operator/ops.h"
35 #include "utils/symbolic.h"
36 #include "utils/ms_context.h"
37 #include "pipeline/jit/ps/action.h"
38 #include "pipeline/jit/ps/parse/resolve.h"
39 #include "pipeline/pynative/pynative_execute.h"
40 #include "include/common/debug/anf_ir_dump.h"
41 
42 namespace mindspore {
43 namespace ad {
44 mindspore::HashMap<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
45 mindspore::HashMap<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
46 
47 bool lift_fv_before_grad = true;
48 
DFunctor(const FuncGraphPtr & primal_graph,const pipeline::ResourceBasePtr & resources,bool is_top)49 DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources, bool is_top)
50     : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(is_top) {
51   {
52     TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
53     k_graph_ = std::make_shared<FuncGraph>();
54   }
55   // To keep switch or switch_layer's inputs from being inlined
56   k_graph_->set_indirect(primal_graph->indirect());
57   k_graph_->set_stage(primal_graph->stage());
58   k_graph_->set_segment(primal_graph->segment());
59 
60   {
61     TraceGuard guard(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
62     tape_ = std::make_shared<FuncGraph>();
63   }
64   tape_->set_stage(primal_graph->stage());
65   tape_->set_segment(primal_graph->segment());
66 
67   dout_ = tape_->add_parameter();
68   const auto &info = primal_graph->GetEffectInfo();
69   if (is_top_ && info.back_mem) {
70     // Add Umonad arg for top graph.
71     (void)tape_->add_parameter();
72   }
73 }
74 
Init(bool is_top)75 void DFunctor::Init(bool is_top) {
76   func_graph_to_functor_[primal_graph_] = shared_from_this();
77   is_top_ = is_top;
78 }
79 
Finish()80 void DFunctor::Finish() {
81   CallDoutHoleOnTape();
82   EliminatePrimalGraph();
83 }
84 
Clear()85 void DFunctor::Clear() {
86   func_graph_to_functor_.clear();
87   anfnode_to_adjoin_definition_.clear();
88 }
89 
BackPropagateFv(const AnfNodePtr & fv,const AnfNodePtr & din)90 void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
91   MS_EXCEPTION_IF_NULL(fv);
92   if (lift_fv_before_grad) {
93     MS_EXCEPTION_IF_NULL(fv->func_graph());
94     MS_LOG(INTERNAL_EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv:" << fv->func_graph()->ToString()
95                                << " " << fv->ToString() << ".";
96   }
97   auto fv_adjoint = anfnode_to_adjoin_.find(fv);
98   if (fv_adjoint == anfnode_to_adjoin_.end()) {
99     MS_LOG(DEBUG) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " "
100                   << fv->ToString() << ".";
101 
102     if (fv->func_graph() == primal_graph_) {
103       // If this fv is not mapped by MapMorphism because of cnode order, then map it now.
104       (void)MapMorphism(fv);
105       fv_adjoint = anfnode_to_adjoin_.find(fv);
106       if (fv_adjoint == anfnode_to_adjoin_.end()) {
107         MS_LOG(INTERNAL_EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
108                                    << " " << fv->ToString() << ".";
109       }
110     } else {
111       fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
112       if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
113         MS_LOG(DEBUG) << "Can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv " << fv->func_graph()->ToString()
114                       << " " << fv->ToString() << ".";
115         auto parent_adjoint = FindAdjoint(fv);
116         AdjointPtr adjoint = nullptr;
117         if (parent_adjoint != nullptr) {
118           adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
119         } else {
120           MS_LOG(DEBUG) << "Can not find adjoint definition fv, add a k hole " << fv->func_graph()->ToString() << " "
121                         << fv->ToString() << ".";
122           adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
123         }
124         anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
125         fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
126       }
127     }
128   }
129   auto fv_node = fv_adjoint->second->k();
130   auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
131   CNodePtr embed_node;
132   CNodePtr default_val_node;
133   if (cached_envitem_iter != anfnode_to_envitem_.end()) {
134     embed_node = cached_envitem_iter->second.first;
135     default_val_node = cached_envitem_iter->second.second;
136   } else {
137     embed_node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_node});
138     default_val_node = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_node});
139     fv_adjoint->second->RegisterKUser(embed_node, 1);
140     fv_adjoint->second->RegisterKUser(default_val_node, 1);
141     anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node);
142   }
143   auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvironGet), din, embed_node, default_val_node});
144   MS_LOG(DEBUG) << "Find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
145                 << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
146   MS_LOG(DEBUG) << "Get item from " << din->ToString() << " key " << embed_node->ToString() << ".";
147   fv_adjoint->second->AccumulateDout(dfv);
148 }
149 
BackPropagateSwitchLayer(const CNodePtr & cnode_morph,const CNodePtr & env)150 void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) {
151   // Take switch_layer as a set of candidate functions.
152   constexpr size_t input_tuple_index = 2;
153   auto input = cnode_morph->input(input_tuple_index);
154   if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
155     MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
156   }
157   mindspore::HashMap<AnfNodePtr, FuncGraphPtr> node_to_fg;
158   auto tuple_graphs = input->cast_ptr<CNode>();
159   for (size_t i = 1; i < tuple_graphs->size(); ++i) {
160     auto graph = tuple_graphs->input(i);
161     if (!IsValueNode<FuncGraph>(graph)) {
162       MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString()
163                         << " as the " << i << "th element.";
164     }
165     auto func_graph = GetValueNode<FuncGraphPtr>(graph);
166     auto functor = func_graph_to_functor_.find(func_graph);
167     if (functor == func_graph_to_functor_.end()) {
168       MS_LOG(EXCEPTION) << "Failed functor for subgraph does not exist input[" << i << "] " << func_graph->ToString()
169                         << ".";
170     }
171     // Consider direct and indirect fvs.
172     for (auto fv : func_graph->free_variables_nodes()) {
173       if (node_to_fg.find(fv) != node_to_fg.end()) {
174         continue;
175       }
176       node_to_fg[fv] = func_graph;
177       BackPropagateFv(fv, env);
178     }
179     for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
180       MS_LOG(DEBUG) << "Backprop indirect fv " << func_graph->ToString() << " " << indirect_fv.first->ToString() << ".";
181       if (node_to_fg.find(indirect_fv.first) != node_to_fg.end()) {
182         continue;
183       }
184       node_to_fg[indirect_fv.first] = func_graph;
185       BackPropagateFv(indirect_fv.first, env);
186     }
187   }
188 }
189 
HasSideEffectBackProp(const CNodePtr & cnode)190 static bool HasSideEffectBackProp(const CNodePtr &cnode) {
191   if (IsPrimitiveCNode(cnode)) {
192     const auto &prim = GetCNodePrimitive(cnode);
193     MS_EXCEPTION_IF_NULL(prim);
194     auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
195     return bprop_flag;
196   }
197   return false;
198 }
199 
HasSideEffectBackPropMem(const CNodePtr & cnode)200 static bool HasSideEffectBackPropMem(const CNodePtr &cnode) {
201   if (IsPrimitiveCNode(cnode)) {
202     const auto &prim = GetCNodePrimitive(cnode);
203     MS_EXCEPTION_IF_NULL(prim);
204     auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP_MEM);
205     return bprop_flag;
206   }
207   return false;
208 }
209 
SkipHookNodeInBackProp(const AnfNodePtr & node)210 static AnfNodePtr SkipHookNodeInBackProp(const AnfNodePtr &node) {
211   MS_EXCEPTION_IF_NULL(node);
212   if (IsPrimitiveCNode(node, prim::kPrimHookBackward) || IsPrimitiveCNode(node, prim::kPrimCellBackwardHook)) {
213     MS_LOG(WARNING) << "Hook operation does not work in graph mode or functions decorated with 'jit', it will be "
214                        "eliminated during compilation.";
215     auto output_cnode = node->cast_ptr<CNode>();
216     MS_EXCEPTION_IF_NULL(output_cnode);
217     if (output_cnode->size() - 1 == 1) {
218       return output_cnode->input(1);
219     }
220     // Replace hook node with make tuple node.
221     abstract::AbstractBasePtrList multi_output_abs;
222     std::vector<AnfNodePtr> multi_output_nodes{NewValueNode(prim::kPrimMakeTuple)};
223     (void)std::for_each(output_cnode->weak_inputs().cbegin() + 1, output_cnode->weak_inputs().cend(),
224                         [&multi_output_nodes, &multi_output_abs](const AnfNodeWeakPtr &weak_inp) {
225                           AnfNodePtr inp = weak_inp.lock();
226                           MS_EXCEPTION_IF_NULL(inp);
227                           (void)multi_output_nodes.emplace_back(inp);
228                           (void)multi_output_abs.emplace_back(inp->abstract());
229                         });
230     auto primal_graph = node->func_graph();
231     MS_EXCEPTION_IF_NULL(primal_graph);
232     auto make_tuple = primal_graph->NewCNode(std::move(multi_output_nodes));
233     make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(multi_output_abs));
234     auto mng = primal_graph->manager();
235     MS_EXCEPTION_IF_NULL(mng);
236     if (!mng->Replace(node, make_tuple)) {
237       MS_LOG(INTERNAL_EXCEPTION) << "Failed to replace old node: " << node->DebugString()
238                                  << " with new node: " << make_tuple->DebugString();
239     }
240     return make_tuple;
241   }
242   if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
243     auto tuple_get_item = node->cast_ptr<CNode>();
244     MS_EXCEPTION_IF_NULL(tuple_get_item);
245     auto inp = tuple_get_item->input(1);
246     if (IsPrimitiveCNode(inp, prim::kPrimHookBackward) || IsPrimitiveCNode(inp, prim::kPrimCellBackwardHook)) {
247       MS_LOG(WARNING) << "Hook operation does not work in graph mode or functions decorated with 'jit', it will be "
248                          "eliminated during compilation.";
249       constexpr size_t idx = 2;
250       auto v_node = dyn_cast_ptr<ValueNode>(tuple_get_item->input(idx));
251       MS_EXCEPTION_IF_NULL(v_node);
252       auto out_idx = GetValue<int64_t>(v_node->value());
253       auto cnode = inp->cast_ptr<CNode>();
254       MS_EXCEPTION_IF_NULL(cnode);
255       return cnode->input(LongToSize(out_idx) + 1);
256     }
257   }
258   return node;
259 }
260 
HandleRealToComplex(const AnfNodePtr & input,const CNodePtr & din,const FuncGraphPtr & fg)261 AnfNodePtr HandleRealToComplex(const AnfNodePtr &input, const CNodePtr &din, const FuncGraphPtr &fg) {
262   MS_EXCEPTION_IF_NULL(input);
263   TypePtr input_type = input->Type();
264   if (input_type == nullptr || !input_type->isa<TensorType>()) {
265     return din;
266   }
267   input_type = input_type->cast_ptr<TensorType>()->element();
268   MS_EXCEPTION_IF_NULL(input_type);
269   if (input_type->type_id() == kNumberTypeComplex64 || input_type->type_id() == kNumberTypeComplex128) {
270     return din;
271   }
272 
273   MS_EXCEPTION_IF_NULL(din);
274   // If we can not get the dtype of din, we insert real op ignoring din's dtype,
275   // and eliminate it in "real_op_elimiate" pass.
276   MS_EXCEPTION_IF_NULL(fg);
277   if (din->abstract() == nullptr) {
278     return fg->NewCNode({NewValueNode(prim::kPrimRealInner), din});
279   }
280 
281   TypePtr din_type = din->Type();
282   if (din_type == nullptr || !din_type->isa<TensorType>()) {
283     return din;
284   }
285   din_type = din_type->cast_ptr<TensorType>()->element();
286   MS_EXCEPTION_IF_NULL(din_type);
287   if (din_type->type_id() != kNumberTypeComplex64 && din_type->type_id() != kNumberTypeComplex128) {
288     return din;
289   }
290   AnfNodePtr new_din = fg->NewCNode({NewValueNode(prim::kPrimReal), din});
291   AbstractBasePtr abs = std::make_shared<abstract::AbstractTensor>(
292     abstract::AbstractTensor(input_type, input->abstract()->GetShapeTrack()));
293   new_din->set_abstract(abs);
294   return new_din;
295 }
296 
BackPropagate(const CNodePtr & cnode_morph,const CNodePtr & k_app,const AdjointPtr & node_adjoint,bool side_effect_bprop_app_propagate)297 void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint,
298                              bool side_effect_bprop_app_propagate) {
299   auto bprop =
300     k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(1))});
301   // Call with delimited continuation dout.
302   CNodePtr bprop_app;
303   if (HasSideEffectBackProp(cnode_morph)) {
304     // as MapMorphism is called recursively, so the order of bprop_app should reversed as visited order.
305     bprop_app = tape_->NewCNodeInFront({bprop, node_adjoint->dout()});
306     tape_->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
307   } else {
308     bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
309   }
310 
311   if (HasSideEffectBackPropMem(cnode_morph)) {
312     bprop_app->AddAttr(kAttrSideEffectBpropApp, MakeValue(true));
313     k_graph_->set_flag(kAttrSideEffectBpropAppPropagate, true);
314   }
315   if (side_effect_bprop_app_propagate) {
316     bprop_app->AddAttr(kAttrSideEffectBpropAppPropagate, MakeValue(true));
317     k_graph_->set_flag(kAttrSideEffectBpropAppPropagate, true);
318   }
319   node_adjoint->RegisterDoutUser(bprop_app, 1);
320   // Special case for switch_layer
321   if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) {
322     auto din =
323       tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(static_cast<int64_t>(0))});
324     BackPropagateSwitchLayer(cnode_morph, din);
325     return;
326   }
327   for (size_t i = 0; i < cnode_morph->size(); i++) {
328     auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
329     auto input = SkipHookNodeInBackProp(cnode_morph->input(i));
330     auto din_with_real = HandleRealToComplex(input, din, tape_);
331     MS_EXCEPTION_IF_NULL(din_with_real);
332     din = din_with_real->cast<CNodePtr>();
333     // Backprop sens wrt fvs.
334     if (IsValueNode<FuncGraph>(input)) {
335       auto func_graph = GetValueNode<FuncGraphPtr>(input);
336       auto functor = func_graph_to_functor_.find(func_graph);
337       if (functor == func_graph_to_functor_.end()) {
338         MS_LOG(INTERNAL_EXCEPTION) << "Failed functor for subgraph does not exist input[" << i << "] "
339                                    << func_graph->ToString() << ".";
340       }
341       // Consider direct and indirect fvs.
342       for (auto fv : func_graph->free_variables_nodes()) {
343         BackPropagateFv(fv, din);
344       }
345       for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
346         MS_LOG(DEBUG) << "Backprop indirect fv " << func_graph->ToString() << ", " << indirect_fv.first->ToString()
347                       << ".";
348         BackPropagateFv(indirect_fv.first, din);
349       }
350       continue;
351     }
352     // Backprop sens wrt inputs.
353     auto input_adjoint = anfnode_to_adjoin_.find(input);
354     if (input_adjoint == anfnode_to_adjoin_.end()) {
355       MS_LOG(INTERNAL_EXCEPTION) << "The adjoint does not exist input[" << i << "] " << input->ToString()
356                                  << ". primal_graph_: " << primal_graph_->ToString();
357     }
358     input_adjoint->second->AccumulateDout(din);
359   }
360 }
361 
362 // Map a morphism.
MapMorphism(const AnfNodePtr & morph)363 AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
364   constexpr int recursive_level = 4;
365   MS_LOG(DEBUG) << "Start: " << morph->DebugString(recursive_level);
366   // MapMorphism All type except CNode should already be mapped by MapObject.
367   if (!morph->isa<CNode>()) {
368     return nullptr;
369   }
370   // for free variable, which may be handled in MapValueObject, just return it
371   auto node_adjoint_found = anfnode_to_adjoin_.find(morph);
372   if (node_adjoint_found != anfnode_to_adjoin_.end()) {
373     return node_adjoint_found->second;
374   }
375   ScopeGuard scope_guard(morph->scope());
376   auto cnode_morph = morph->cast<CNodePtr>();
377 
378   std::vector<AnfNodePtr> inputs;
379   std::vector<AdjointPtr> param_adjoints;
380   bool side_effect_bprop_app_propagate = false;
381   for (size_t i = 0; i < cnode_morph->size(); i++) {
382     auto node = SkipHookNodeInBackProp(cnode_morph->input(i));
383     AdjointPtr node_adjoint = nullptr;
384     auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
385     if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
386       node_adjoint = node_adjoint_iter->second;
387     } else {
388       // Input might be a CNode that needs to be handled previously.
389       node_adjoint = MapMorphism(node);
390     }
391     if (node_adjoint == nullptr) {
392       MS_LOG(INTERNAL_EXCEPTION) << "The node adjoint is null, " << node->DebugString();
393     }
394     AnfNodePtr k = node_adjoint->k();
395     if (k == nullptr) {
396       MS_LOG(INTERNAL_EXCEPTION) << "The adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
397     }
398     if (i == 0) {
399       auto k_fg = GetValueNode<FuncGraphPtr>(k);
400       if (k_fg != nullptr) {
401         (void)k_fg->transforms().emplace("primal_cnode", FuncGraphTransform(cnode_morph));
402         side_effect_bprop_app_propagate = k_fg->has_flag(kAttrSideEffectBpropAppPropagate);
403       }
404     }
405     inputs.push_back(k);
406     param_adjoints.push_back(node_adjoint);
407   }
408   CNodePtr k_app = nullptr;
409   {
410     TraceGuard guard(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
411     k_app = k_graph_->NewCNode(inputs);
412   }
413   // Run in pynative mode, when @jit is used.
414   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
415     pynative::PyNativeExecutor::GetInstance()->grad_executor()->jit()->ProcessCnodeFromAdGrad(k_app, cnode_morph);
416   }
417 
418   for (size_t i = 0; i < param_adjoints.size(); ++i) {
419     param_adjoints[i]->RegisterKUser(k_app, i);
420   }
421   // Do forward computation
422   auto forward_app =
423     k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(0))});
424   // K:: cnode -> forward_app
425   auto node_adjoint = std::make_shared<Adjoint>(morph, forward_app, tape_);
426   UpdateAdjoint(node_adjoint);
427   anfnode_to_adjoin_[morph] = node_adjoint;
428   if (cnode_morph->stop_gradient()) {
429     MS_LOG(DEBUG) << "The node " << morph->ToString() << " is stopped.";
430     return node_adjoint;
431   }
432 
433   // Do sens backpropagation
434   BackPropagate(cnode_morph, k_app, node_adjoint, side_effect_bprop_app_propagate);
435   MS_LOG(DEBUG) << "End, node: " << morph->DebugString(recursive_level);
436   return node_adjoint;
437 }
438 
IsFreeMorphism(const AnfNodePtr & node)439 bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
440   // Do not care about non-CNode
441   if (!node->isa<CNode>()) {
442     return false;
443   }
444   // Do not care about kPrimReturn
445   if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
446     return false;
447   }
448   MS_EXCEPTION_IF_NULL(primal_graph_->manager());
449   auto &node_users = primal_graph_->manager()->node_users();
450   auto iter = node_users.find(node);
451   if (iter == node_users.end()) {
452     return false;
453   }
454   auto &users = iter->second;
455   // Do not care about isolated morphisms
456   if (users.empty()) {
457     return false;
458   }
459   // Not free if it's used by some node in primal_graph
460   bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
461     auto &user = kv.first;
462     return user->func_graph() == primal_graph_;
463   });
464   return !nonfree;
465 }
466 
MapFreeMorphism()467 void DFunctor::MapFreeMorphism() {
468   // Handle cnode not attached to output, that might be referred in other functions.
469   for (auto &node : primal_graph_->nodes()) {
470     if (!IsFreeMorphism(node)) {
471       continue;
472     }
473     MS_LOG(DEBUG) << "Map nonoutput cnode after MapMorphism " << node->ToString() << ".";
474     (void)MapMorphism(node);
475   }
476 }
477 
AttachFvDoutToTape(const AnfNodePtr & grad_fv)478 AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
479   AnfNodePtr new_grad_fv = grad_fv;
480   // Add grads wrt fv.
481   const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
482   if (!is_top_ && free_variables_nodes.size() != 0) {
483     if (lift_fv_before_grad) {
484       MS_LOG(INTERNAL_EXCEPTION) << "The direct fv size is: " << free_variables_nodes.size() << " in "
485                                  << primal_graph_->ToString() << ".";
486     }
487   }
488 
489   for (auto &fv : free_variables_nodes) {
490     if (IsPrimitiveCNode(fv, prim::kPrimJ)) {  // Ignore if FV is a J CNode.
491       continue;
492     }
493     auto fv_adjoint = anfnode_to_adjoin_.find(fv);
494     if (fv_adjoint == anfnode_to_adjoin_.end()) {
495       MS_LOG(INTERNAL_EXCEPTION) << "The fv adjoint does not exist " << fv->ToString() << ".";
496     }
497     auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
498     fv_adjoint->second->RegisterKUser(node, 1);
499     auto sens = fv_adjoint->second->dout();
500     new_grad_fv = tape_->NewCNode({NewValueNode(prim::kPrimEnvironSet), new_grad_fv, node, sens});
501     constexpr size_t sens_index = 3;
502     fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), sens_index);
503     MS_LOG(DEBUG) << "Add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " " << fv->ToString()
504                   << " " << primal_graph_->ToString() << ".";
505   }
506   return new_grad_fv;
507 }
508 
AttachIndirectFvDoutToTape(const AnfNodePtr & grad_fv)509 AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
510   if (lift_fv_before_grad) {
511     MS_LOG(INTERNAL_EXCEPTION) << "Lift free variable case: backprop indirect fv " << grad_fv->ToString() << " "
512                                << primal_graph_->ToString() << ".";
513   }
514   AnfNodePtr new_grad_fv = grad_fv;
515   // Add indirect fv bprop.
516   for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
517     MS_LOG(DEBUG) << "Backprop indirect fv " << fv_adjoint.first->ToString() << " " << primal_graph_->ToString() << ".";
518     auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()});
519     fv_adjoint.second->RegisterKUser(node, 1);
520     auto sens = fv_adjoint.second->dout();
521     new_grad_fv = tape_->NewCNode({NewValueNode(prim::kPrimEnvironSet), new_grad_fv, node, sens});
522     constexpr size_t sens_index = 3;
523     fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), sens_index);
524     MS_LOG(DEBUG) << "Add indirect fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << ".";
525   }
526   return new_grad_fv;
527 }
528 
MapMorphism()529 void DFunctor::MapMorphism() {
530   // Set stop_gradient before MapMorphism.
531   BroadCastStopFlag();
532 
533   // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
534   MapFreeMorphism();
535   // Skip HookBackward op and CellBackwardHook op when it is the output node.
536   auto output_node = primal_graph_->output();
537   output_node = SkipHookNodeInBackProp(output_node);
538   // Handle morphism from output.
539   // Topo sort all nodes firstly in case of stack overflow fault.
540   auto nodes = TopoSort(output_node, SuccIncoming, [this](const AnfNodePtr &node) -> IncludeType {
541     MS_EXCEPTION_IF_NULL(node);
542     if (node->func_graph() == nullptr || node->func_graph() != primal_graph_ || node->isa<Parameter>()) {
543       return EXCLUDE;
544     }
545     return FOLLOW;
546   });
547   for (const auto &node : nodes) {
548     (void)MapMorphism(SkipHookNodeInBackProp(node));
549   }
550 
551   // Construct K for primal_graph_.
552   auto output_adjoint = anfnode_to_adjoin_.find(output_node);
553   // Attach dout_ parameter to output_adjoint.
554   output_adjoint->second->AccumulateDout(dout_);
555 
556   // Set output for tape closure.
557   AnfNodePtr grad_fv;
558   if (lift_fv_before_grad) {
559     grad_fv = AttachFvDoutToTape(NewEnviron(tape_));
560   } else {
561     grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewEnviron(tape_)));
562   }
563 
564   std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
565   // Add grads wrt inputs.
566   std::vector<AdjointPtr> param_adjoints;
567   for (auto &param : primal_graph_->parameters()) {
568     auto param_adjoint = anfnode_to_adjoin_.find(param);
569     inputs.push_back(param_adjoint->second->dout());
570     param_adjoints.push_back(param_adjoint->second);
571   }
572   auto tape_output = tape_->NewCNode(inputs);
573   constexpr size_t offset_num = 2;
574   for (size_t i = 0; i < param_adjoints.size(); ++i) {
575     param_adjoints[i]->RegisterDoutUser(tape_output, i + offset_num);
576   }
577   tape_->set_output(tape_output);
578   // Set output for k_graph_, K:: cnode->forward_app.
579   auto forward_app = output_adjoint->second->k();
580   auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)});
581   output_adjoint->second->RegisterKUser(output, 1);
582   k_graph_->set_output(output);
583   (void)primal_graph_->transforms().emplace("grad", FuncGraphTransform(k_graph_));
584   (void)k_graph_->transforms().emplace("primal", FuncGraphTransform(primal_graph_));
585 }
586 
KUserDefined(const FuncGraphPtr & primal)587 FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
588   // K user defined cell bprop.
589   auto bprop = primal->transforms().find("bprop");
590   if (bprop != primal->transforms().end()) {
591     FuncGraphPtr bprop_graph = bprop->second.func_graph();
592     resources_->manager()->AddFuncGraph(bprop_graph);
593 
594     (void)parse::ResolveFuncGraph(bprop_graph, resources_);
595     if (!bprop_graph->free_variables_nodes().empty()) {
596       MS_LOG(EXCEPTION) << "The user defined 'bprop' function in scope " << primal->output()->scope()->name()
597                         << " does not support using Parameter.\n"
598                         << trace::GetDebugInfoStr(bprop_graph->debug_info());
599     }
600     // Check the func decorated by @custom_vjp.
601     if (g_k_prims.CheckCustomVjp(bprop_graph)) {
602       bprop_graph = g_k_prims.GetCustomVjpBprop(bprop_graph);
603       bprop->second = FuncGraphTransform(bprop_graph);
604     }
605 
606     bprop_graph->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true);
607     bprop_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
608 
609     auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph, primal);
610     if (fg == nullptr) {
611       MS_LOG(INTERNAL_EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope "
612                                  << primal->output()->scope()->name() << ".";
613     }
614 
615     // Cache the grad func
616     (void)primal->transforms().emplace("grad", FuncGraphTransform(fg));
617     (void)fg->transforms().emplace("primal", FuncGraphTransform(primal));
618     // Reset defer_inline to enable successive inlining
619     primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
620 
621     auto functor = std::make_shared<DFunctor>(primal, resources_, false);
622     functor->Init();
623     functor->k_graph_ = fg;
624 
625     return fg;
626   }
627   return nullptr;
628 }
629 
StopGradientForScalar(const CNodePtr & cnode)630 bool StopGradientForScalar(const CNodePtr &cnode) {
631   auto grad_for_scalar = MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR);
632   if (grad_for_scalar) {
633     return false;
634   }
635   auto abs = cnode->abstract();
636   return abs != nullptr && abs->isa<abstract::AbstractScalar>();
637 }
638 
639 // Construct representation graph for {CNode, Index} of Primitive.
MapPrimitiveToK(const CNodePtr & primitive_user,size_t index)640 AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t index) {
641   auto primal = primitive_user->input(index);
642   if (!IsValueNode<Primitive>(primal)) {
643     MS_LOG(INTERNAL_EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Primitive.";
644   }
645   ScopeGuard scope_guard(primal->scope());
646   // Map Primitive to K
647   auto value_node = primal->cast<ValueNodePtr>();
648   auto prim = GetValueNode<PrimitivePtr>(value_node);
649   if ((prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) ||
650       (prim->Hash() == prim::kPrimUpdateState->Hash() && prim->name() == prim::kPrimUpdateState->name()) ||
651       StopGradientForScalar(primitive_user)) {
652     MS_LOG(DEBUG) << "Should stop gradient for " << prim->ToString();
653     need_cut_ = true;
654   }
655   if (prim->Hash() == prim::kPrimPyExecute->Hash() && prim->name() == prim::kPrimPyExecute->name()) {
656     MS_LOG(WARNING) << "The gradient will be stopped from propagating at the PyExecute node created at the location: "
657                     << trace::GetDebugInfoStr(primitive_user->debug_info());
658     need_cut_ = true;
659   }
660 
661   auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_);
662   if (k_prim != nullptr) {
663     auto prim_recompute_attr = prim->GetAttr(kAttrRecompute);
664     if (prim_recompute_attr != nullptr && prim_recompute_attr->isa<BoolImm>()) {
665       auto recomputed = GetValue<bool>(prim_recompute_attr);
666       if (recomputed) {
667         k_prim->set_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH, true);
668       } else {
669         k_prim->set_flag(FUNC_GRAPH_NOT_RECOMPUTE_K_GRAPH, true);
670       }
671     }
672     return NewValueNode(k_prim);
673   }
674   // When failed to find k_prim, try k_meta.
675   auto k_meta = g_k_prims.KMetaFuncGraph(prim);
676   if (k_meta != nullptr) {
677     return NewValueNode(k_meta);
678   }
679   MS_LOG(INTERNAL_EXCEPTION) << "Fail to map Primitive of \"" << primal->ToString() << "\" to K.";
680 }
681 
682 // Construct representation graph for ValueNode of FuncGraph.
MapFuncGraphToK(const AnfNodePtr & primal)683 AnfNodePtr DFunctor::MapFuncGraphToK(const AnfNodePtr &primal) {
684   if (!IsValueNode<FuncGraph>(primal)) {
685     MS_LOG(INTERNAL_EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of FuncGraph.";
686   }
687   ScopeGuard scope_guard(primal->scope());
688   // Map func graph to K
689   auto func_graph = GetValueNode<FuncGraphPtr>(primal);
690   auto f = func_graph_to_functor_.find(func_graph);
691   if (f != func_graph_to_functor_.end()) {
692     MS_LOG(DEBUG) << "K graph functor already exist " << func_graph->ToString() << ".";
693     return NewValueNode(f->second->k_graph_);
694   }
695   auto k_user_defined = KUserDefined(func_graph);
696   if (k_user_defined != nullptr) {
697     MS_LOG(DEBUG) << "K graph functor user defined bprop " << func_graph->ToString() << ".";
698     return NewValueNode(k_user_defined);
699   }
700   auto functor = std::make_shared<DFunctor>(func_graph, resources_, false);
701   functor->Init();
702   functor->MapObject();
703   functor->MapMorphism();
704 
705   if (func_graph->has_flag(FUNC_GRAPH_FLAG_NO_INLINE)) {
706     functor->k_graph_->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
707   }
708   if (func_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
709     functor->k_graph_->set_flag(FUNC_GRAPH_FLAG_CELL_REUSE, true);
710   }
711   if (func_graph->has_flag(GRAPH_FLAG_IS_WHILE_HEADER)) {
712     functor->k_graph_->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
713     functor->tape_->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
714   }
715   if (func_graph->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
716     functor->k_graph_->set_flag(FUNC_GRAPH_RECOMPUTE_K_GRAPH, true);
717   }
718 
719   MS_LOG(DEBUG) << "Map \"" << func_graph->ToString() << "\" to \"" << functor->k_graph_->ToString() << "\"";
720   return NewValueNode(functor->k_graph_);
721 }
722 
723 // Construct for ValueNode of Parameter.
MapParameterToK(const AnfNodePtr & primal)724 AnfNodePtr DFunctor::MapParameterToK(const AnfNodePtr &primal) {
725   if (!primal->isa<Parameter>()) {
726     MS_LOG(INTERNAL_EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Parameter.";
727   }
728   ScopeGuard scope_guard(primal->scope());
729   // Map Parameter to K
730   TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
731   auto ret = k_graph_->add_parameter();
732   ret->cast_ptr<Parameter>()->set_name(primal->cast_ptr<Parameter>()->name());
733   return ret;
734 }
735 
MapFvObject()736 void DFunctor::MapFvObject() {
737   // Map free variable.
738   const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
739   for (auto &node : free_variables_nodes) {
740     ScopeGuard scope_guard(node->scope());
741     MS_LOG(DEBUG) << "The free variable " << node->ToString() << ".";
742     // Find fv's K from parent.
743     AdjointPtr adjoint = nullptr;
744     auto parent_adjoint = FindAdjoint(node);
745     if (parent_adjoint != nullptr) {
746       adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
747     } else {
748       if (is_top_ || node->isa<Parameter>()) {
749         // Out of ad scope, add adjoint for free variables.
750         adjoint = std::make_shared<Adjoint>(node, node, tape_);
751         UpdateAdjoint(adjoint);
752       } else {
753         MS_LOG(DEBUG) << "Fail to find parent adjoint for nontop fv " << node->ToString() << ".";
754         adjoint = std::make_shared<Adjoint>(node, nullptr, tape_);
755       }
756     }
757     if (adjoint == nullptr) {
758       MS_LOG(INTERNAL_EXCEPTION) << "Failed for free variable " << node->ToString() << ".";
759     }
760     anfnode_to_adjoin_[node] = adjoint;
761   }
762 }
763 
MapParamObject()764 void DFunctor::MapParamObject() {
765   // Map parameter.
766   for (auto &p : primal_graph_->parameters()) {
767     ScopeGuard scope_guard(p->scope());
768     MS_LOG(DEBUG) << "The parameter " << p->ToString() << ".";
769     auto adjoint = std::make_shared<Adjoint>(p, MapParameterToK(p), tape_);
770     UpdateAdjoint(adjoint);
771     anfnode_to_adjoin_[p] = adjoint;
772   }
773 }
774 
MapValueObject()775 void DFunctor::MapValueObject() {
776   // Map ValueNode.
777   auto manager = resources_->manager();
778   auto &value_nodes = primal_graph_->value_nodes();
779   for (const auto &value_pair : value_nodes) {
780     auto node = value_pair.first;
781     auto parent_adjoint = FindAdjoint(node);
782     if (parent_adjoint != nullptr) {
783       auto adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
784       anfnode_to_adjoin_[node] = adjoint;
785       continue;
786     }
787 
788     AdjointPtr adjoint = nullptr;
789     if (IsValueNode<Primitive>(node)) {  // Primitive.
790       auto prim = GetValuePtr<Primitive>(node);
791       MS_EXCEPTION_IF_NULL(prim);
792       if ((prim->Hash() == prim::kPrimReturn->hash() && prim->name() == prim::kPrimReturn->name()) ||
793           (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) ||
794           (prim->Hash() == prim::kPrimCellBackwardHook->Hash() &&
795            prim->name() == prim::kPrimCellBackwardHook->name())) {
796         continue;
797       }
798       MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
799       auto &users = manager->node_users()[node];
800       if (users.size() == 0) {
801         MS_LOG(ERROR) << "\"" << node->DebugString() << "\" has no user.";
802         continue;
803       } else if (users.size() > 1) {
804         MS_LOG(DEBUG) << "\"" << node->DebugString() << "\" supposed to be used once, but users size: " << users.size();
805       }
806       auto cnode = users.begin()->first->cast<CNodePtr>();  // We just use the first user.
807       auto index = users.begin()->second;
808       adjoint = std::make_shared<Adjoint>(node, MapPrimitiveToK(cnode, index), tape_);
809     } else if (IsValueNode<FuncGraph>(node)) {  // FuncGraph
810       MS_LOG(DEBUG) << "Map FuncGraph node " << node->DebugString() << ".";
811       adjoint = std::make_shared<Adjoint>(node, MapFuncGraphToK(node), tape_);
812     } else if (node->isa<Parameter>()) {  // Parameter, hardly reach here.
813       MS_LOG(DEBUG) << "Map Parameter node " << node->DebugString() << ".";
814       adjoint = std::make_shared<Adjoint>(node, MapParameterToK(node), tape_);
815     } else {
816       adjoint = std::make_shared<Adjoint>(node, node, tape_);
817     }
818     UpdateAdjoint(adjoint);
819     anfnode_to_adjoin_[node] = adjoint;
820   }
821 }
822 
823 // Skip morphism.
MapObject()824 void DFunctor::MapObject() {
825   // The order does not matter
826   MapFvObject();
827   MapParamObject();
828   MapValueObject();
829 }
830 
UpdateAdjoint(const AdjointPtr & adjoint_definition)831 void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) {
832   auto primal = adjoint_definition->primal();
833   if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) {
834     MS_LOG(INTERNAL_EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " "
835                                << primal->ToString() << ".";
836   }
837   anfnode_to_adjoin_definition_[primal] = adjoint_definition;
838   // Update k hole for primal.
839   for (auto &f : func_graph_to_functor_) {
840     auto adjoint = f.second->anfnode_to_adjoin_.find(primal);
841     if (adjoint != f.second->anfnode_to_adjoin_.end()) {
842       adjoint->second->UpdateK(adjoint_definition->k());
843     }
844     adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal);
845     if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) {
846       adjoint->second->UpdateK(adjoint_definition->k());
847     }
848   }
849 }
850 
FindAdjoint(const AnfNodePtr & primal) const851 AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) const {
852   auto adjoint = anfnode_to_adjoin_definition_.find(primal);
853   if (adjoint != anfnode_to_adjoin_definition_.end()) {
854     MS_LOG(DEBUG) << "Found adjoint definition for free variable " << primal->ToString() << ".";
855     return adjoint->second;
856   }
857   MS_LOG(DEBUG) << "The adjoint definition for free variable not defined yet " << primal->ToString() << ".";
858   return nullptr;
859 }
860 
CallDoutHoleOnTape() const861 void DFunctor::CallDoutHoleOnTape() const {
862   if (!is_top_) {
863     return;
864   }
865 
866   // Call dout hole of all adjoint.
867   for (auto &f : func_graph_to_functor_) {
868     for (auto &adjoint : f.second->anfnode_to_adjoin_) {
869       adjoint.second->CallDoutHole();
870     }
871     for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) {
872       adjoint.second->CallDoutHole();
873     }
874   }
875 }
876 
k_graph()877 FuncGraphPtr DFunctor::k_graph() { return k_graph_; }
878 
tape()879 FuncGraphPtr DFunctor::tape() { return tape_; }
880 
BroadCastStopFlag()881 void DFunctor::BroadCastStopFlag() {
882   // As stop set expanding, all directly or indirectly stopped CNode will be cut off
883   while (need_cut_) {
884     need_cut_ = false;
885     for (auto &node : primal_graph_->nodes()) {
886       auto cnode = dyn_cast<CNode>(node);
887       if (cnode != nullptr && !cnode->stop_gradient()) {
888         // Cut off the cnode only when it's not referred any more
889         if (cnode->IsApply(prim::kPrimStopGradient) || cnode->IsApply(prim::kPrimUpdateState) ||
890             AllReferencesStopped(cnode) || StopGradientForScalar(cnode) || cnode->IsApply(prim::kPrimPyExecute)) {
891           MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
892           cnode->set_stop_gradient(true);
893           // The stop set changed, more cut required
894           need_cut_ = true;
895         }
896       }
897     }
898   }
899 }
900 
AllReferencesStopped(const CNodePtr & node)901 bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
902   auto &users = primal_graph_->manager()->node_users()[node];
903   // Only care about stop_gradient caused cutting
904   if (users.empty()) {
905     return false;
906   }
907   for (auto &kv : users) {
908     auto &user = kv.first;
909     if (!user->isa<CNode>()) {
910       return false;
911     } else {
912       auto cnode = user->cast_ptr<CNode>();
913       MS_EXCEPTION_IF_NULL(cnode);
914       if (!cnode->stop_gradient()) {
915         return false;
916       }
917     }
918   }
919   return true;
920 }
921 
GetJUser(const NodeUsersMap & node_user_map,const CNodePtr & cnode,int index)922 CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int index) {
923   constexpr auto recursive_level = 2;
924   auto it = node_user_map.find(cnode);
925   if (it == node_user_map.end()) {
926     MS_LOG(INTERNAL_EXCEPTION) << "J CNode not used {" << cnode->DebugString(recursive_level) << "/" << index << "}";
927   }
928   auto &j_users = it->second;
929   auto size = j_users.size();
930   if (size != 1) {
931     bool has_multiple_j_call_user = false;
932     CNodePtr j_call_user = nullptr;
933     for (auto &user : j_users) {
934       // If J CNode is used as a FV, the j_users.size may exceed 1 user. It is allowed.
935       if (user.second == 0) {
936         // Real J CNode call user.
937         if (j_call_user == nullptr) {  // First user.
938           j_call_user = user.first->cast<CNodePtr>();
939         } else {  // More than 1 call user. Not allowed.
940           has_multiple_j_call_user = true;
941         }
942       }
943     }
944     if (has_multiple_j_call_user) {  // Has multiple J CNode call user.
945       std::ostringstream user_info;
946       for (auto &user : j_users) {
947         user_info << "    user: " << user.first->DebugString() << ", index: " << user.second << "\n";
948       }
949 #ifdef ENABLE_DUMP_IR
950       DumpIR("J_User_Ex_" + cnode->func_graph()->ToString() + ".ir", cnode->func_graph());
951 #endif
952       MS_LOG(INTERNAL_EXCEPTION) << "Incorrect J CNode user size: " << size << ", of {"
953                                  << cnode->DebugString(recursive_level) << "/" << index << "}\nUser Info:\n"
954                                  << user_info.str();
955     } else {
956       return j_call_user;
957     }
958   }
959   return j_users.begin()->first->cast<CNodePtr>();
960 }
961 
GetPrimalUser(const CNodePtr & j_user,const std::map<FuncGraphPtr,std::vector<CNodePtr>> & primal_map)962 CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map<FuncGraphPtr, std::vector<CNodePtr>> &primal_map) {
963   // Check if the forward network and the gradient of it are called in the same graph.
964   auto graph = j_user->func_graph();
965   auto iter = primal_map.find(graph);
966   if (iter == primal_map.end()) {
967     // The CNode using the forward graph result and the gradient of the forward graph are not in the same graph.
968     // The EliminatePrimalGraph optimization can not be done. If the code use the forward network and its gradient,
969     // the forward network can not be eliminated. This may cause the decrease of the compilation and running efficiency.
970     MS_LOG(DEBUG) << "The gradient operation of forward network and the forward network are not called in the same"
971                   << " graph. The CNode to use the gradient result is: " << j_user->DebugString()
972                   << " This CNode is in graph: " << graph->ToString();
973     return nullptr;
974   }
975 
976   // Check if there is only one primal call corresponding to the specified j user.
977   auto primal_users = iter->second;
978   if (primal_users.size() != 1) {
979     MS_LOG(WARNING) << "It is recommended to call the forward network only once.";
980     MS_LOG(INFO) << "There is " << primal_users.size()
981                  << " primal calls for same J operation in the same graph. Func graph: " << graph->ToString()
982                  << ", J operation: " << j_user->DebugString() << ", Primal call: ";
983     size_t count = 0;
984     for (const auto &user : primal_users) {
985       MS_LOG(INFO) << "[ " << ++count << " ] : " << user->DebugString(2) << trace::DumpSourceLines(user, false);
986     }
987     return nullptr;
988   }
989 
990   // Check input size.
991   auto primal_user = primal_users[0];
992   if (primal_user->size() != j_user->size()) {
993     MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal_user->DebugString() << " is "
994                     << primal_user->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size();
995     return nullptr;
996   }
997   return primal_user;
998 }
999 
FindPrimalJPair(const FuncGraphManagerPtr & manager,const FuncGraphPtr & primal_graph)1000 static mindspore::HashMap<CNodePtr, std::vector<CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
1001                                                                            const FuncGraphPtr &primal_graph) {
1002   std::vector<CNodePtr> j_users;
1003   std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map;
1004   const auto &node_user_map = manager->node_users();
1005   // Search primal graph user cnodes.
1006   for (auto &entry : primal_graph->func_graph_cnodes_index()) {
1007     auto cnode = entry.first->first->cast<CNodePtr>();
1008     auto index = entry.first->second;
1009     if (index == 0) {
1010       // To find real calling.
1011       auto fg = cnode->func_graph();
1012       MS_EXCEPTION_IF_NULL(fg);
1013       const auto &iter = primal_map.find(fg);
1014       if (iter != primal_map.end()) {
1015         iter->second.push_back(cnode);
1016         continue;
1017       }
1018       primal_map[fg] = {cnode};
1019     } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
1020       // To find J user.
1021       j_users.emplace_back(GetJUser(node_user_map, cnode, index));
1022     }
1023   }
1024 
1025   mindspore::HashMap<CNodePtr, std::vector<CNodePtr>> primal_user_to_j_users;
1026   for (const auto &j_user : j_users) {
1027     MS_EXCEPTION_IF_NULL(j_user);
1028     auto primal = GetPrimalUser(j_user, primal_map);
1029     if (primal == nullptr) {
1030       continue;
1031     }
1032     MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
1033                   << " and J user is: " << j_user->DebugString();
1034     primal_user_to_j_users[primal].emplace_back(j_user);
1035   }
1036   return primal_user_to_j_users;
1037 }
1038 
RemovePrimalUpdateStates(const FuncGraphManagerPtr & manager,const CNodePtr & primal_call)1039 static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) {
1040   auto &node_users = manager->node_users();
1041   auto iter = node_users.find(primal_call);
1042   if (iter == node_users.end()) {
1043     // Skip if user of primal_call not found.
1044     return;
1045   }
1046   // Find UpdateState nodes after the primal call.
1047   std::vector<CNodePtr> update_states;
1048   for (auto &user : iter->second) {
1049     auto &user_node = user.first;
1050     if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
1051       update_states.emplace_back(user_node->cast<CNodePtr>());
1052     }
1053   }
1054   // Remove UpdateStates by replace them with their monad input.
1055   for (auto &update_state : update_states) {
1056     auto &input_monad = update_state->inputs().at(1);
1057     (void)manager->Replace(update_state, input_monad);
1058   }
1059 }
1060 
CopyMonadArguments(const CNodePtr & primal_user,const CNodePtr & j_user,const FuncGraphManagerPtr & manager)1061 static bool CopyMonadArguments(const CNodePtr &primal_user, const CNodePtr &j_user,
1062                                const FuncGraphManagerPtr &manager) {
1063   auto &primal_inputs = primal_user->inputs();
1064   auto &j_user_inputs = j_user->inputs();
1065   bool has_monad = false;
1066   for (size_t i = 1; i < primal_inputs.size(); ++i) {
1067     auto &input = primal_inputs.at(i);
1068     if (HasAbstractMonad(input)) {
1069       // Copy monad input from primal to j_user.
1070       manager->SetEdge(j_user, i, input);
1071       has_monad = true;
1072     } else if (input != j_user_inputs.at(i)) {
1073       // Skip if there are different non-monad inputs.
1074       return false;
1075     }
1076   }
1077   return has_monad;
1078 }
1079 
1080 //
1081 // To replace the primal graph with k graph.
1082 // Convert:
1083 //   x = primal(args, u0)
1084 //   u1 = update_state(u0, x)
1085 //   ...
1086 //   tuple = K(args, u1)
1087 //   u2 = update_state(u1, tuple)
1088 //   ...
1089 // To:
1090 //   tuple = K(args, u0)
1091 //   x = get_item(tuple, 0)
1092 //   ...
1093 //   tuple = K(args, u0)
1094 //   u2 = update_state(u0, tuple)
1095 //   ...
1096 //
EliminatePrimalGraph()1097 void DFunctor::EliminatePrimalGraph() {
1098   // Find primal user and paired J user cnodes.
1099   auto manager = primal_graph_->manager();
1100   MS_EXCEPTION_IF_NULL(manager);
1101   auto primal_user_to_j_users = FindPrimalJPair(manager, primal_graph_);
1102   for (const auto &iter : primal_user_to_j_users) {
1103     auto primal_user = iter.first;
1104     auto &j_users = iter.second;
1105     MS_EXCEPTION_IF_NULL(primal_user);
1106     if (j_users.size() == 1) {
1107       // If both inputs are same except monads, we copy primal monad args to k graph
1108       // so that they can be combined in CSE (common subexpression elimination) pass.
1109       // Only do this when the size of j_users is 1 in order to keep the execution order.
1110       const bool has_monad = CopyMonadArguments(primal_user, j_users[0], manager);
1111       // Remove the UpdateState nodes after primal_user if need.
1112       if (has_monad) {
1113         RemovePrimalUpdateStates(manager, primal_user);
1114       }
1115     } else {
1116       MS_LOG(INFO) << "There are multiple j users with the same primal user " << primal_user->DebugString();
1117     }
1118 
1119     // Replace primal graph with k graph.
1120     auto k_vnode = NewValueNode(k_graph_);
1121     primal_user->set_input(0, k_vnode);
1122     if (j_users.empty()) {
1123       MS_LOG(INTERNAL_EXCEPTION) << "The J nodes for primal graph " << primal_graph_->ToString()
1124                                  << " should be used by at least one other node.";
1125     }
1126     primal_user->set_abstract(j_users[0]->abstract());
1127     // Insert tuple_getitem after primal user cnode.
1128     auto construct_wrapper = primal_user->func_graph();
1129     auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);
1130     auto imm0 = std::make_shared<Int64Imm>(0);
1131     auto idx0 = NewValueNode(SizeToLong(0));
1132     idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
1133     auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0});
1134     getitem0->CloneCNodeInfo(primal_user);
1135     (void)manager->Replace(primal_user, getitem0);
1136   }
1137 }
1138 }  // namespace ad
1139 }  // namespace mindspore
1140