1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/ad/dfunctor.h"
18
19 #include <map>
20 #include <memory>
21 #include <string>
22
23 #include "ir/anf.h"
24 #include "utils/info.h"
25 #include "ir/func_graph_cloner.h"
26 #include "ir/manager.h"
27 #include "pipeline/jit/resource.h"
28 #include "frontend/optimizer/ad/adjoint.h"
29 #include "frontend/operator/ops.h"
30 #include "utils/symbolic.h"
31 #include "utils/ms_context.h"
32 #include "pipeline/jit/action.h"
33 #include "pipeline/jit/parse/resolve.h"
34 #include "pipeline/pynative/pynative_execute.h"
35 #include "debug/anf_ir_dump.h"
36
37 namespace mindspore {
38 namespace ad {
39 std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
40 std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
41
42 bool lift_fv_before_grad = true;
43
DFunctor(const FuncGraphPtr & primal_graph,const pipeline::ResourceBasePtr & resources)44 DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
45 : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
46 {
47 TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
48 k_graph_ = std::make_shared<FuncGraph>();
49 }
50 // To keep switch or switch_layer's inputs from being inlined
51 k_graph_->set_switch_input(primal_graph->switch_input());
52 k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
53 k_graph_->set_stage(primal_graph->stage());
54
55 {
56 TraceGuard guard(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
57 tape_ = std::make_shared<FuncGraph>();
58 }
59 tape_->set_stage(primal_graph->stage());
60
61 dout_ = tape_->add_parameter();
62 }
63
Init(bool is_top)64 void DFunctor::Init(bool is_top) {
65 func_graph_to_functor_[primal_graph_] = shared_from_this();
66 is_top_ = is_top;
67 }
68
Finish()69 void DFunctor::Finish() {
70 CallDoutHoleOnTape();
71 EliminatePrimalGraph();
72 }
73
Clear()74 void DFunctor::Clear() {
75 func_graph_to_functor_.clear();
76 anfnode_to_adjoin_definition_.clear();
77 }
78
BackPropagateFv(const AnfNodePtr & fv,const AnfNodePtr & din)79 void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
80 MS_EXCEPTION_IF_NULL(fv);
81 if (lift_fv_before_grad) {
82 MS_EXCEPTION_IF_NULL(fv->func_graph());
83 MS_LOG(EXCEPTION) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv:"
84 << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
85 }
86 auto fv_adjoint = anfnode_to_adjoin_.find(fv);
87 if (fv_adjoint == anfnode_to_adjoin_.end()) {
88 MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString()
89 << " " << fv->ToString() << ".";
90
91 if (fv->func_graph() == primal_graph_) {
92 // If this fv is not mapped by MapMorphism because of cnode order, then map it now.
93 (void)MapMorphism(fv);
94 fv_adjoint = anfnode_to_adjoin_.find(fv);
95 if (fv_adjoint == anfnode_to_adjoin_.end()) {
96 MS_LOG(EXCEPTION) << "Can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() << " "
97 << fv->ToString() << ".";
98 }
99 } else {
100 fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
101 if (fv_adjoint == anfnode_to_adjoin_indirect_fv_.end()) {
102 MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_indirect_fv_ fv "
103 << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
104 auto parent_adjoint = FindAdjoint(fv);
105 AdjointPtr adjoint = nullptr;
106 if (parent_adjoint != nullptr) {
107 adjoint = std::make_shared<Adjoint>(fv, parent_adjoint->k(), tape_);
108 } else {
109 MS_LOG(DEBUG) << "BackPropagateFv failed can not find adjoint definition fv, add a k hole "
110 << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
111 adjoint = std::make_shared<Adjoint>(fv, nullptr, tape_);
112 }
113 anfnode_to_adjoin_indirect_fv_[fv] = adjoint;
114 fv_adjoint = anfnode_to_adjoin_indirect_fv_.find(fv);
115 }
116 }
117 }
118 auto fv_node = fv_adjoint->second->k();
119 auto cached_envitem_iter = anfnode_to_envitem_.find(fv_node);
120 CNodePtr embed_node, default_val_node;
121 if (cached_envitem_iter != anfnode_to_envitem_.end()) {
122 embed_node = cached_envitem_iter->second.first;
123 default_val_node = cached_envitem_iter->second.second;
124 } else {
125 embed_node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_node});
126 default_val_node = tape_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), fv_node});
127 fv_adjoint->second->RegisterKUser(embed_node, 1);
128 fv_adjoint->second->RegisterKUser(default_val_node, 1);
129 anfnode_to_envitem_[fv_node] = std::make_pair(embed_node, default_val_node);
130 }
131 auto dfv = tape_->NewCNode({NewValueNode(prim::kPrimEnvGetItem), din, embed_node, default_val_node});
132 MS_LOG(DEBUG) << "BackPropagateFv find adjoint in anfnode_to_adjoin_ or anfnode_to_adjoin_indirect_fv_ fv "
133 << fv->func_graph()->ToString() << " " << fv->ToString() << ".";
134 MS_LOG(DEBUG) << "BackPropagateFv get item from " << din->ToString() << " key " << embed_node->ToString() << ".";
135 fv_adjoint->second->AccumulateDout(dfv);
136 }
137
BackPropagateSwitchLayer(const CNodePtr & cnode_morph,const CNodePtr & env)138 void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) {
139 // Take switch_layer as a set of candidate functions.
140 constexpr size_t input_tuple_index = 2;
141 auto input = cnode_morph->input(input_tuple_index);
142 if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
143 MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
144 }
145 std::unordered_map<AnfNodePtr, FuncGraphPtr> node_to_fg;
146 auto tuple_graphs = input->cast<CNodePtr>();
147 for (size_t i = 1; i < tuple_graphs->size(); ++i) {
148 auto graph = tuple_graphs->input(i);
149 if (!IsValueNode<FuncGraph>(graph)) {
150 MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString()
151 << " as the " << i << "th element.";
152 }
153 auto func_graph = GetValueNode<FuncGraphPtr>(graph);
154 auto functor = func_graph_to_functor_.find(func_graph);
155 if (functor == func_graph_to_functor_.end()) {
156 MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] "
157 << func_graph->ToString() << ".";
158 }
159 // Consider direct and indirect fvs.
160 for (auto fv : func_graph->free_variables_nodes()) {
161 if (node_to_fg.find(fv) != node_to_fg.end()) {
162 continue;
163 }
164 node_to_fg[fv] = func_graph;
165 BackPropagateFv(fv, env);
166 }
167 for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
168 MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " "
169 << indirect_fv.first->ToString() << ".";
170 if (node_to_fg.find(indirect_fv.first) != node_to_fg.end()) {
171 continue;
172 }
173 node_to_fg[indirect_fv.first] = func_graph;
174 BackPropagateFv(indirect_fv.first, env);
175 }
176 }
177 }
178
HasSideEffectBackProp(const CNodePtr & cnode)179 static bool HasSideEffectBackProp(const CNodePtr &cnode) {
180 if (IsPrimitiveCNode(cnode)) {
181 const auto &prim = GetCNodePrimitive(cnode);
182 MS_EXCEPTION_IF_NULL(prim);
183 auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
184 return bprop_flag;
185 }
186 return false;
187 }
188
BackPropagate(const CNodePtr & cnode_morph,const CNodePtr & k_app,const AdjointPtr & node_adjoint)189 void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
190 auto bprop =
191 k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(1))});
192 // Call with delimited continuation dout.
193 CNodePtr bprop_app;
194 if (HasSideEffectBackProp(cnode_morph)) {
195 // as MapMorphism is called recursively, so the order of bprop_app should reversed as visited order.
196 bprop_app = tape_->NewCNodeInFront({bprop, node_adjoint->dout()});
197 tape_->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
198 } else {
199 bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
200 }
201 node_adjoint->RegisterDoutUser(bprop_app, 1);
202 // Special case for switch_layer
203 if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) {
204 auto din =
205 tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(static_cast<int64_t>(0))});
206 BackPropagateSwitchLayer(cnode_morph, din);
207 return;
208 }
209 for (size_t i = 0; i < cnode_morph->size(); i++) {
210 auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))});
211 auto input = cnode_morph->input(i);
212 // Skip HookBackward op
213 if (IsPrimitiveCNode(input, prim::kPrimHookBackward)) {
214 auto inp_i = input->cast<CNodePtr>();
215 input = inp_i->input(1);
216 }
217 // Backprop sens wrt fvs.
218 if (IsValueNode<FuncGraph>(input)) {
219 auto func_graph = GetValueNode<FuncGraphPtr>(input);
220 auto functor = func_graph_to_functor_.find(func_graph);
221 if (functor == func_graph_to_functor_.end()) {
222 MS_LOG(EXCEPTION) << "BackPropagate failed functor for subgraph does not exist input[" << i << "] "
223 << func_graph->ToString() << ".";
224 }
225 // Consider direct and indirect fvs.
226 for (auto fv : func_graph->free_variables_nodes()) {
227 BackPropagateFv(fv, din);
228 }
229 for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
230 MS_LOG(DEBUG) << "BackPropagate backprop indirect fv " << func_graph->ToString() << " "
231 << indirect_fv.first->ToString() << ".";
232 BackPropagateFv(indirect_fv.first, din);
233 }
234 continue;
235 }
236 // Backprop sens wrt inputs.
237 auto input_adjoint = anfnode_to_adjoin_.find(input);
238 if (input_adjoint == anfnode_to_adjoin_.end()) {
239 MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString() << ".";
240 }
241 input_adjoint->second->AccumulateDout(din);
242 }
243 }
244
245 // Map a morphism.
MapMorphism(const AnfNodePtr & morph)246 AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
247 MS_LOG(DEBUG) << "start MapMorphism:" << morph->DebugString(4);
248 // MapMorphism All type except CNode should already be mapped by MapObject.
249 if (!morph->isa<CNode>()) {
250 return nullptr;
251 }
252 // for free variable, which may be handled in MapValueObject, just return it
253 auto node_adjoint_found = anfnode_to_adjoin_.find(morph);
254 if (node_adjoint_found != anfnode_to_adjoin_.end()) {
255 return node_adjoint_found->second;
256 }
257 ScopeGuard scope_guard(morph->scope());
258 auto cnode_morph = morph->cast<CNodePtr>();
259
260 std::vector<AnfNodePtr> inputs;
261 std::vector<AdjointPtr> param_adjoints;
262 for (size_t i = 0; i < cnode_morph->size(); i++) {
263 auto node = cnode_morph->input(i);
264 // Skip HookBackward op
265 if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
266 auto input_i = node->cast<CNodePtr>();
267 MS_LOG(WARNING)
268 << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
269 node = input_i->input(1);
270 }
271 AdjointPtr node_adjoint = nullptr;
272 auto node_adjoint_iter = anfnode_to_adjoin_.find(node);
273 if (node_adjoint_iter != anfnode_to_adjoin_.end()) {
274 node_adjoint = node_adjoint_iter->second;
275 } else {
276 // Input might be a CNode that needs to be handled previously.
277 node_adjoint = MapMorphism(node);
278 }
279 MS_EXCEPTION_IF_NULL(node_adjoint);
280 AnfNodePtr k = node_adjoint->k();
281 if (k == nullptr) {
282 MS_LOG(EXCEPTION) << "MapMorphism adjoint node does not exist, input[" << i << "] " << node->ToString() << ".";
283 }
284 inputs.push_back(k);
285 param_adjoints.push_back(node_adjoint);
286 }
287 CNodePtr k_app = nullptr;
288 {
289 TraceGuard guard(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
290 k_app = k_graph_->NewCNode(inputs);
291 }
292 // Run in pynative mode, when @ms_function is used.
293 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
294 auto pynative_exec = pynative::PynativeExecutor::GetInstance();
295 auto grad_exec = pynative_exec->grad_executor();
296 if (grad_exec->eliminate_forward()) {
297 PynativeDFunctor::ReplaceEquivdout(k_app, cnode_morph);
298 cnode_morph->clear_inputs_value();
299 }
300 }
301
302 for (size_t i = 0; i < param_adjoints.size(); ++i) {
303 param_adjoints[i]->RegisterKUser(k_app, i);
304 }
305 // Do forward computation
306 auto foward_app =
307 k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(0))});
308 // K:: cnode -> forward_app
309 auto node_adjoint = std::make_shared<Adjoint>(morph, foward_app, tape_);
310 UpdateAdjoint(node_adjoint);
311 anfnode_to_adjoin_[morph] = node_adjoint;
312 if (cnode_morph->stop_gradient()) {
313 MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
314 return node_adjoint;
315 }
316
317 // Do sens backpropagation
318 BackPropagate(cnode_morph, k_app, node_adjoint);
319 MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << ".";
320 return node_adjoint;
321 }
322
IsFreeMorphism(const AnfNodePtr & node)323 bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
324 // Do not care about non-CNode
325 if (!node->isa<CNode>()) {
326 return false;
327 }
328 // Do not care about kPrimReturn
329 if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
330 return false;
331 }
332 MS_EXCEPTION_IF_NULL(primal_graph_->manager());
333 auto &node_users = primal_graph_->manager()->node_users();
334 auto iter = node_users.find(node);
335 if (iter == node_users.end()) {
336 return false;
337 }
338 auto &users = iter->second;
339 // Do not care about isolated morphisms
340 if (users.empty()) {
341 return false;
342 }
343 // Not free if it's used by some node in primal_graph
344 bool nonfree = std::any_of(std::begin(users), std::end(users), [&](const auto &kv) {
345 auto &user = kv.first;
346 return user->func_graph() == primal_graph_;
347 });
348 return !nonfree;
349 }
350
MapFreeMorphism()351 void DFunctor::MapFreeMorphism() {
352 // Handle cnode not attached to output, that might be referred in other functions.
353 for (auto &node : primal_graph_->nodes()) {
354 if (!IsFreeMorphism(node)) {
355 continue;
356 }
357 MS_LOG(DEBUG) << "MapFreeMorphism map nonoutput cnode after MapMorphism " << node->ToString() << ".";
358 (void)MapMorphism(node);
359 }
360 }
361
AttachFvDoutToTape(const AnfNodePtr & grad_fv)362 AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) {
363 AnfNodePtr new_grad_fv = grad_fv;
364 // Add grads wrt fv.
365 const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
366 if (!is_top_ && free_variables_nodes.size() != 0) {
367 if (lift_fv_before_grad) {
368 MS_LOG(EXCEPTION) << "direct fv size is: " << free_variables_nodes.size() << " in " << primal_graph_->ToString()
369 << ".";
370 }
371 }
372
373 for (auto &fv : free_variables_nodes) {
374 if (IsPrimitiveCNode(fv, prim::kPrimJ)) { // Ignore if FV is a J CNode.
375 continue;
376 }
377 auto fv_adjoint = anfnode_to_adjoin_.find(fv);
378 if (fv_adjoint == anfnode_to_adjoin_.end()) {
379 MS_LOG(EXCEPTION) << "AttachFvDoutToTape fv adjoint does not exist " << fv->ToString() << ".";
380 }
381 auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint->second->k()});
382 fv_adjoint->second->RegisterKUser(node, 1);
383 auto sens = fv_adjoint->second->dout();
384 new_grad_fv = tape_->NewCNode({
385 NewValueNode(prim::kPrimEnvSetItem),
386 new_grad_fv,
387 node,
388 sens,
389 });
390 constexpr size_t sens_index = 3;
391 fv_adjoint->second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), sens_index);
392 MS_LOG(DEBUG) << "AttachFvDoutToTape add fv sens " << sens->ToString() << " to " << new_grad_fv->ToString() << " "
393 << fv->ToString() << " " << primal_graph_->ToString() << ".";
394 }
395 return new_grad_fv;
396 }
397
AttachIndirectFvDoutToTape(const AnfNodePtr & grad_fv)398 AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) {
399 if (lift_fv_before_grad) {
400 MS_LOG(EXCEPTION) << "Lift free variable case: AttachIndirectFvDoutToTape backprop indirect fv "
401 << grad_fv->ToString() << " " << primal_graph_->ToString() << ".";
402 }
403 AnfNodePtr new_grad_fv = grad_fv;
404 // Add indirect fv bprop.
405 for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) {
406 MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape backprop indirect fv " << fv_adjoint.first->ToString() << " "
407 << primal_graph_->ToString() << ".";
408 auto node = tape_->NewCNode({NewValueNode(prim::kPrimEmbed), fv_adjoint.second->k()});
409 fv_adjoint.second->RegisterKUser(node, 1);
410 auto sens = fv_adjoint.second->dout();
411 new_grad_fv = tape_->NewCNode({
412 NewValueNode(prim::kPrimEnvSetItem),
413 new_grad_fv,
414 node,
415 sens,
416 });
417 constexpr size_t sens_index = 3;
418 fv_adjoint.second->RegisterDoutUser(new_grad_fv->cast<CNodePtr>(), sens_index);
419 MS_LOG(DEBUG) << "AttachIndirectFvDoutToTape add indirect fv sens " << sens->ToString() << " to "
420 << new_grad_fv->ToString() << ".";
421 }
422 return new_grad_fv;
423 }
424
MapMorphism()425 void DFunctor::MapMorphism() {
426 // Set stop_gradient before MapMorphism.
427 BroadCastStopFlag();
428
429 // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent
430 MapFreeMorphism();
431 // Skip HookBackward when it is the output node.
432 auto output_node = primal_graph_->output();
433 if (IsPrimitiveCNode(output_node, prim::kPrimHookBackward)) {
434 auto output_cnode = output_node->cast<CNodePtr>();
435 MS_LOG(WARNING)
436 << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
437 output_node = output_cnode->input(1);
438 }
439 // Handle morphism from output.
440 (void)MapMorphism(output_node);
441
442 // Construct K for primal_graph_.
443 auto output_adjoint = anfnode_to_adjoin_.find(output_node);
444 // Attach dout_ parameter to output_adjoint.
445 output_adjoint->second->AccumulateDout(dout_);
446
447 // Set output for tape closure.
448 AnfNodePtr grad_fv;
449 if (lift_fv_before_grad) {
450 grad_fv = AttachFvDoutToTape(NewValueNode(newenv));
451 } else {
452 grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv)));
453 }
454
455 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv};
456 // Add grads wrt inputs.
457 std::vector<AdjointPtr> param_adjoints;
458 for (auto ¶m : primal_graph_->parameters()) {
459 auto param_adjoint = anfnode_to_adjoin_.find(param);
460 inputs.push_back(param_adjoint->second->dout());
461 param_adjoints.push_back(param_adjoint->second);
462 }
463 auto tape_output = tape_->NewCNode(inputs);
464 constexpr size_t offset_num = 2;
465 for (size_t i = 0; i < param_adjoints.size(); ++i) {
466 param_adjoints[i]->RegisterDoutUser(tape_output, i + offset_num);
467 }
468 tape_->set_output(tape_output);
469 // Set output for k_graph_, K:: cnode->forward_app.
470 auto forward_app = output_adjoint->second->k();
471 auto output = k_graph_->NewCNode({NewValueNode(prim::kPrimMakeTuple), forward_app, NewValueNode(tape_)});
472 output_adjoint->second->RegisterKUser(output, 1);
473 k_graph_->set_output(output);
474 (void)primal_graph_->transforms().insert(std::make_pair("grad", FuncGraphTransform(k_graph_)));
475 (void)k_graph_->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal_graph_)));
476 }
477
KUserDefined(const FuncGraphPtr & primal)478 FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) {
479 // K user defined cell bprop.
480 auto bprop = primal->transforms().find("bprop");
481 if (bprop != primal->transforms().end()) {
482 FuncGraphPtr bprop_graph = bprop->second.func_graph();
483 resources_->manager()->AddFuncGraph(bprop_graph);
484
485 if (!bprop_graph->free_variables_nodes().empty() || !primal->free_variables_nodes().empty()) {
486 MS_LOG(EXCEPTION) << "The Cell with user defined 'bprop' function in scope " << primal->output()->scope()->name()
487 << " does not support Parameter data type.\n"
488 << trace::GetDebugInfo(bprop_graph->debug_info());
489 }
490 bprop_graph->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true);
491 bprop_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
492
493 auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph, primal);
494 if (fg == nullptr) {
495 MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope "
496 << primal->output()->scope()->name() << ".";
497 }
498
499 // Cache the grad func
500 (void)primal->transforms().insert(std::make_pair("grad", FuncGraphTransform(fg)));
501 (void)fg->transforms().insert(std::make_pair("primal", FuncGraphTransform(primal)));
502 // Reset defer_inline to enable successive inlining
503 primal->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
504
505 auto functor = std::make_shared<DFunctor>(primal, resources_);
506 functor->Init();
507 functor->k_graph_ = fg;
508
509 return fg;
510 }
511 return nullptr;
512 }
513
514 // Construct representation graph for {CNode, Index} of Primitive.
MapPrimitiveToK(const CNodePtr & primitive_user,size_t index)515 AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t index) {
516 auto primal = primitive_user->input(index);
517 if (!IsValueNode<Primitive>(primal)) {
518 MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Primitive.";
519 }
520 ScopeGuard scope_guard(primal->scope());
521 // Map Primitive to K
522 auto value_node = primal->cast<ValueNodePtr>();
523 auto prim = GetValueNode<PrimitivePtr>(value_node);
524 if ((prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) ||
525 (prim->Hash() == prim::kPrimUpdateState->Hash() && prim->name() == prim::kPrimUpdateState->name())) {
526 MS_LOG(DEBUG) << "Should stop gradient for " << prim->ToString();
527 need_cut_ = true;
528 }
529 auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_);
530 if (k_prim != nullptr) {
531 return NewValueNode(k_prim);
532 }
533 // When failed to find k_prim, try k_meta.
534 auto k_meta = g_k_prims.KMetaFuncGraph(prim);
535 if (k_meta != nullptr) {
536 return NewValueNode(k_meta);
537 }
538 MS_LOG(EXCEPTION) << "Fail to map Primitive of \"" << primal->ToString() << "\" to K.";
539 }
540
541 // Construct representation graph for ValueNode of FuncGraph.
MapFuncGraphToK(const AnfNodePtr & primal)542 AnfNodePtr DFunctor::MapFuncGraphToK(const AnfNodePtr &primal) {
543 if (!IsValueNode<FuncGraph>(primal)) {
544 MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of FuncGraph.";
545 }
546 ScopeGuard scope_guard(primal->scope());
547 // Map func graph to K
548 auto func_graph = GetValueNode<FuncGraphPtr>(primal);
549 auto f = func_graph_to_functor_.find(func_graph);
550 if (f != func_graph_to_functor_.end()) {
551 MS_LOG(DEBUG) << "K graph functor already exist " << func_graph->ToString() << ".";
552 return NewValueNode(f->second->k_graph_);
553 }
554 auto k_user_defined = KUserDefined(func_graph);
555 if (k_user_defined != nullptr) {
556 MS_LOG(DEBUG) << "K graph functor user defined bprop " << func_graph->ToString() << ".";
557 return NewValueNode(k_user_defined);
558 }
559 auto functor = std::make_shared<DFunctor>(func_graph, resources_);
560 functor->Init();
561 functor->MapObject();
562 functor->MapMorphism();
563
564 MS_LOG(DEBUG) << "Map \"" << func_graph->ToString() << "\" to \"" << functor->k_graph_->ToString() << "\"";
565 return NewValueNode(functor->k_graph_);
566 }
567
568 // Construct for ValueNode of Parameter.
MapParameterToK(const AnfNodePtr & primal)569 AnfNodePtr DFunctor::MapParameterToK(const AnfNodePtr &primal) {
570 if (!primal->isa<Parameter>()) {
571 MS_LOG(EXCEPTION) << "Primal graph \"" << primal->ToString() << "\" is not a ValueNode of Parameter.";
572 }
573 ScopeGuard scope_guard(primal->scope());
574 // Map Parameter to K
575 TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal->debug_info()));
576 auto ret = k_graph_->add_parameter();
577 return ret;
578 }
579
MapFvObject()580 void DFunctor::MapFvObject() {
581 // Map free variable.
582 const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
583 for (auto &node : free_variables_nodes) {
584 ScopeGuard scope_guard(node->scope());
585 MS_LOG(DEBUG) << "MapFvObject free variable " << node->ToString() << ".";
586 // Find fv's K from parent.
587 AdjointPtr adjoint = nullptr;
588 auto parent_adjoint = FindAdjoint(node);
589 if (parent_adjoint != nullptr) {
590 adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
591 } else {
592 if (is_top_ || node->isa<Parameter>()) {
593 // Out of ad scope, add adjoint for free variables.
594 adjoint = std::make_shared<Adjoint>(node, node, tape_);
595 UpdateAdjoint(adjoint);
596 } else {
597 MS_LOG(DEBUG) << "MapFvObject fail to find parent adjoint for nontop fv " << node->ToString() << ".";
598 adjoint = std::make_shared<Adjoint>(node, nullptr, tape_);
599 }
600 }
601 if (adjoint == nullptr) {
602 MS_LOG(EXCEPTION) << "MapFvObject failed for free variable " << node->ToString() << ".";
603 }
604 anfnode_to_adjoin_[node] = adjoint;
605 }
606 }
607
MapParamObject()608 void DFunctor::MapParamObject() {
609 // Map parameter.
610 for (auto &p : primal_graph_->parameters()) {
611 ScopeGuard scope_guard(p->scope());
612 MS_LOG(DEBUG) << "MapParamObject parameter " << p->ToString() << ".";
613 auto adjoint = std::make_shared<Adjoint>(p, MapParameterToK(p), tape_);
614 UpdateAdjoint(adjoint);
615 anfnode_to_adjoin_[p] = adjoint;
616 }
617 }
618
MapValueObject()619 void DFunctor::MapValueObject() {
620 // Map ValueNode.
621 auto manager = resources_->manager();
622 auto &value_nodes = primal_graph_->value_nodes();
623 for (const auto &value_pair : value_nodes) {
624 auto node = value_pair.first;
625 auto parent_adjoint = FindAdjoint(node);
626 if (parent_adjoint != nullptr) {
627 auto adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
628 anfnode_to_adjoin_[node] = adjoint;
629 continue;
630 }
631
632 AdjointPtr adjoint = nullptr;
633 if (IsValueNode<Primitive>(node)) { // Primitive.
634 auto prim = GetValueNode<PrimitivePtr>(node);
635 if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn ||
636 (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name())) {
637 continue;
638 }
639 MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << ".";
640 auto &users = manager->node_users()[node];
641 if (users.size() == 0) {
642 MS_LOG(ERROR) << "\"" << node->DebugString() << "\" has no user.";
643 continue;
644 } else if (users.size() > 1) {
645 MS_LOG(DEBUG) << "\"" << node->DebugString() << "\" supposed to be used once, but users size: " << users.size();
646 }
647 auto cnode = users.begin()->first->cast<CNodePtr>(); // We just use the first user.
648 auto index = users.begin()->second;
649 adjoint = std::make_shared<Adjoint>(node, MapPrimitiveToK(cnode, index), tape_);
650 } else if (IsValueNode<FuncGraph>(node)) { // FuncGraph
651 MS_LOG(DEBUG) << "Map FuncGraph node " << node->DebugString() << ".";
652 adjoint = std::make_shared<Adjoint>(node, MapFuncGraphToK(node), tape_);
653 } else if (node->isa<Parameter>()) { // Parameter, hardly reach here.
654 MS_LOG(DEBUG) << "Map Parameter node " << node->DebugString() << ".";
655 adjoint = std::make_shared<Adjoint>(node, MapParameterToK(node), tape_);
656 } else {
657 adjoint = std::make_shared<Adjoint>(node, node, tape_);
658 }
659 UpdateAdjoint(adjoint);
660 anfnode_to_adjoin_[node] = adjoint;
661 }
662 }
663
664 // Skip morphism.
MapObject()665 void DFunctor::MapObject() {
666 // The order does not matter
667 MapFvObject();
668 MapParamObject();
669 MapValueObject();
670 }
671
UpdateAdjoint(const AdjointPtr & adjoint_definition)672 void DFunctor::UpdateAdjoint(const AdjointPtr &adjoint_definition) {
673 auto primal = adjoint_definition->primal();
674 if (anfnode_to_adjoin_definition_.find(primal) != anfnode_to_adjoin_definition_.end()) {
675 MS_LOG(EXCEPTION) << "UpdateAdjoint adjoint definition already exists " << primal_graph_->ToString() << " "
676 << primal->ToString() << ".";
677 }
678 anfnode_to_adjoin_definition_[primal] = adjoint_definition;
679 // Update k hole for primal.
680 for (auto &f : func_graph_to_functor_) {
681 auto adjoint = f.second->anfnode_to_adjoin_.find(primal);
682 if (adjoint != f.second->anfnode_to_adjoin_.end()) {
683 adjoint->second->UpdateK(adjoint_definition->k());
684 }
685 adjoint = f.second->anfnode_to_adjoin_indirect_fv_.find(primal);
686 if (adjoint != f.second->anfnode_to_adjoin_indirect_fv_.end()) {
687 adjoint->second->UpdateK(adjoint_definition->k());
688 }
689 }
690 }
691
FindAdjoint(const AnfNodePtr & primal)692 AdjointPtr DFunctor::FindAdjoint(const AnfNodePtr &primal) {
693 auto adjoint = anfnode_to_adjoin_definition_.find(primal);
694 if (adjoint != anfnode_to_adjoin_definition_.end()) {
695 MS_LOG(DEBUG) << "FindAdjoint found adjoint definition for free variable " << primal->ToString() << ".";
696 return adjoint->second;
697 }
698 MS_LOG(DEBUG) << "FindAdjoint adjoint definition for free variable not defined yet " << primal->ToString() << ".";
699 return nullptr;
700 }
701
CallDoutHoleOnTape()702 void DFunctor::CallDoutHoleOnTape() {
703 if (!is_top_) {
704 return;
705 }
706
707 // Call dout hole of all adjoint.
708 for (auto &f : func_graph_to_functor_) {
709 for (auto &adjoint : f.second->anfnode_to_adjoin_) {
710 adjoint.second->CallDoutHole();
711 }
712 for (auto &adjoint : f.second->anfnode_to_adjoin_indirect_fv_) {
713 adjoint.second->CallDoutHole();
714 }
715 }
716 }
717
k_graph()718 FuncGraphPtr DFunctor::k_graph() { return k_graph_; }
719
tape()720 FuncGraphPtr DFunctor::tape() { return tape_; }
721
BroadCastStopFlag()722 void DFunctor::BroadCastStopFlag() {
723 // As stop set expanding, all directly or indirectly stopped CNode will be cut off
724 while (need_cut_) {
725 need_cut_ = false;
726 for (auto &node : primal_graph_->nodes()) {
727 if (node->isa<CNode>()) {
728 auto cnode = node->cast<CNodePtr>();
729 if (!cnode->stop_gradient()) {
730 // Cut off the cnode only when it's not referred any more
731 if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
732 AllReferencesStopped(cnode)) {
733 MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << ".";
734 cnode->set_stop_gradient(true);
735 // The stop set changed, more cut required
736 need_cut_ = true;
737 }
738 }
739 }
740 }
741 }
742 }
743
AllReferencesStopped(const CNodePtr & node)744 bool DFunctor::AllReferencesStopped(const CNodePtr &node) {
745 auto &users = primal_graph_->manager()->node_users()[node];
746 // Only care about stop_gradient caused cutting
747 if (users.empty()) {
748 return false;
749 }
750 for (auto &kv : users) {
751 auto &user = kv.first;
752 if (!user->isa<CNode>() || !user->cast<CNodePtr>()->stop_gradient()) {
753 return false;
754 }
755 }
756 return true;
757 }
758
GetJUser(const NodeUsersMap & node_user_map,const CNodePtr & cnode,int index)759 CNodePtr GetJUser(const NodeUsersMap &node_user_map, const CNodePtr &cnode, int index) {
760 auto it = node_user_map.find(cnode);
761 if (it == node_user_map.end()) {
762 MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}";
763 }
764 auto &j_users = it->second;
765 auto size = j_users.size();
766 if (size != 1) {
767 bool has_multiple_j_call_user = false;
768 CNodePtr j_call_user = nullptr;
769 for (auto &user : j_users) {
770 // If J CNode is used as a FV, the j_users.size may exceed 1 user. It is allowed.
771 if (user.second == 0) {
772 // Real J CNode call user.
773 if (j_call_user == nullptr) { // First user.
774 j_call_user = user.first->cast<CNodePtr>();
775 } else { // More than 1 call user. Not allowed.
776 has_multiple_j_call_user = true;
777 }
778 }
779 }
780 if (has_multiple_j_call_user) { // Has multiple J CNode call user.
781 std::ostringstream user_info;
782 for (auto &user : j_users) {
783 user_info << " user: " << user.first->DebugString() << ", index: " << user.second << "\n";
784 }
785 #ifdef ENABLE_DUMP_IR
786 DumpIR("J_User_Ex_" + cnode->func_graph()->ToString() + ".ir", cnode->func_graph());
787 #endif
788 MS_LOG(EXCEPTION) << "Incorrect J CNode user size: " << size << ", of {" << cnode->DebugString(2) << "/" << index
789 << "}\nUser Info:\n"
790 << user_info.str();
791 } else {
792 return j_call_user;
793 }
794 }
795 return j_users.begin()->first->cast<CNodePtr>();
796 }
797
GetPrimalUser(const CNodePtr & j_user,const std::map<FuncGraphPtr,std::vector<CNodePtr>> & primal_map)798 CNodePtr GetPrimalUser(const CNodePtr &j_user, const std::map<FuncGraphPtr, std::vector<CNodePtr>> &primal_map) {
799 // Check if J operation has relevant primal call in the same graph.
800 auto graph = j_user->func_graph();
801 auto iter = primal_map.find(graph);
802 if (iter == primal_map.end()) {
803 MS_LOG(WARNING) << "J operation has no relevant primal call in the same graph. Func graph: " << graph->ToString()
804 << ", J user: " << j_user->DebugString();
805 return nullptr;
806 }
807
808 // Check if there is only one primal call corresponding to the specified j user.
809 auto primal_users = iter->second;
810 if (primal_users.size() != 1) {
811 MS_LOG(WARNING) << "It is recommended to call the forward network only once.";
812 MS_LOG(INFO) << "There is " << primal_users.size()
813 << " primal calls for same J operation in the same graph. Func graph: " << graph->ToString()
814 << ", J operation: " << j_user->DebugString() << ", Primal call: ";
815 size_t count = 0;
816 for (const auto &user : primal_users) {
817 MS_LOG(INFO) << "[ " << ++count << " ] : " << user->DebugString(2) << ", trace: " << trace::DumpSourceLines(user);
818 }
819 return nullptr;
820 }
821
822 // Check input size.
823 auto primal_user = primal_users[0];
824 if (primal_user->size() != j_user->size()) {
825 MS_LOG(WARNING) << "Input size incorrect, the input size of primal " << primal_user->DebugString() << " is "
826 << primal_user->size() << ", and J user " << j_user->DebugString() << " is " << j_user->size();
827 return nullptr;
828 }
829 return primal_user;
830 }
831
FindPrimalJPair(const FuncGraphManagerPtr & manager,const FuncGraphPtr & primal_graph)832 static std::unordered_map<CNodePtr, std::vector<CNodePtr>> FindPrimalJPair(const FuncGraphManagerPtr &manager,
833 const FuncGraphPtr &primal_graph) {
834 std::vector<CNodePtr> j_users;
835 std::map<FuncGraphPtr, std::vector<CNodePtr>> primal_map;
836 const auto &node_user_map = manager->node_users();
837 // Search primal graph user cnodes.
838 for (auto &entry : primal_graph->func_graph_cnodes_index()) {
839 auto cnode = entry.first->first->cast<CNodePtr>();
840 auto index = entry.first->second;
841 if (index == 0) {
842 // To find real calling.
843 auto fg = cnode->func_graph();
844 MS_EXCEPTION_IF_NULL(fg);
845 auto iter = primal_map.find(fg);
846 if (iter != primal_map.end()) {
847 iter->second.push_back(cnode);
848 continue;
849 }
850 primal_map[fg] = {cnode};
851 } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) {
852 // To find J user.
853 j_users.emplace_back(GetJUser(node_user_map, cnode, index));
854 }
855 }
856
857 std::unordered_map<CNodePtr, std::vector<CNodePtr>> primal_user_to_j_users;
858 for (const auto &j_user : j_users) {
859 MS_EXCEPTION_IF_NULL(j_user);
860 auto primal = GetPrimalUser(j_user, primal_map);
861 if (primal == nullptr) {
862 continue;
863 }
864 MS_LOG(DEBUG) << "Primal_J pair is found, where primal is: " << primal->DebugString()
865 << " and J user is: " << j_user->DebugString();
866 primal_user_to_j_users[primal].emplace_back(j_user);
867 }
868 return primal_user_to_j_users;
869 }
870
RemovePrimalUpdateStates(const FuncGraphManagerPtr & manager,const CNodePtr & primal_call)871 static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) {
872 auto &node_users = manager->node_users();
873 auto iter = node_users.find(primal_call);
874 if (iter == node_users.end()) {
875 // Skip if user of primal_call not found.
876 return;
877 }
878 // Find UpdateState nodes after the primal call.
879 std::vector<CNodePtr> update_states;
880 for (auto &user : iter->second) {
881 auto &user_node = user.first;
882 if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) {
883 update_states.emplace_back(user_node->cast<CNodePtr>());
884 }
885 }
886 // Remove UpdateStates by replace them with their monad input.
887 for (auto &update_state : update_states) {
888 auto &input_monad = update_state->inputs().at(1);
889 manager->Replace(update_state, input_monad);
890 }
891 }
892
CopyMonadArguments(const CNodePtr & primal_user,const CNodePtr & j_user)893 static bool CopyMonadArguments(const CNodePtr &primal_user, const CNodePtr &j_user) {
894 auto &primal_inputs = primal_user->inputs();
895 auto &j_user_inputs = j_user->inputs();
896 bool has_monad = false;
897 for (size_t i = 1; i < primal_inputs.size(); ++i) {
898 auto &input = primal_inputs.at(i);
899 if (HasAbstractMonad(input)) {
900 // Copy monad input from primal to j_user.
901 j_user->set_input(i, input);
902 has_monad = true;
903 } else if (input != j_user_inputs.at(i)) {
904 // Skip if there are different non-monad inputs.
905 return false;
906 }
907 }
908 return has_monad;
909 }
910
911 //
912 // To replace the primal graph with k graph.
913 // Convert:
914 // x = primal(args, u0)
915 // u1 = update_state(u0, x)
916 // ...
917 // tuple = K(args, u1)
918 // u2 = update_state(u1, tuple)
919 // ...
920 // To:
921 // tuple = K(args, u0)
922 // x = get_item(tuple, 0)
923 // ...
924 // tuple = K(args, u0)
925 // u2 = update_state(u0, tuple)
926 // ...
927 //
EliminatePrimalGraph()928 void DFunctor::EliminatePrimalGraph() {
929 // Find primal user and paired J user cnodes.
930 auto manager = primal_graph_->manager();
931 MS_EXCEPTION_IF_NULL(manager);
932 auto primal_user_to_j_users = FindPrimalJPair(manager, primal_graph_);
933 for (const auto &iter : primal_user_to_j_users) {
934 auto primal_user = iter.first;
935 auto &j_users = iter.second;
936 MS_EXCEPTION_IF_NULL(primal_user);
937 if (j_users.size() == 1) {
938 // If both inputs are same except monads, we copy primal monad args to k graph
939 // so that they can be combined in CSE (common subexpression elimination) pass.
940 // Only do this when the size of j_users is 1 in order to keep the execution order.
941 const bool has_monad = CopyMonadArguments(primal_user, j_users[0]);
942 // Remove the UpdateState nodes after primal_user if need.
943 if (has_monad) {
944 RemovePrimalUpdateStates(manager, primal_user);
945 }
946 } else {
947 MS_LOG(INFO) << "There are multiple j users with the same primal user " << primal_user->DebugString();
948 }
949
950 // Replace primal graph with k graph.
951 auto k_vnode = NewValueNode(k_graph_);
952 primal_user->set_input(0, k_vnode);
953 if (j_users.empty()) {
954 MS_LOG(EXCEPTION) << "The J nodes for primal graph " << primal_graph_->ToString()
955 << " should be used by at least one other node.";
956 }
957 primal_user->set_abstract(j_users[0]->abstract());
958 // Insert tuple_getitem after primal user cnode.
959 auto construct_wrapper = primal_user->func_graph();
960 auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem);
961 auto imm0 = std::make_shared<Int64Imm>(0);
962 auto idx0 = NewValueNode(SizeToLong(0));
963 idx0->set_abstract(std::make_shared<abstract::AbstractScalar>(imm0));
964 auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0});
965 getitem0->CloneCNodeInfo(primal_user);
966 (void)manager->Replace(primal_user, getitem0);
967 }
968 }
969 } // namespace ad
970 } // namespace mindspore
971