• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2020-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef _WIN32
20 #include <dirent.h>
21 #endif
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include "ir/anf.h"
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "pybind_api/ir/primitive_py.h"
29 #include "ir/meta_func_graph.h"
30 #include "ir/func_graph_cloner.h"
31 #include "ir/manager.h"
32 #include "pipeline/jit/ps/resource.h"
33 #include "frontend/optimizer/ad/dfunctor.h"
34 #include "frontend/operator/composite/composite.h"
35 #include "frontend/expander/bprop/bprop.h"
36 #include "frontend/expander/bprop/bprop_meta_func_graph.h"
37 #include "include/common/utils/primitive_utils.h"
38 #include "include/common/utils/utils.h"
39 #include "utils/symbolic.h"
40 #include "utils/ms_context.h"
41 #include "utils/info.h"
42 #include "pipeline/jit/ps/debug/trace.h"
43 #include "utils/anf_utils.h"
44 #include "frontend/expander/utils.h"
45 
46 namespace mindspore {
47 namespace ad {
48 KPrim g_k_prims;
49 
50 namespace {
51 constexpr char kLiftedUserDataKey[] = "lifted_from_fv";
52 
GetBprop(const PrimitivePtr & prim,const pipeline::ResourceBasePtr & resources,const CNodePtr & cnode)53 FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode) {
54   // Set a child scope named "grad'PrimitiveName'" for the bprop function,
55   // and add "Gradients" to the front.
56   static const std::string gradients_scope = "Gradients/";
57   static const std::string grad_op_child_scope_prefix = "/Grad_";
58   MS_EXCEPTION_IF_NULL(prim);
59   const auto &prim_name = prim->name();
60   auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
61                                        grad_op_child_scope_prefix + prim_name);
62   ScopeGuard scope_guard(scope);
63 
64   // Firstly we get bprop from expander. If failed, try mindir. If still failed, try the python bprop function.
65   FuncGraphPtr func_graph = expander::bprop::GetBpropMetaFuncGraph(prim, cnode);
66   if (func_graph != nullptr) {
67     return func_graph;
68   }
69 
70   py::function fn;
71   if (prim->is_base()) {
72     fn = GetBpropFunction(prim_name);
73   } else if (mindspore::ops::IsPrimitiveFunction(prim_name)) {
74     fn = GetBpropFunction(prim_name);
75   } else if (prim->isa<PrimitivePy>()) {
76     fn = prim->cast_ptr<PrimitivePy>()->GetBpropFunction();
77     if (py::isinstance<py::none>(fn)) {
78       fn = GetBpropFunction(prim_name);
79     }
80   } else {
81     MS_LOG(INTERNAL_EXCEPTION) << "Unexpected prim: " << prim->ToString();
82   }
83   if (!fn || py::isinstance<py::none>(fn)) {
84     MS_LOG(DEBUG) << "Fail to find bprop function for " << prim_name << ". fn: " << py::str(fn);
85     return nullptr;
86   }
87   func_graph = parse::ParsePythonCode(fn);
88   if (func_graph == nullptr) {
89     MS_LOG(ERROR) << "Fail to parse bprop function for " << prim_name << ".";
90     return nullptr;
91   }
92   auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
93   if (bprop_flag) {
94     func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
95   }
96   pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>();
97   (void)parse::ResolveFuncGraph(func_graph, res, false);
98   return func_graph;
99 }
100 }  // namespace
101 
GetPrimBprop(const PrimitivePtr & prim,const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources,const CNodePtr & cnode)102 FuncGraphPtr KPrim::GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node,
103                                  const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode) {
104   MS_EXCEPTION_IF_NULL(prim);
105   MS_EXCEPTION_IF_NULL(value_node);
106   auto iter = bprop_registry_.find(prim);
107   if (iter != bprop_registry_.end() && !iter->second->dropped()) {
108     return iter->second;
109   }
110 
111   FuncGraphPtr bprop_fg = GetBprop(prim, resources, cnode);
112   if (bprop_fg != nullptr) {
113     // Set bprop_g graph cache
114     bprop_registry_[prim] = bprop_fg;
115   } else {
116     bprop_fg = FakeBprop(value_node, resources);
117   }
118 
119   return bprop_fg;
120 }
121 
GetFprop(const PrimitivePtr & prim) const122 FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) const {
123   static const std::string ad_module = "mindspore.ops._grad_experimental.grad_implementations";
124   std::string func_name = "_fprop_" + prim->name();
125   py::function fn = python_adapter::GetPyFn(ad_module, func_name);
126   auto func_graph = parse::ParsePythonCode(fn);
127   MS_EXCEPTION_IF_NULL(func_graph);
128   return BasicClone(func_graph);
129 }
130 
KMetaFuncGraph(const PrimitivePtr & prim)131 MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
132   MS_EXCEPTION_IF_NULL(prim);
133 
134   auto iter = bprop_registry_meta_.find(prim);
135   if (iter != bprop_registry_meta_.end()) {
136     return iter->second;
137   }
138 
139   if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
140     MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
141     bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
142     return meta;
143   }
144 
145   if (IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
146     MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient");
147     bprop_registry_meta_[prim::kPrimMakeList] = meta;
148     return meta;
149   }
150 
151   if (IsPrimitiveEquals(prim, prim::kPrimMakeDict)) {
152     MetaFuncGraphPtr meta = std::make_shared<prim::MakeDictGradient>("make_dict_gradient");
153     bprop_registry_meta_[prim::kPrimMakeDict] = meta;
154     return meta;
155   }
156 
157   if (IsPrimitiveEquals(prim, prim::kPrimMutable)) {
158     MetaFuncGraphPtr meta = std::make_shared<prim::MutableGradient>("MutableGradient");
159     bprop_registry_meta_[prim::kPrimMutable] = meta;
160     return meta;
161   }
162 
163   MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
164 }
165 
AddMonad(const FuncGraphPtr & bprop_fg,const CNodePtr & output,const AnfNodePtr & monad)166 static void AddMonad(const FuncGraphPtr &bprop_fg, const CNodePtr &output, const AnfNodePtr &monad) {
167   if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
168     constexpr char model_name[] = "mindspore.ops.composite.multitype_ops.add_impl";
169     constexpr char python_ops[] = "_tuple_add";
170     auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
171     auto maketuple_monad = bprop_fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), monad});
172     auto tuple_add_monad = bprop_fg->NewCNode({tuple_add_ops, output, maketuple_monad});
173     bprop_fg->set_output(tuple_add_monad);
174   } else {
175     output->add_input(monad);
176   }
177 }
178 
AppendMonadOutput(const FuncGraphPtr & bprop_fg,const AnfNodePtr & monad)179 static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &monad) {
180   const auto &output = bprop_fg->output();
181   MS_EXCEPTION_IF_NULL(output);
182   auto output_cnode = output->cast<CNodePtr>();
183   if (output_cnode != nullptr) {
184     // If output_cnode has the form like (make_tuple, x, y).
185     while (output_cnode->IsApply(prim::kPrimDepend)) {
186       const auto &real_input = output_cnode->input(kRealInputIndexInDepend);
187       MS_EXCEPTION_IF_NULL(real_input);
188       output_cnode = real_input->cast<CNodePtr>();
189     }
190   }
191   constexpr char u_monad_in_output[] = "u_monad_in_output";
192   constexpr char io_monad_in_output[] = "io_monad_in_output";
193   if (output_cnode != nullptr) {
194     if (HasAbstractUMonad(monad) && !bprop_fg->has_flag(u_monad_in_output)) {
195       AddMonad(bprop_fg, output_cnode, monad);
196       bprop_fg->set_flag(u_monad_in_output, true);
197     } else if (HasAbstractIOMonad(monad) && !bprop_fg->has_flag(io_monad_in_output)) {
198       AddMonad(bprop_fg, output_cnode, monad);
199       bprop_fg->set_flag(io_monad_in_output, true);
200     }
201     return;
202   }
203   // If output is an empty tuple, create a (make_tuple, monad) as the new output.
204   auto make_tuple = NewValueNode(prim::kPrimMakeTuple);
205   output_cnode = bprop_fg->NewCNode({make_tuple, monad});
206   if (HasAbstractUMonad(monad)) {
207     bprop_fg->set_flag(u_monad_in_output, true);
208   } else if (HasAbstractIOMonad(monad)) {
209     bprop_fg->set_flag(io_monad_in_output, true);
210   }
211   bprop_fg->set_output(output_cnode);
212 }
213 
214 // Append U or/and IO monad to output of Bprop funcgraph.
AdjustForAutoMonad(const PrimitivePtr & prim,const FuncGraphPtr & bprop_fg)215 static void AdjustForAutoMonad(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
216   auto effect_info = GetPrimEffectInfo(prim);
217   if (effect_info.memory) {
218     MS_LOG(DEBUG) << "Append U monad for Bprop FuncGraph of Primitive " << prim->ToString();
219     auto u = NewValueNode(kUMonad);
220     u->set_abstract(kUMonad->ToAbstract());
221     AppendMonadOutput(bprop_fg, u);
222   }
223   if (effect_info.io) {
224     MS_LOG(DEBUG) << "Append IO monad for Bprop FuncGraph of Primitive " << prim->ToString();
225     auto io = NewValueNode(kIOMonad);
226     io->set_abstract(kIOMonad->ToAbstract());
227     AppendMonadOutput(bprop_fg, io);
228   }
229 }
230 
GeneratePrimalDebugInfo(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)231 std::vector<NodeDebugInfoPtr> GeneratePrimalDebugInfo(const ValueNodePtr &value_node,
232                                                       const pipeline::ResourceBasePtr &resources) {
233   std::vector<NodeDebugInfoPtr> primal_debug_infos;
234   if (resources != nullptr) {
235     auto manager = resources->manager();
236     auto &users = manager->node_users()[value_node];
237     for (auto user_iter = users.begin(); user_iter != users.end(); ++user_iter) {
238       primal_debug_infos.push_back(user_iter->first->debug_info());
239     }
240   }
241   return primal_debug_infos;
242 }
243 
SetDumpFlag(const PrimitivePtr & prim,const FuncGraphPtr & bprop_fg)244 void SetDumpFlag(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
245   if (prim == nullptr || bprop_fg == nullptr) {
246     return;
247   }
248   auto attr = prim->GetAttr(kAttrDump);
249   if (attr != nullptr) {
250     if (attr->isa<StringImm>()) {
251       auto str_attr = attr->cast_ptr<StringImm>();
252       MS_EXCEPTION_IF_NULL(str_attr);
253       if (str_attr->value() == kValueTrue) {
254         bprop_fg->set_flag(FUNC_GRAPH_FLAG_DUMP, true);
255       }
256     }
257   }
258 }
259 
KPrimitive(const CNodePtr & cnode,const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)260 FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
261                                const pipeline::ResourceBasePtr &resources) {
262   if (!IsValueNode<Primitive>(value_node)) {
263     MS_LOG(INTERNAL_EXCEPTION) << "Primitive node is not valid.";
264   }
265 
266   auto prim = GetValueNode<PrimitivePtr>(value_node);
267   if (IsPrimitiveEquals(prim, prim::kPrimSwitchLayer)) {
268     auto fprop = GetFprop(prim);
269     fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
270     return fprop;
271   } else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList) ||
272              IsPrimitiveEquals(prim, prim::kPrimMakeDict) || IsPrimitiveEquals(prim, prim::kPrimMutable)) {
273     // Return null to use Meta bprop.
274     return nullptr;
275   }
276 
277   FuncGraphPtr bprop_fg = nullptr;
278   if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
279     if (MsContext::GetInstance()->get_param<int>(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) {
280       MS_LOG(EXCEPTION)
281         << "The Hook operation is not supported in graph mode, which is only supported in pynative mode.\n"
282         << trace::GetDebugInfoStr(cnode->debug_info());
283     }
284     bprop_fg = BpropCut(value_node, resources);
285   } else {
286     bprop_fg = GetPrimBprop(prim, value_node, resources, cnode);
287   }
288   MS_EXCEPTION_IF_NULL(bprop_fg);
289   MS_EXCEPTION_IF_NULL(bprop_fg->return_node());
290 
291   SetDumpFlag(prim, bprop_fg);
292   AdjustForAutoMonad(prim, bprop_fg);
293   mindspore::HashMap<std::string, ValuePtr> primal_attrs;
294   std::vector<NodeDebugInfoPtr> primal_debug_infos = GeneratePrimalDebugInfo(value_node, resources);
295   if (cnode != nullptr) {
296     primal_attrs = cnode->primal_attrs();
297     cnode->AddPrimalAttr(kPrimalAttrUniqueId, MakeValue(cnode->UniqueId()));
298     const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
299     primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
300     primal_attrs[kPrimalAttrForwardUniqueId] = MakeValue(cnode->UniqueId());
301   }
302   auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos);
303   if (expanded_fg == nullptr) {
304     MS_LOG(INTERNAL_EXCEPTION) << "Failed convert " << prim->name()
305                                << " prim bprop function to J expanded func graph. NodeInfo: "
306                                << trace::GetDebugInfoStr(bprop_fg->debug_info());
307   }
308   if (lift_fv_before_grad && IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
309     // Inline fprop_switch before renormalize;
310     expanded_fg->set_flag(FUNC_GRAPH_FLAG_FORCE_INLINE, true);
311     MS_LOG(DEBUG) << "set force_inline for fg: " << expanded_fg->ToString();
312   }
313 
314   return expanded_fg;
315 }
316 
BuildOutput(const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg) const317 AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg) const {
318   // The primal fg may have extra parameters from lifted fv or u_monad and io_monad.
319   std::vector<AnfNodePtr> extra_lifted_args;
320   std::vector<AnfNodePtr> extra_monad_args;
321   // caller had checked size() - 2 is greater than 0.
322   auto bprop_fg_param_size = bprop_fg->parameters().size() - 2;
323   if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) {
324     auto current_primal_fg_param_size = current_primal_fg->parameters().size();
325     MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so "
326                      "Insert it. Extra parameters size: "
327                   << current_primal_fg_param_size - bprop_fg_param_size;
328     // The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}.
329     for (size_t i = 0; i < current_primal_fg_param_size; ++i) {
330       auto primal_parameter = dyn_cast<Parameter>(current_primal_fg->parameters()[i]);
331       MS_EXCEPTION_IF_NULL(primal_parameter);
332       auto lifted = primal_parameter->user_data<bool>(kLiftedUserDataKey);
333       if (lifted == nullptr || !*lifted) {
334         break;
335       }
336       extra_lifted_args.push_back(
337         bprop_fg->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), primal_parameter}));
338       ++bprop_fg_param_size;
339     }
340     for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) {
341       const auto &primal_node = current_primal_fg->parameters()[i];
342       AnfNodePtr extra_node;
343       // Simplify zeros_like(primal_node) to U or IO, so extra_node in bprop_fg will not refer to primal_node
344       // as a free variable of primal_graph.
345       // Notes: if the implementation of zeros_like changes, here too.
346       if (HasAbstractUMonad(primal_node)) {
347         extra_node = NewValueNode(kUMonad);
348       } else if (HasAbstractIOMonad(primal_node)) {
349         extra_node = NewValueNode(kIOMonad);
350       } else {
351         MS_EXCEPTION(TypeError)
352           << "The params of function 'bprop' of Primitive or Cell requires the forward inputs as well "
353              "as the 'out' and 'dout'.\n"
354           << trace::GetDebugInfoStr(bprop_fg->debug_info());
355       }
356       extra_monad_args.push_back(extra_node);
357       MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString();
358     }
359   }
360   // bprop_fg has been checked in caller
361   if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
362     // Set bprop output as (env, dx, dy, dz, ...)
363     auto cbprop = bprop_fg->output()->cast_ptr<CNode>();
364     auto &inputs = cbprop->inputs();
365 
366     std::vector<AnfNodePtr> args;
367     args.push_back(NewValueNode(prim::kPrimMakeTuple));
368     args.push_back(NewEnviron(bprop_fg));
369     // The lifted parameters are put in front.
370     if (!extra_lifted_args.empty()) {
371       (void)args.insert(args.cend(), extra_lifted_args.cbegin(), extra_lifted_args.cend());
372     }
373     (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend());
374     if (!extra_monad_args.empty()) {
375       (void)args.insert(args.cend(), extra_monad_args.cbegin(), extra_monad_args.cend());
376     }
377     return NewCNode(args, bprop_fg);
378   }
379 
380   // Set bprop output as (env, dx)
381   constexpr char model_name[] = "mindspore.ops.composite.multitype_ops.add_impl";
382   constexpr char python_ops[] = "_tuple_add";
383   auto bprop_tuple_add_check_func = std::make_shared<std::function<bool(const std::vector<AbstractBasePtr> &args)>>(
384     [](const std::vector<AbstractBasePtr> &args) {
385       for (const auto &arg : args) {
386         if (!arg->isa<abstract::AbstractTuple>()) {
387           MS_EXCEPTION(TypeError) << "For bprop function, output should be a tuple, but got " << arg->ToString();
388         }
389       }
390       return true;
391     });
392   auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewEnviron(bprop_fg)}, bprop_fg);
393   auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
394   if (IsValueNode<FuncGraphBase>(tuple_add_ops)) {
395     auto tuple_add_func_graph = GetValueNode<FuncGraphBasePtr>(tuple_add_ops);
396     MS_LOG(DEBUG) << "Get tuple add func successful. Tuple add fg: " << tuple_add_func_graph->ToString();
397     auto checker = std::make_shared<FuncGraphChecker>();
398     checker->AddCheckFunc<const std::vector<AbstractBasePtr> &>(bprop_tuple_add_check_func);
399     tuple_add_func_graph->AddChecker("check_infer_inputs", checker);
400   }
401 
402   if (!extra_lifted_args.empty()) {
403     (void)extra_lifted_args.insert(extra_lifted_args.cbegin(), NewValueNode(prim::kPrimMakeTuple));
404     auto extra_tuple = NewCNode(extra_lifted_args, bprop_fg);
405     tuple_env = NewCNode({tuple_add_ops, tuple_env, extra_tuple}, bprop_fg);
406   }
407   if (!extra_monad_args.empty()) {
408     (void)extra_monad_args.insert(extra_monad_args.cbegin(), NewValueNode(prim::kPrimMakeTuple));
409     auto extra_tuple = NewCNode(extra_monad_args, bprop_fg);
410     auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg);
411     return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg);
412   }
413 
414   return NewCNode({tuple_add_ops, tuple_env, bprop_fg->output()}, bprop_fg);
415 }
416 
TransformNormalArgs(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * transf_args)417 static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
418                                 std::vector<AnfNodePtr> *transf_args) {
419   MS_EXCEPTION_IF_NULL(mng);
420   // bprop_fg has been checked in caller
421   // transform except the last 2 parameters: out, dout.
422   const size_t last_parameter_sizes = 2;
423   auto bprop_fg_param_size = bprop_fg->parameters().size() - last_parameter_sizes;
424   for (size_t i = 0; i < bprop_fg_param_size; ++i) {
425     auto p = bprop_fg->parameters()[i];
426     MS_EXCEPTION_IF_NULL(p);
427 
428     TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
429     auto transf_p = outer->add_parameter();
430 
431     (void)mng->Replace(p, transf_p);
432     transf_args->push_back(transf_p);
433   }
434 }
TransformArgsForPrimitive(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const PrimitivePtr & primitive,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * const transf_args) const435 void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
436                                       const PrimitivePtr &primitive, const FuncGraphPtr &outer,
437                                       std::vector<AnfNodePtr> *const transf_args) const {
438   TransformNormalArgs(mng, bprop_fg, outer, transf_args);
439   // Fprop_fg for Primitive with side effect should append extra U or IO monad parameter.
440   auto effect_info = GetPrimEffectInfo(primitive);
441   if (effect_info.memory) {
442     MS_LOG(DEBUG) << "Append U monad to Fprop FuncGraph for Primitive " << primitive->ToString();
443     auto transf_p = outer->add_parameter();
444     transf_args->push_back(transf_p);
445   }
446   if (effect_info.io) {
447     MS_LOG(DEBUG) << "Append IO monad to Fprop FuncGraph for Primitive " << primitive->ToString();
448     auto transf_p = outer->add_parameter();
449     transf_args->push_back(transf_p);
450   }
451 }
452 
453 template <typename T>
TransformArgsForFuncGraph(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const T & current_primal_fg,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * const transf_args) const454 void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
455                                       const T &current_primal_fg, const FuncGraphPtr &outer,
456                                       std::vector<AnfNodePtr> *const transf_args) const {
457   constexpr size_t need_filter_size = 2;
458   auto bprop_fg_param_size = bprop_fg->parameters().size() - need_filter_size;
459   const auto &current_primal_fg_params = current_primal_fg->parameters();
460   // The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}.
461   for (size_t i = 0; i < current_primal_fg_params.size(); ++i) {
462     auto primal_parameter = dyn_cast_ptr<Parameter>(current_primal_fg_params[i]);
463     MS_EXCEPTION_IF_NULL(primal_parameter);
464     auto lifted = primal_parameter->template user_data<bool>(kLiftedUserDataKey);
465     if (lifted == nullptr || !*lifted) {
466       break;
467     }
468     TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal_parameter->debug_info()));
469     auto transf_p = outer->add_parameter();
470     transf_args->push_back(transf_p);
471     ++bprop_fg_param_size;
472   }
473   TransformNormalArgs(mng, bprop_fg, outer, transf_args);
474   // Current primal fg may have extra parameters after AutoMonad
475   if (bprop_fg_param_size < current_primal_fg_params.size()) {
476     for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) {
477       auto p = current_primal_fg_params[i];
478       MS_EXCEPTION_IF_NULL(p);
479       // extra parameters should be Monad.
480       if (!HasAbstractMonad(p)) {
481         continue;
482       }
483       MS_LOG(DEBUG) << "Function " << current_primal_fg->ToString()
484                     << ", has extra monad parameter: " << p->DebugString()
485                     << ", abstract: " << p->abstract()->ToString();
486 
487       TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
488       auto transf_p = outer->add_parameter();
489       // See also Notes on extra_node of BuildOutput.
490       // Notes: No need to replace p with transf_p as the only use of p is here.
491       // If extra_node in bprop_fg use p as free variable, a replacement of p is required here.
492       // This replacement will make the usage of p in current_primal_fg got replaced with transf_p
493       // of outer. outer will be released after it is being cloned to fprop_fg, so the func_graph_
494       // in transf_p will be nullptr.
495       // So the RULE is DONT tamper the current_primal_fg;
496       transf_args->push_back(transf_p);
497     }
498   }
499   if (transf_args->size() != current_primal_fg_params.size()) {
500     MS_EXCEPTION(TypeError) << "Function " << current_primal_fg->ToString()
501                             << ", The number of parameter of this primal function is "
502                             << current_primal_fg_params.size() << ", but the number of parameters of bprop is "
503                             << bprop_fg_param_size;
504   }
505 }
506 
CheckBprop(const FuncGraphPtr & bprop_fg,const string & prim_to_check) const507 void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) const {
508   TraceGuard guard(std::make_shared<TraceCopy>(bprop_fg->return_node()->debug_info()));
509   auto context = MsContext::GetInstance();
510   MS_EXCEPTION_IF_NULL(context);
511   bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
512   // Skip checking if check_bprop not set
513   if (!check_bprop_flag) {
514     return;
515   }
516 
517   // bprop_fg has been checked in caller
518   auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations._inner_ops");
519   MS_EXCEPTION_IF_NULL(check_bprop_class);
520   auto check_bprop =
521     bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
522 
523   std::vector<AnfNodePtr> inputs;
524   inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
525   constexpr int primitive_size = 1;
526   constexpr int brprop_offset_size = 2;
527   (void)inputs.insert(inputs.cbegin() + primitive_size, bprop_fg->parameters().cbegin(),
528                       bprop_fg->parameters().cend() - brprop_offset_size);
529   AnfNodePtr params = bprop_fg->NewCNode(inputs);
530 
531   inputs.clear();
532   inputs.push_back(check_bprop);
533   inputs.push_back(bprop_fg->output());
534   inputs.push_back(params);
535   AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
536   bprop_fg->set_output(bprop_out);
537 }
538 
KUserDefinedCellBprop(const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg)539 FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg) {
540   MS_EXCEPTION_IF_NULL(bprop_fg);
541   // primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph.
542   // current_primal_fg is specalized and AutoMoaded primal_fg;
543   auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph();
544   auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr, {}, {});
545   if (expanded_fg == nullptr) {
546     MS_LOG(INTERNAL_EXCEPTION) << "Failed convert " << primal_fg->ToString()
547                                << " Cell bprop function to K expanded func graph. NodeInfo: "
548                                << trace::GetDebugInfoStr(primal_fg->debug_info());
549   }
550   return expanded_fg;
551 }
552 
BpropCut(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources) const553 FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const {
554   auto prim = GetValueNode<PrimitivePtr>(value_node);
555   MS_EXCEPTION_IF_NULL(prim);
556   auto &node_users = resources->manager()->node_users();
557 
558   auto &users = node_users[value_node];
559   auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
560     return IsPrimitiveCNode(user.first, prim);
561   });
562   if (cnode == users.end()) {
563     MS_LOG(INTERNAL_EXCEPTION) << "Fail to find cnode.";
564   }
565   auto cnode_first = cnode->first->cast_ptr<CNode>();
566   MS_EXCEPTION_IF_NULL(cnode_first);
567   auto inputs_num = cnode_first->size() - 1;
568 
569   auto func_graph = std::make_shared<FuncGraph>();
570   std::vector<AnfNodePtr> outputs;
571   auto prim_py = prim->cast<PrimitivePyPtr>();
572   MS_EXCEPTION_IF_NULL(prim_py);
573   auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
574   bprop_cut->CopyHookFunction(prim_py);
575 
576   auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
577   if (cell_id != "") {
578     (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
579     (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
580   }
581 
582   outputs.push_back(NewValueNode(bprop_cut));
583   for (size_t i = 0; i < inputs_num; ++i) {
584     auto param = func_graph->add_parameter();
585     outputs.push_back(param);
586   }
587   auto p1 = func_graph->add_parameter();
588   auto p2 = func_graph->add_parameter();
589   outputs.push_back(p1);
590   outputs.push_back(p2);
591 
592   func_graph->set_output(func_graph->NewCNode(outputs));
593   return func_graph;
594 }
595 
FakeBprop(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources) const596 FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const {
597   auto prim = value_node->value()->cast<PrimitivePtr>();
598   MS_EXCEPTION_IF_NULL(prim);
599   auto &node_users = resources->manager()->node_users();
600 
601   auto &users = node_users[value_node];
602   auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
603     return IsPrimitiveCNode(user.first, prim);
604   });
605   if (cnode == users.end()) {
606     MS_LOG(INTERNAL_EXCEPTION) << "Fail to find user for " << prim->ToString();
607   }
608   auto cnode_first = cnode->first->cast_ptr<CNode>();
609   MS_EXCEPTION_IF_NULL(cnode_first);
610   auto inputs_num = cnode_first->size() - 1;
611   auto effect_info = GetPrimEffectInfo(prim);
612   // Don't add U or IO monad parameters as it will be added later.
613   size_t monad_params_size = 0;
614   if (effect_info.memory) {
615     monad_params_size++;
616   }
617   if (effect_info.io) {
618     monad_params_size++;
619   }
620   if (inputs_num < monad_params_size) {
621     MS_LOG(INTERNAL_EXCEPTION) << "Arguments number should be greater than or equal to " << monad_params_size
622                                << ", but the CNode is: " << cnode->first->DebugString();
623   }
624   inputs_num -= monad_params_size;
625 
626   auto func_graph = std::make_shared<FuncGraph>();
627   std::vector<AnfNodePtr> outputs;
628   outputs.push_back(NewValueNode(prim::kPrimMakeTuple));
629 
630   auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
631   (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
632 
633   for (size_t i = 0; i < inputs_num; ++i) {
634     // Mock params for inputs
635     auto param = func_graph->add_parameter();
636     // Mock derivatives for each inputs
637     if (IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
638       outputs.push_back(func_graph->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), param}));
639     } else {
640       outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param}));
641     }
642   }
643   // mock params for out and dout
644   (void)func_graph->add_parameter();
645   (void)func_graph->add_parameter();
646   func_graph->set_output(func_graph->NewCNode(outputs));
647   return func_graph;
648 }
649 
CheckCustomVjp(const FuncGraphPtr & bprop_fg) const650 bool KPrim::CheckCustomVjp(const FuncGraphPtr &bprop_fg) const {
651   MS_EXCEPTION_IF_NULL(bprop_fg);
652   auto parameters_size = bprop_fg->parameters().size();
653   if (bprop_fg->has_flag("custom_vjp") && parameters_size == 1) {
654     return true;
655   }
656   return false;
657 }
658 
GetCustomVjpBprop(const FuncGraphPtr & bprop_fg) const659 FuncGraphPtr KPrim::GetCustomVjpBprop(const FuncGraphPtr &bprop_fg) const {
660   MS_EXCEPTION_IF_NULL(bprop_fg);
661   auto bprop_fg_output = dyn_cast<CNode>(bprop_fg->output());
662   MS_EXCEPTION_IF_NULL(bprop_fg_output);
663   // Check the definition of the bprop function
664   if (IsValueNode<None>(bprop_fg_output->input(1))) {
665     MS_EXCEPTION(TypeError)
666       << "The bprop function of @custom_vjp is undefined. Please use 'defbwd(bprop)' to define the 'bprop' function.";
667   }
668 
669   auto custom_vjp_bprop_fg = GetValueNode<FuncGraphPtr>(bprop_fg_output->input(1));
670   if (custom_vjp_bprop_fg != nullptr) {
671     custom_vjp_bprop_fg->set_transforms(bprop_fg->transforms());
672     return custom_vjp_bprop_fg;
673   } else {
674     MS_EXCEPTION(TypeError) << "The 'bprop' function defined by @custom_vjp defbwd(bprop) is illegal.";
675   }
676 }
677 }  // namespace ad
678 }  // namespace mindspore
679