• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/ad/dfunctor.h"
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 
23 #include "ir/anf.h"
24 #include "utils/info.h"
25 #include "ir/func_graph_cloner.h"
26 #include "ir/manager.h"
27 #include "pipeline/jit/resource.h"
28 #include "frontend/optimizer/ad/adjoint.h"
29 #include "frontend/operator/ops.h"
30 #include "utils/symbolic.h"
31 #include "utils/ms_context.h"
32 #include "pipeline/jit/action.h"
33 #include "pipeline/jit/parse/resolve.h"
34 #include "pipeline/pynative/pynative_execute.h"
35 #include "debug/anf_ir_dump.h"
36 
37 namespace mindspore {
38 namespace ad {
39 std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
40 std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
41 
42 bool lift_fv_before_grad = true;
43 
DFunctor(const FuncGraphPtr & primal_graph,const pipeline::ResourceBasePtr & resources)44 DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
45     : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
46   {
47     TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
48     k_graph_ = std::make_shared<FuncGraph>();
49   }
50   // To keep switch or switch_layer's inputs from being inlined
51   k_graph_->set_switch_input(primal_graph->switch_input());
52   k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
53   k_graph_->set_stage(primal_graph->stage());
54 
55   {
56     TraceGuard guard(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
57     tape_ = std::make_shared<FuncGraph>();
58   }
59   tape_->set_stage(primal_graph->stage());
60 
61   dout_ = tape_->add_parameter();
62 }
63 
Init(bool is_top)64 void DFunctor::Init(bool is_top) {
65   func_graph_to_functor_[primal_graph_] = shared_from_this();
66   is_top_ = is_top;
67 }
68 
Finish()69 void DFunctor::Finish() {
70   CallDoutHoleOnTape();
71   EliminatePrimalGraph();
72 }
73 
Clear()74 void DFunctor::Clear() {
75   func_graph_to_functor_.clear();
76   anfnode_to_adjoin_definition_.clear();
77 }
78 
BackPropagateFv(const AnfNodePtr & fv,const AnfNodePtr & din)79 void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
80   MS_EXCEPTION_IF_NULL(fv);
81   if (lift_fv_before_grad) {
82     MS_EXCEPTION_IF_NULL(fv->func_graph());
83     MS_LOG(EXCEPTION) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv:"
84                       << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
85   }
86   auto fv_adjoint = anfnode_to_adjoin_.find(fv);
87   if (fv_adjoint == anfnode_to_adjoin_.end()) {
88     MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
89                   << " " << fv->ToString() << ".";
90 
91     if (fv->func_graph() == primal_graph_) {
92       // If this fv is not mapped by MapMorphism because of cnode order, then map it now.
93       (void)MapMorphism(fv);
94       fv_adjoint = anfnode_to_adjoin_.find(fv);
95       if (fv_adjoint == anfnode_to_adjoin_.end()) {
96         MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " "
97                           << fv->ToString() << ".";
98       }
99     } else {
100       fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
101       if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
102         MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
103                       << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
104         auto parent_adjoint = FindAdjoint(fv);
105         AdjointPtr adjoint = nullptr;
106         if (parent_adjoint != nullptr) {
107           adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
108         } else {
109           MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
110                         << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
111           adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
112         }
113         anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
114         fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
115       }
116     }
117   }
118   auto fv_node = fv_adjoint->second->k();
119   auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
120   CNodePtr embed_node, default_val_node;
121   if (cached_envitem_iter != anfnode_to_envitem_.end()) {
122     embed_node = cached_envitem_iter->second.first;
123     default_val_node = cached_envitem_iter->second.second;
124   } else {
125     embed_node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_node});
126     default_val_node = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_node});
127     fv_adjoint->second->RegisterKUser(embed_node, 1);
128     fv_adjoint->second->RegisterKUser(default_val_node, 1);
129     anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node);
130   }
131   auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node});
132   MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
133                 << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
134   MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << ".";
135   fv_adjoint->second->AccumulateDout(dfv);
136 }
137 
BackPropagateSwitchLayer(const CNodePtr & cnode_morph,const CNodePtr & env)138 void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) {
139   // Take switch_layer as a set of candidate functions.
140   constexpr size_t input_tuple_index = 2;
141   auto input = cnode_morph->input(input_tuple_index);
142   if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
143     MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
144   }
145   std::unordered_map<AnfNodePtr, FuncGraphPtr> node_to_fg;
146   auto tuple_graphs = input->cast<CNodePtr>();
147   for (size_t i = 1; i < tuple_graphs->size(); ++i) {
148     auto graph = tuple_graphs->input(i);
149     if (!IsValueNode<FuncGraph>(graph)) {
150       MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString()
151                         << " as the " << i << "th element.";
152     }
153     auto func_graph = GetValueNode<FuncGraphPtr>(graph);
154     auto functor = func_graph_to_functor_.find(func_graph);
155     if (functor == func_graph_to_functor_.end()) {
156       MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] "
157                         << func_graph->ToString() << ".";
158     }
159     // Consider direct and indirect fvs.
160     for (auto fv : func_graph->free_variables_nodes()) {
161       if (node_to_fg.find(fv) != node_to_fg.end()) {
162         continue;
163       }
164       node_to_fg[fv] = func_graph;
165       BackPropagateFv(fv, env);
166     }
167     for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
168       MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " "
169                     << indirect_fv.first->ToString() << ".";
170       if (node_to_fg.find(indirect_fv.first) != node_to_fg.end()) {
171         continue;
172       }
173       node_to_fg[indirect_fv.first] = func_graph;
174       BackPropagateFv(indirect_fv.first, env);
175     }
176   }
177 }
178 
HasSideEffectBackProp(const CNodePtr & cnode)179 static bool HasSideEffectBackProp(const CNodePtr &cnode) {
180   if (IsPrimitiveCNode(cnode)) {
181     const auto &prim = GetCNodePrimitive(cnode);
182     MS_EXCEPTION_IF_NULL(prim);
183     auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
184     return bprop_flag;
185   }
186   return false;
187 }
188 
BackPropagate(const CNodePtr & cnode_morph,const CNodePtr & k_app,const AdjointPtr & node_adjoint)189 void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
190   auto bprop =
191     k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(1))});
192   // Call with delimited continuation dout.
193   CNodePtr bprop_app;
194   if (HasSideEffectBackProp(cnode_morph)) {
195     // as MapMorphism is called recursively, so the order of bprop_app should reversed as visited order.
196     bprop_app = tape_->NewCNodeInFront({bprop, node_adjoint->dout()});
197     tape_->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
198   } else {
199     bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
200   }
201   node_adjoint->RegisterDoutUser(bprop_app, 1);
202   // Special case for switch_layer
203   if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) {
204     auto din =
205       tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(static_cast<int64_t>(0))});
206     BackPropagateSwitchLayer(cnode_morph, din);
207     return;
208   }
209   for (size_t i = 0; i < cnode_morph->size(); i++) {
210     auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
211     auto input = cnode_morph->input(i);
212     // Skip HookBackward op
213     if (IsPrimitiveCNode(input, prim::kPrimHookBackward)) {
214       auto inp_i = input->cast<CNodePtr>();
215       input = inp_i->input(1);
216     }
217     // Backprop sens wrt fvs.
218     if (IsValueNode<FuncGraph>(input)) {
219       auto func_graph = GetValueNode<FuncGraphPtr>(input);
220       auto functor = func_graph_to_functor_.find(func_graph);
221       if (functor == func_graph_to_functor_.end()) {
222         MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] "
223                           << func_graph->ToString() << ".";
224       }
225       // Consider direct and indirect fvs.
226       for (auto fv : func_graph->free_variables_nodes()) {
227         BackPropagateFv(fv, din);
228       }
229       for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
230         MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " "
231                       << indirect_fv.first->ToString() << ".";
232         BackPropagateFv(indirect_fv.first, din);
233       }
234       continue;
235     }
236     // Backprop sens wrt inputs.
237     auto input_adjoint = anfnode_to_adjoin_.find(input);
238     if (input_adjoint == anfnode_to_adjoin_.end()) {
239       MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << ".";
240     }
241     input_adjoint->second->AccumulateDout(din);
242   }
243 }
244 
245 // Map a morphism.
MapMorphism(const AnfNodePtr & morph)246 AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
247   MS_LOG(DEBUG) << "start MapMorphism:" << morph->DebugString(4);
248   // MapMorphism All type except CNode should already be mapped by MapObject.
249   if (!morph->isa<CNode>()) {
250     return nullptr;
251   }
252   // for free variable, which may be handled in MapValueObject, just return it
253   auto node_adjoint_found = anfnode_to_adjoin_.find(morph);
254   if (node_adjoint_found != anfnode_to_adjoin_.end()) {
255     return node_adjoint_found->second;
256   }
257   ScopeGuard scope_guard(morph->scope());
258   auto cnode_morph = morph->cast<CNodePtr>();
259 
260   std::vector<AnfNodePtr> inputs;
261   std::vector<AdjointPtr> param_adjoints;
262   for (size_t i = 0; i < cnode_morph->size(); i++) {
263     auto node = cnode_morph->input(i);
264     // Skip HookBackward op
265     if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
266       auto input_i = node->cast<CNodePtr>();
267       MS_LOG(WARNING)
268         << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
269       node = input_i->input(1);
270     }
271     AdjointPtr node_adjoint = nullptr;
272     auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
273     if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
274       node_adjoint = node_adjoint_iter->second;
275     } else {
276       // Input might be a CNode that needs to be handled previously.
277       node_adjoint = MapMorphism(node);
278     }
279     MS_EXCEPTION_IF_NULL(node_adjoint);
280     AnfNodePtr k = node_adjoint->k();
281     if (k == nullptr) {
282       MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
283     }
284     inputs.push_back(k);
285     param_adjoints.push_back(node_adjoint);
286   }
287   CNodePtr k_app = nullptr;
288   {
289     TraceGuard guard(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
290     k_app = k_graph_->NewCNode(inputs);
291   }
292   // Run in pynative mode, when @ms_function is used.
293   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
294     auto pynative_exec = pynative::PynativeExecutor::GetInstance();
295     auto grad_exec = pynative_exec->grad_executor();
296     if (grad_exec->eliminate_forward()) {
297       PynativeDFunctor::ReplaceEquivdout(k_app, cnode_morph);
298       cnode_morph->clear_inputs_value();
299     }
300   }
301 
302   for (size_t i = 0; i < param_adjoints.size(); ++i) {
303     param_adjoints[i]->RegisterKUser(k_app, i);
304   }
305   // Do forward computation
306   auto foward_app =
307     k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(0))});
308   // K:: cnode -> forward_app
309   auto node_adjoint = std::make_shared<Adjoint>(morph, foward_app, tape_);
310   UpdateAdjoint(node_adjoint);
311   anfnode_to_adjoin_[morph] = node_adjoint;
312   if (cnode_morph->stop_gradient()) {
313     MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
314     return node_adjoint;
315   }
316 
317   // Do sens backpropagation
318   BackPropagate(cnode_morph, k_app, node_adjoint);
319   MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << ".";
320   return node_adjoint;
321 }
322 
IsFreeMorphism(const AnfNodePtr & node)323 bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
324   // Do not care about non-CNode
325   if (!node->isa<CNode>()) {
326     return false;
327   }
328   // Do not care about kPrimReturn
329   if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
330     return false;
331   }
332   MS_EXCEPTION_IF_NULL(primal_graph_->manager());
333   auto &node_users = primal_graph_->manager()->node_users();
334   auto iter = node_users.find(node);
335   if (iter == node_users.end()) {
336     return false;
337   }
338   auto &users = iter->second;
339   // Do not care about isolated morphisms
340   if (users.empty()) {
341     return false;
342   }
343   // Not free if it's used by some node in primal_graph
344   bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
345     auto &user = kv.first;
346     return user->func_graph() == primal_graph_;
347   });
348   return !nonfree;
349 }
350 
MapFreeMorphism()351 void DFunctor::MapFreeMorphism() {
352   // Handle cnode not attached to output, that might be referred in other functions.
353   for (auto &node : primal_graph_->nodes()) {
354     if (!IsFreeMorphism(node)) {
355       continue;
356     }
357     MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
358     (void)MapMorphism(node);
359   }
360 }
361 
AttachFvDoutToTape(const AnfNodePtr & grad_fv)362 AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
363   AnfNodePtr new_grad_fv = grad_fv;
364   // Add grads wrt fv.
365   const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
366   if (!is_top_ && free_variables_nodes.size() != 0) {
367     if (lift_fv_before_grad) {
368       MS_LOG(EXCEPTION) << "direct fv size is: " << free_variables_nodes.size() << " in " << primal_graph_->ToString()
369                         << ".";
370     }
371   }
372 
373   for (auto &fv : free_variables_nodes) {
374     if (IsPrimitiveCNode(fv, prim::kPrimJ)) {  // Ignore if FV is a J CNode.
375       continue;
376     }
377     auto fv_adjoint = anfnode_to_adjoin_.find(fv);
378     if (fv_adjoint == anfnode_to_adjoin_.end()) {
379       MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
380     }
381     auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
382     fv_adjoint->second->RegisterKUser(node, 1);
383     auto sens = fv_adjoint->second->dout();
384     new_grad_fv = tape_->NewCNode({
385       NewValueNode(prim::kPrimEnvSetItem),
386       new_grad_fv,
387       node,
388       sens,
389     });
390     constexpr size_t sens_index = 3;
391     fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), sens_index);
392     MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " "
393                   << fv->ToString() << " " << primal_graph_->ToString() << ".";
394   }
395   return new_grad_fv;
396 }
397 
AttachIndirectFvDoutToTape(const AnfNodePtr & grad_fv)398 AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
399   if (lift_fv_before_grad) {
400     MS_LOG(EXCEPTION) << "Lift free variable case: AttachIndirectFvDoutToTape backprop indirect fv "
401                       << grad_fv->ToString() << " " << primal_graph_->ToString() << ".";
402   }
403   AnfNodePtr new_grad_fv = grad_fv;
404   // Add indirect fv bprop.
405   for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
406     MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " "
407                   << primal_graph_->ToString() << ".";
408     auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()});
409     fv_adjoint.second->RegisterKUser(node, 1);
410     auto sens = fv_adjoint.second->dout();
411     new_grad_fv = tape_->NewCNode({
412       NewValueNode(prim::kPrimEnvSetItem),
413       new_grad_fv,
414       node,
415       sens,
416     });
417     constexpr size_t sens_index = 3;
418     fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), sens_index);
419     MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to "
420                   << new_grad_fv->ToString() << ".";
421   }
422   return new_grad_fv;
423 }
424 
MapMorphism()425 void DFunctor::MapMorphism() {
426   // Set stop_gradient before MapMorphism.
427   BroadCastStopFlag();
428 
429   // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
430   MapFreeMorphism();
431   // Skip HookBackward when it is the output node.
432   auto output_node = primal_graph_->output();
433   if (IsPrimitiveCNode(output_node, prim::kPrimHookBackward)) {
434     auto output_cnode = output_node->cast<CNodePtr>();
435     MS_LOG(WARNING)
436       << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
437     output_node = output_cnode->input(1);
438   }
439   // Handle morphism from output.
440   (void)MapMorphism(output_node);
441 
442   // Construct K for primal_graph_.
443   auto output_adjoint = anfnode_to_adjoin_.find(output_node);
444   // Attach dout_ parameter to output_adjoint.
445   output_adjoint->second->AccumulateDout(dout_);
446 
447   // Set output for tape closure.
448   AnfNodePtr grad_fv;
449   if (lift_fv_before_grad) {
450     grad_fv = AttachFvDoutToTape(NewValueNode(newenv));
451   } else {
452     grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
453   }
454 
455   std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
456   // Add grads wrt inputs.
457   std::vector<AdjointPtr> param_adjoints;
458   for (auto &param : primal_graph_->parameters()) {
459     auto param_adjoint = anfnode_to_adjoin_.find(param);
460     inputs.push_back(param_adjoint->second->dout());
461     param_adjoints.push_back(param_adjoint->second);
462   }
463   auto tape_output = tape_->NewCNode(inputs);
464   constexpr size_t offset_num = 2;
465   for (size_t i = 0; i < param_adjoints.size(); ++i) {
466     param_adjoints[i]->RegisterDoutUser(tape_output, i + offset_num);
467   }
468   tape_->set_output(tape_output);
469   // Set output for k_graph_, K:: cnode->forward_app.
470   auto forward_app = output_adjoint->second->k();
471   auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)});
472   output_adjoint->second->RegisterKUser(output, 1);
473   k_graph_->set_output(output);
474   (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_)));
475   (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_)));
476 }
477 
KUserDefined(const FuncGraphPtr & primal)478 FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
479   // K user defined cell bprop.
480   auto bprop = primal->transforms().find("bprop");
481   if (bprop != primal->transforms().end()) {
482     FuncGraphPtr bprop_graph = bprop->second.func_graph();
483     resources_->manager()->AddFuncGraph(bprop_graph);
484 
485     if (!bprop_graph->free_variables_nodes().empty() || !primal->free_variables_nodes().empty()) {
486       MS_LOG(EXCEPTION) << "The Cell with user defined 'bprop' function in scope " << primal->output()->scope()->name()
487                         << " does not support Parameter data type.\n"
488                         << trace::GetDebugInfo(bprop_graph->debug_info());
489     }
490     bprop_graph->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true);
491     bprop_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
492 
493     auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph, primal);
494     if (fg == nullptr) {
495       MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope "
496                         << primal->output()->scope()->name() << ".";
497     }
498 
499     // Cache the grad func
500     (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg)));
501     (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal)));
502     // Reset defer_inline to enable successive inlining
503     primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
504 
505     auto functor = std::make_shared<DFunctor>(primal, resources_);
506     functor->Init();
507     functor->k_graph_ = fg;
508 
509     return fg;
510   }
511   return nullptr;
512 }
513 
514 // Construct representation graph for {CNode, Index} of Primitive.
MapPrimitiveToK(const CNodePtr & primitive_user,size_t index)515 AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t index) {
516   auto primal = primitive_user->input(index);
517   if (!IsValueNode<Primitive>(primal)) {
518     MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Primitive.";
519   }
520   ScopeGuard scope_guard(primal->scope());
521   // Map Primitive to K
522   auto value_node = primal->cast<ValueNodePtr>();
523   auto prim = GetValueNode<PrimitivePtr>(value_node);
524   if ((prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) ||
525       (prim->Hash() == prim::kPrimUpdateState->Hash() && prim->name() == prim::kPrimUpdateState->name())) {
526     MS_LOG(DEBUG) << "Should stop gradient for " << prim->ToString();
527     need_cut_ = true;
528   }
529   auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_);
530   if (k_prim != nullptr) {
531     return NewValueNode(k_prim);
532   }
533   // When failed to find k_prim, try k_meta.
534   auto k_meta = g_k_prims.KMetaFuncGraph(prim);
535   if (k_meta != nullptr) {
536     return NewValueNode(k_meta);
537   }
538   MS_LOG(EXCEPTION) << "Fail to map Primitive of \"" << primal->ToString() << "\" to K.";
539 }
540 
541 // Construct representation graph for ValueNode of FuncGraph.
MapFuncGraphToK(const AnfNodePtr & primal)542 AnfNodePtr DFunctor::MapFuncGraphToK(const AnfNodePtr &primal) {
543   if (!IsValueNode<FuncGraph>(primal)) {
544     MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of FuncGraph.";
545   }
546   ScopeGuard scope_guard(primal->scope());
547   // Map func graph to K
548   auto func_graph = GetValueNode<FuncGraphPtr>(primal);
549   auto f = func_graph_to_functor_.find(func_graph);
550   if (f != func_graph_to_functor_.end()) {
551     MS_LOG(DEBUG) << "K graph functor already exist " << func_graph->ToString() << ".";
552     return NewValueNode(f->second->k_graph_);
553   }
554   auto k_user_defined = KUserDefined(func_graph);
555   if (k_user_defined != nullptr) {
556     MS_LOG(DEBUG) << "K graph functor user defined bprop " << func_graph->ToString() << ".";
557     return NewValueNode(k_user_defined);
558   }
559   auto functor = std::make_shared<DFunctor>(func_graph, resources_);
560   functor->Init();
561   functor->MapObject();
562   functor->MapMorphism();
563 
564   MS_LOG(DEBUG) << "Map \"" << func_graph->ToString() << "\" to \"" << functor->k_graph_->ToString() << "\"";
565   return NewValueNode(functor->k_graph_);
566 }
567 
568 // Construct for ValueNode of Parameter.
MapParameterToK(const AnfNodePtr & primal)569 AnfNodePtr DFunctor::MapParameterToK(const AnfNodePtr &primal) {
570   if (!primal->isa<Parameter>()) {
571     MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Parameter.";
572   }
573   ScopeGuard scope_guard(primal->scope());
574   // Map Parameter to K
575   TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
576   auto ret = k_graph_->add_parameter();
577   return ret;
578 }
579 
MapFvObject()580 void DFunctor::MapFvObject() {
581   // Map free variable.
582   const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
583   for (auto &node : free_variables_nodes) {
584     ScopeGuard scope_guard(node->scope());
585     MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << ".";
586     // Find fv's K from parent.
587     AdjointPtr adjoint = nullptr;
588     auto parent_adjoint = FindAdjoint(node);
589     if (parent_adjoint != nullptr) {
590       adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
591     } else {
592       if (is_top_ || node->isa<Parameter>()) {
593         // Out of ad scope, add adjoint for free variables.
594         adjoint = std::make_shared<Adjoint>(node, node, tape_);
595         UpdateAdjoint(adjoint);
596       } else {
597         MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << ".";
598         adjoint = std::make_shared<Adjoint>(node, nullptr, tape_);
599       }
600     }
601     if (adjoint == nullptr) {
602       MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << ".";
603     }
604     anfnode_to_adjoin_[node] = adjoint;
605   }
606 }
607 
MapParamObject()608 void DFunctor::MapParamObject() {
609   // Map parameter.
610   for (auto &p : primal_graph_->parameters()) {
611     ScopeGuard scope_guard(p->scope());
612     MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
613     auto adjoint = std::make_shared<Adjoint>(p, MapParameterToK(p), tape_);
614     UpdateAdjoint(adjoint);
615     anfnode_to_adjoin_[p] = adjoint;
616   }
617 }
618 
MapValueObject()619 void DFunctor::MapValueObject() {
620   // Map ValueNode.
621   auto manager = resources_->manager();
622   auto &value_nodes = primal_graph_->value_nodes();
623   for (const auto &value_pair : value_nodes) {
624     auto node = value_pair.first;
625     auto parent_adjoint = FindAdjoint(node);
626     if (parent_adjoint != nullptr) {
627       auto adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
628       anfnode_to_adjoin_[node] = adjoint;
629       continue;
630     }
631 
632     AdjointPtr adjoint = nullptr;
633     if (IsValueNode<Primitive>(node)) {  // Primitive.
634       auto prim = GetValueNode<PrimitivePtr>(node);
635       if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn ||
636           (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name())) {
637         continue;
638       }
639       MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
640       auto &users = manager->node_users()[node];
641       if (users.size() == 0) {
642         MS_LOG(ERROR) << "\"" << node->DebugString() << "\" has no user.";
643         continue;
644       } else if (users.size() > 1) {
645         MS_LOG(DEBUG) << "\"" << node->DebugString() << "\" supposed to be used once, but users size: " << users.size();
646       }
647       auto cnode = users.begin()->first->cast<CNodePtr>();  // We just use the first user.
648       auto index = users.begin()->second;
649       adjoint = std::make_shared<Adjoint>(node, MapPrimitiveToK(cnode, index), tape_);
650     } else if (IsValueNode<FuncGraph>(node)) {  // FuncGraph
651       MS_LOG(DEBUG) << "Map FuncGraph node " << node->DebugString() << ".";
652       adjoint = std::make_shared<Adjoint>(node, MapFuncGraphToK(node), tape_);
653     } else if (node->isa<Parameter>()) {  // Parameter, hardly reach here.
654       MS_LOG(DEBUG) << "Map Parameter node " << node->DebugString() << ".";
655       adjoint = std::make_shared<Adjoint>(node, MapParameterToK(node), tape_);
656     } else {
657       adjoint = std::make_shared<Adjoint>(node, node, tape_);
658     }
659     UpdateAdjoint(adjoint);
660     anfnode_to_adjoin_[node] = adjoint;
661   }
662 }
663 
664 // Skip morphism.
MapObject()665 void DFunctor::MapObject() {
666   // The order does not matter
667   MapFvObject();
668   MapParamObject();
669   MapValueObject();
670 }
671 
UpdateAdjoint(const AdjointPtr & adjoint_definition)672 void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) {
673   auto primal = adjoint_definition->primal();
674   if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) {
675     MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " "
676                       << primal->ToString() << ".";
677   }
678   anfnode_to_adjoin_definition_[primal] = adjoint_definition;
679   // Update k hole for primal.
680   for (auto &f : func_graph_to_functor_) {
681     auto adjoint = f.second->anfnode_to_adjoin_.find(primal);
682     if (adjoint != f.second->anfnode_to_adjoin_.end()) {
683       adjoint->second->UpdateK(adjoint_definition->k());
684     }
685     adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal);
686     if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) {
687       adjoint->second->UpdateK(adjoint_definition->k());
688     }
689   }
690 }
691 
FindAdjoint(const AnfNodePtr & primal)692 AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
693   auto adjoint = anfnode_to_adjoin_definition_.find(primal);
694   if (adjoint != anfnode_to_adjoin_definition_.end()) {
695     MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << ".";
696     return adjoint->second;
697   }
698   MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << ".";
699   return nullptr;
700 }
701 
CallDoutHoleOnTape()702 void DFunctor::CallDoutHoleOnTape() {
703   if (!is_top_) {
704     return;
705   }
706 
707   // Call dout hole of all adjoint.
708   for (auto &f : func_graph_to_functor_) {
709     for (auto &adjoint : f.second->anfnode_to_adjoin_) {
710       adjoint.second->CallDoutHole();
711     }
712     for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) {
713       adjoint.second->CallDoutHole();
714     }
715   }
716 }
717 
k_graph()718 FuncGraphPtr DFunctor::k_graph() { return k_graph_; }
719 
tape()720 FuncGraphPtr DFunctor::tape() { return tape_; }
721 
BroadCastStopFlag()722 void DFunctor::BroadCastStopFlag() {
723   // As stop set expanding, all directly or indirectly stopped CNode will be cut off
724   while (need_cut_) {
725     need_cut_ = false;
726     for (auto &node : primal_graph_->nodes()) {
727       if (node->isa<CNode>()) {
728         auto cnode = node->cast<CNodePtr>();
729         if (!cnode->stop_gradient()) {
730           // Cut off the cnode only when it's not referred any more
731           if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
732               AllReferencesStopped(cnode)) {
733             MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
734             cnode->set_stop_gradient(true);
735             // The stop set changed, more cut required
736             need_cut_ = true;
737           }
738         }
739       }
740     }
741   }
742 }
743 
AllReferencesStopped(const CNodePtr & node)744 bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
745   auto &users = primal_graph_->manager()->node_users()[node];
746   // Only care about stop_gradient caused cutting
747   if (users.empty()) {
748     return false;
749   }
750   for (auto &kv : users) {
751     auto &user = kv.first;
752     if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
753       return false;
754     }
755   }
756   return true;
757 }
758 
GetJUser(const NodeUsersMap & node_user_map,const CNodePtr & cnode,int index)759 CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int index) {
760   auto it = node_user_map.find(cnode);
761   if (it == node_user_map.end()) {
762     MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}";
763   }
764   auto &j_users = it->second;
765   auto size = j_users.size();
766   if (size != 1) {
767     bool has_multiple_j_call_user = false;
768     CNodePtr j_call_user = nullptr;
769     for (auto &user : j_users) {
770       // If J CNode is used as a FV, the j_users.size may exceed 1 user. It is allowed.
771       if (user.second == 0) {
772         // Real J CNode call user.
773         if (j_call_user == nullptr) {  // First user.
774           j_call_user = user.first->cast<CNodePtr>();
775         } else {  // More than 1 call user. Not allowed.
776           has_multiple_j_call_user = true;
777         }
778       }
779     }
780     if (has_multiple_j_call_user) {  // Has multiple J CNode call user.
781       std::ostringstream user_info;
782       for (auto &user : j_users) {
783         user_info << "    user: " << user.first->DebugString() << ", index: " << user.second << "\n";
784       }
785 #ifdef ENABLE_DUMP_IR
786       DumpIR("J_User_Ex_" + cnode->func_graph()->ToString() + ".ir", cnode->func_graph());
787 #endif
788       MS_LOG(EXCEPTION) << "Incorrect J CNode user size: " << size << ", of {" << cnode->DebugString(2) << "/" << index
789                         << "}\nUser Info:\n"
790                         << user_info.str();
791     } else {
792       return j_call_user;
793     }
794   }
795   return j_users.begin()->first->cast<CNodePtr>();
796 }
797 
GetPrimalUser(const CNodePtr & j_user,const std::map<FuncGraphPtr,std::vector<CNodePtr>> & primal_map)798 CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map<FuncGraphPtr, std::vector<CNodePtr>> &primal_map) {
799   // Check if J operation has relevant primal call in the same graph.
800   auto graph = j_user->func_graph();
801   auto iter = primal_map.find(graph);
802   if (iter == primal_map.end()) {
803     MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString()
804                     << ", J user: " << j_user->DebugString();
805     return nullptr;
806   }
807 
808   // Check if there is only one primal call corresponding to the specified j user.
809   auto primal_users = iter->second;
810   if (primal_users.size() != 1) {
811     MS_LOG(WARNING) << "It is recommended to call the forward network only once.";
812     MS_LOG(INFO) << "There is " << primal_users.size()
813                  << " primal calls for same J operation in the same graph. Func graph: " << graph->ToString()
814                  << ", J operation: " << j_user->DebugString() << ", Primal call: ";
815     size_t count = 0;
816     for (const auto &user : primal_users) {
817       MS_LOG(INFO) << "[ " << ++count << " ] : " << user->DebugString(2) << ", trace: " << trace::DumpSourceLines(user);
818     }
819     return nullptr;
820   }
821 
822   // Check input size.
823   auto primal_user = primal_users[0];
824   if (primal_user->size() != j_user->size()) {
825     MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal_user->DebugString() << " is "
826                     << primal_user->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size();
827     return nullptr;
828   }
829   return primal_user;
830 }
831 
FindPrimalJPair(const FuncGraphManagerPtr & manager,const FuncGraphPtr & primal_graph)832 static std::unordered_map<CNodePtr, std::vector<CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
833                                                                            const FuncGraphPtr &primal_graph) {
834   std::vector<CNodePtr> j_users;
835   std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map;
836   const auto &node_user_map = manager->node_users();
837   // Search primal graph user cnodes.
838   for (auto &entry : primal_graph->func_graph_cnodes_index()) {
839     auto cnode = entry.first->first->cast<CNodePtr>();
840     auto index = entry.first->second;
841     if (index == 0) {
842       // To find real calling.
843       auto fg = cnode->func_graph();
844       MS_EXCEPTION_IF_NULL(fg);
845       auto iter = primal_map.find(fg);
846       if (iter != primal_map.end()) {
847         iter->second.push_back(cnode);
848         continue;
849       }
850       primal_map[fg] = {cnode};
851     } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
852       // To find J user.
853       j_users.emplace_back(GetJUser(node_user_map, cnode, index));
854     }
855   }
856 
857   std::unordered_map<CNodePtr, std::vector<CNodePtr>> primal_user_to_j_users;
858   for (const auto &j_user : j_users) {
859     MS_EXCEPTION_IF_NULL(j_user);
860     auto primal = GetPrimalUser(j_user, primal_map);
861     if (primal == nullptr) {
862       continue;
863     }
864     MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
865                   << " and J user is: " << j_user->DebugString();
866     primal_user_to_j_users[primal].emplace_back(j_user);
867   }
868   return primal_user_to_j_users;
869 }
870 
RemovePrimalUpdateStates(const FuncGraphManagerPtr & manager,const CNodePtr & primal_call)871 static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) {
872   auto &node_users = manager->node_users();
873   auto iter = node_users.find(primal_call);
874   if (iter == node_users.end()) {
875     // Skip if user of primal_call not found.
876     return;
877   }
878   // Find UpdateState nodes after the primal call.
879   std::vector<CNodePtr> update_states;
880   for (auto &user : iter->second) {
881     auto &user_node = user.first;
882     if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
883       update_states.emplace_back(user_node->cast<CNodePtr>());
884     }
885   }
886   // Remove UpdateStates by replace them with their monad input.
887   for (auto &update_state : update_states) {
888     auto &input_monad = update_state->inputs().at(1);
889     manager->Replace(update_state, input_monad);
890   }
891 }
892 
CopyMonadArguments(const CNodePtr & primal_user,const CNodePtr & j_user)893 static bool CopyMonadArguments(const CNodePtr &primal_user, const CNodePtr &j_user) {
894   auto &primal_inputs = primal_user->inputs();
895   auto &j_user_inputs = j_user->inputs();
896   bool has_monad = false;
897   for (size_t i = 1; i < primal_inputs.size(); ++i) {
898     auto &input = primal_inputs.at(i);
899     if (HasAbstractMonad(input)) {
900       // Copy monad input from primal to j_user.
901       j_user->set_input(i, input);
902       has_monad = true;
903     } else if (input != j_user_inputs.at(i)) {
904       // Skip if there are different non-monad inputs.
905       return false;
906     }
907   }
908   return has_monad;
909 }
910 
911 //
912 // To replace the primal graph with k graph.
913 // Convert:
914 //   x = primal(args, u0)
915 //   u1 = update_state(u0, x)
916 //   ...
917 //   tuple = K(args, u1)
918 //   u2 = update_state(u1, tuple)
919 //   ...
920 // To:
921 //   tuple = K(args, u0)
922 //   x = get_item(tuple, 0)
923 //   ...
924 //   tuple = K(args, u0)
925 //   u2 = update_state(u0, tuple)
926 //   ...
927 //
EliminatePrimalGraph()928 void DFunctor::EliminatePrimalGraph() {
929   // Find primal user and paired J user cnodes.
930   auto manager = primal_graph_->manager();
931   MS_EXCEPTION_IF_NULL(manager);
932   auto primal_user_to_j_users = FindPrimalJPair(manager, primal_graph_);
933   for (const auto &iter : primal_user_to_j_users) {
934     auto primal_user = iter.first;
935     auto &j_users = iter.second;
936     MS_EXCEPTION_IF_NULL(primal_user);
937     if (j_users.size() == 1) {
938       // If both inputs are same except monads, we copy primal monad args to k graph
939       // so that they can be combined in CSE (common subexpression elimination) pass.
940       // Only do this when the size of j_users is 1 in order to keep the execution order.
941       const bool has_monad = CopyMonadArguments(primal_user, j_users[0]);
942       // Remove the UpdateState nodes after primal_user if need.
943       if (has_monad) {
944         RemovePrimalUpdateStates(manager, primal_user);
945       }
946     } else {
947       MS_LOG(INFO) << "There are multiple j users with the same primal user " << primal_user->DebugString();
948     }
949 
950     // Replace primal graph with k graph.
951     auto k_vnode = NewValueNode(k_graph_);
952     primal_user->set_input(0, k_vnode);
953     if (j_users.empty()) {
954       MS_LOG(EXCEPTION) << "The J nodes for primal graph " << primal_graph_->ToString()
955                         << " should be used by at least one other node.";
956     }
957     primal_user->set_abstract(j_users[0]->abstract());
958     // Insert tuple_getitem after primal user cnode.
959     auto construct_wrapper = primal_user->func_graph();
960     auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);
961     auto imm0 = std::make_shared<Int64Imm>(0);
962     auto idx0 = NewValueNode(SizeToLong(0));
963     idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
964     auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0});
965     getitem0->CloneCNodeInfo(primal_user);
966     (void)manager->Replace(primal_user, getitem0);
967   }
968 }
969 }  // namespace ad
970 }  // namespace mindspore
971