1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2020-2023 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef _WIN32
20 #include <dirent.h>
21 #endif
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include "ir/anf.h"
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "pybind_api/ir/primitive_py.h"
29 #include "ir/meta_func_graph.h"
30 #include "ir/func_graph_cloner.h"
31 #include "ir/manager.h"
32 #include "pipeline/jit/ps/resource.h"
33 #include "frontend/optimizer/ad/dfunctor.h"
34 #include "frontend/operator/composite/composite.h"
35 #include "frontend/expander/bprop/bprop.h"
36 #include "frontend/expander/bprop/bprop_meta_func_graph.h"
37 #include "include/common/utils/primitive_utils.h"
38 #include "include/common/utils/utils.h"
39 #include "utils/symbolic.h"
40 #include "utils/ms_context.h"
41 #include "utils/info.h"
42 #include "pipeline/jit/ps/debug/trace.h"
43 #include "utils/anf_utils.h"
44 #include "frontend/expander/utils.h"
45
46 namespace mindspore {
47 namespace ad {
48 KPrim g_k_prims;
49
50 namespace {
51 constexpr char kLiftedUserDataKey[] = "lifted_from_fv";
52
GetBprop(const PrimitivePtr & prim,const pipeline::ResourceBasePtr & resources,const CNodePtr & cnode)53 FuncGraphPtr GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode) {
54 // Set a child scope named "grad'PrimitiveName'" for the bprop function,
55 // and add "Gradients" to the front.
56 static const std::string gradients_scope = "Gradients/";
57 static const std::string grad_op_child_scope_prefix = "/Grad_";
58 MS_EXCEPTION_IF_NULL(prim);
59 const auto &prim_name = prim->name();
60 auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
61 grad_op_child_scope_prefix + prim_name);
62 ScopeGuard scope_guard(scope);
63
64 // Firstly we get bprop from expander. If failed, try mindir. If still failed, try the python bprop function.
65 FuncGraphPtr func_graph = expander::bprop::GetBpropMetaFuncGraph(prim, cnode);
66 if (func_graph != nullptr) {
67 return func_graph;
68 }
69
70 py::function fn;
71 if (prim->is_base()) {
72 fn = GetBpropFunction(prim_name);
73 } else if (mindspore::ops::IsPrimitiveFunction(prim_name)) {
74 fn = GetBpropFunction(prim_name);
75 } else if (prim->isa<PrimitivePy>()) {
76 fn = prim->cast_ptr<PrimitivePy>()->GetBpropFunction();
77 if (py::isinstance<py::none>(fn)) {
78 fn = GetBpropFunction(prim_name);
79 }
80 } else {
81 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected prim: " << prim->ToString();
82 }
83 if (!fn || py::isinstance<py::none>(fn)) {
84 MS_LOG(DEBUG) << "Fail to find bprop function for " << prim_name << ". fn: " << py::str(fn);
85 return nullptr;
86 }
87 func_graph = parse::ParsePythonCode(fn);
88 if (func_graph == nullptr) {
89 MS_LOG(ERROR) << "Fail to parse bprop function for " << prim_name << ".";
90 return nullptr;
91 }
92 auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
93 if (bprop_flag) {
94 func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
95 }
96 pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>();
97 (void)parse::ResolveFuncGraph(func_graph, res, false);
98 return func_graph;
99 }
100 } // namespace
101
GetPrimBprop(const PrimitivePtr & prim,const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources,const CNodePtr & cnode)102 FuncGraphPtr KPrim::GetPrimBprop(const PrimitivePtr &prim, const ValueNodePtr &value_node,
103 const pipeline::ResourceBasePtr &resources, const CNodePtr &cnode) {
104 MS_EXCEPTION_IF_NULL(prim);
105 MS_EXCEPTION_IF_NULL(value_node);
106 auto iter = bprop_registry_.find(prim);
107 if (iter != bprop_registry_.end() && !iter->second->dropped()) {
108 return iter->second;
109 }
110
111 FuncGraphPtr bprop_fg = GetBprop(prim, resources, cnode);
112 if (bprop_fg != nullptr) {
113 // Set bprop_g graph cache
114 bprop_registry_[prim] = bprop_fg;
115 } else {
116 bprop_fg = FakeBprop(value_node, resources);
117 }
118
119 return bprop_fg;
120 }
121
GetFprop(const PrimitivePtr & prim) const122 FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) const {
123 static const std::string ad_module = "mindspore.ops._grad_experimental.grad_implementations";
124 std::string func_name = "_fprop_" + prim->name();
125 py::function fn = python_adapter::GetPyFn(ad_module, func_name);
126 auto func_graph = parse::ParsePythonCode(fn);
127 MS_EXCEPTION_IF_NULL(func_graph);
128 return BasicClone(func_graph);
129 }
130
KMetaFuncGraph(const PrimitivePtr & prim)131 MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
132 MS_EXCEPTION_IF_NULL(prim);
133
134 auto iter = bprop_registry_meta_.find(prim);
135 if (iter != bprop_registry_meta_.end()) {
136 return iter->second;
137 }
138
139 if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
140 MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
141 bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
142 return meta;
143 }
144
145 if (IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
146 MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient");
147 bprop_registry_meta_[prim::kPrimMakeList] = meta;
148 return meta;
149 }
150
151 if (IsPrimitiveEquals(prim, prim::kPrimMakeDict)) {
152 MetaFuncGraphPtr meta = std::make_shared<prim::MakeDictGradient>("make_dict_gradient");
153 bprop_registry_meta_[prim::kPrimMakeDict] = meta;
154 return meta;
155 }
156
157 if (IsPrimitiveEquals(prim, prim::kPrimMutable)) {
158 MetaFuncGraphPtr meta = std::make_shared<prim::MutableGradient>("MutableGradient");
159 bprop_registry_meta_[prim::kPrimMutable] = meta;
160 return meta;
161 }
162
163 MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
164 }
165
AddMonad(const FuncGraphPtr & bprop_fg,const CNodePtr & output,const AnfNodePtr & monad)166 static void AddMonad(const FuncGraphPtr &bprop_fg, const CNodePtr &output, const AnfNodePtr &monad) {
167 if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
168 constexpr char model_name[] = "mindspore.ops.composite.multitype_ops.add_impl";
169 constexpr char python_ops[] = "_tuple_add";
170 auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
171 auto maketuple_monad = bprop_fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), monad});
172 auto tuple_add_monad = bprop_fg->NewCNode({tuple_add_ops, output, maketuple_monad});
173 bprop_fg->set_output(tuple_add_monad);
174 } else {
175 output->add_input(monad);
176 }
177 }
178
AppendMonadOutput(const FuncGraphPtr & bprop_fg,const AnfNodePtr & monad)179 static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &monad) {
180 const auto &output = bprop_fg->output();
181 MS_EXCEPTION_IF_NULL(output);
182 auto output_cnode = output->cast<CNodePtr>();
183 if (output_cnode != nullptr) {
184 // If output_cnode has the form like (make_tuple, x, y).
185 while (output_cnode->IsApply(prim::kPrimDepend)) {
186 const auto &real_input = output_cnode->input(kRealInputIndexInDepend);
187 MS_EXCEPTION_IF_NULL(real_input);
188 output_cnode = real_input->cast<CNodePtr>();
189 }
190 }
191 constexpr char u_monad_in_output[] = "u_monad_in_output";
192 constexpr char io_monad_in_output[] = "io_monad_in_output";
193 if (output_cnode != nullptr) {
194 if (HasAbstractUMonad(monad) && !bprop_fg->has_flag(u_monad_in_output)) {
195 AddMonad(bprop_fg, output_cnode, monad);
196 bprop_fg->set_flag(u_monad_in_output, true);
197 } else if (HasAbstractIOMonad(monad) && !bprop_fg->has_flag(io_monad_in_output)) {
198 AddMonad(bprop_fg, output_cnode, monad);
199 bprop_fg->set_flag(io_monad_in_output, true);
200 }
201 return;
202 }
203 // If output is an empty tuple, create a (make_tuple, monad) as the new output.
204 auto make_tuple = NewValueNode(prim::kPrimMakeTuple);
205 output_cnode = bprop_fg->NewCNode({make_tuple, monad});
206 if (HasAbstractUMonad(monad)) {
207 bprop_fg->set_flag(u_monad_in_output, true);
208 } else if (HasAbstractIOMonad(monad)) {
209 bprop_fg->set_flag(io_monad_in_output, true);
210 }
211 bprop_fg->set_output(output_cnode);
212 }
213
214 // Append U or/and IO monad to output of Bprop funcgraph.
AdjustForAutoMonad(const PrimitivePtr & prim,const FuncGraphPtr & bprop_fg)215 static void AdjustForAutoMonad(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
216 auto effect_info = GetPrimEffectInfo(prim);
217 if (effect_info.memory) {
218 MS_LOG(DEBUG) << "Append U monad for Bprop FuncGraph of Primitive " << prim->ToString();
219 auto u = NewValueNode(kUMonad);
220 u->set_abstract(kUMonad->ToAbstract());
221 AppendMonadOutput(bprop_fg, u);
222 }
223 if (effect_info.io) {
224 MS_LOG(DEBUG) << "Append IO monad for Bprop FuncGraph of Primitive " << prim->ToString();
225 auto io = NewValueNode(kIOMonad);
226 io->set_abstract(kIOMonad->ToAbstract());
227 AppendMonadOutput(bprop_fg, io);
228 }
229 }
230
GeneratePrimalDebugInfo(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)231 std::vector<NodeDebugInfoPtr> GeneratePrimalDebugInfo(const ValueNodePtr &value_node,
232 const pipeline::ResourceBasePtr &resources) {
233 std::vector<NodeDebugInfoPtr> primal_debug_infos;
234 if (resources != nullptr) {
235 auto manager = resources->manager();
236 auto &users = manager->node_users()[value_node];
237 for (auto user_iter = users.begin(); user_iter != users.end(); ++user_iter) {
238 primal_debug_infos.push_back(user_iter->first->debug_info());
239 }
240 }
241 return primal_debug_infos;
242 }
243
SetDumpFlag(const PrimitivePtr & prim,const FuncGraphPtr & bprop_fg)244 void SetDumpFlag(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
245 if (prim == nullptr || bprop_fg == nullptr) {
246 return;
247 }
248 auto attr = prim->GetAttr(kAttrDump);
249 if (attr != nullptr) {
250 if (attr->isa<StringImm>()) {
251 auto str_attr = attr->cast_ptr<StringImm>();
252 MS_EXCEPTION_IF_NULL(str_attr);
253 if (str_attr->value() == kValueTrue) {
254 bprop_fg->set_flag(FUNC_GRAPH_FLAG_DUMP, true);
255 }
256 }
257 }
258 }
259
KPrimitive(const CNodePtr & cnode,const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)260 FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
261 const pipeline::ResourceBasePtr &resources) {
262 if (!IsValueNode<Primitive>(value_node)) {
263 MS_LOG(INTERNAL_EXCEPTION) << "Primitive node is not valid.";
264 }
265
266 auto prim = GetValueNode<PrimitivePtr>(value_node);
267 if (IsPrimitiveEquals(prim, prim::kPrimSwitchLayer)) {
268 auto fprop = GetFprop(prim);
269 fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
270 return fprop;
271 } else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList) ||
272 IsPrimitiveEquals(prim, prim::kPrimMakeDict) || IsPrimitiveEquals(prim, prim::kPrimMutable)) {
273 // Return null to use Meta bprop.
274 return nullptr;
275 }
276
277 FuncGraphPtr bprop_fg = nullptr;
278 if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
279 if (MsContext::GetInstance()->get_param<int>(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) {
280 MS_LOG(EXCEPTION)
281 << "The Hook operation is not supported in graph mode, which is only supported in pynative mode.\n"
282 << trace::GetDebugInfoStr(cnode->debug_info());
283 }
284 bprop_fg = BpropCut(value_node, resources);
285 } else {
286 bprop_fg = GetPrimBprop(prim, value_node, resources, cnode);
287 }
288 MS_EXCEPTION_IF_NULL(bprop_fg);
289 MS_EXCEPTION_IF_NULL(bprop_fg->return_node());
290
291 SetDumpFlag(prim, bprop_fg);
292 AdjustForAutoMonad(prim, bprop_fg);
293 mindspore::HashMap<std::string, ValuePtr> primal_attrs;
294 std::vector<NodeDebugInfoPtr> primal_debug_infos = GeneratePrimalDebugInfo(value_node, resources);
295 if (cnode != nullptr) {
296 primal_attrs = cnode->primal_attrs();
297 cnode->AddPrimalAttr(kPrimalAttrUniqueId, MakeValue(cnode->UniqueId()));
298 const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
299 primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
300 primal_attrs[kPrimalAttrForwardUniqueId] = MakeValue(cnode->UniqueId());
301 }
302 auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos);
303 if (expanded_fg == nullptr) {
304 MS_LOG(INTERNAL_EXCEPTION) << "Failed convert " << prim->name()
305 << " prim bprop function to J expanded func graph. NodeInfo: "
306 << trace::GetDebugInfoStr(bprop_fg->debug_info());
307 }
308 if (lift_fv_before_grad && IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
309 // Inline fprop_switch before renormalize;
310 expanded_fg->set_flag(FUNC_GRAPH_FLAG_FORCE_INLINE, true);
311 MS_LOG(DEBUG) << "set force_inline for fg: " << expanded_fg->ToString();
312 }
313
314 return expanded_fg;
315 }
316
BuildOutput(const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg) const317 AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) const {
318 // The primal fg may have extra parameters from lifted fv or u_monad and io_monad.
319 std::vector<AnfNodePtr> extra_lifted_args;
320 std::vector<AnfNodePtr> extra_monad_args;
321 // caller had checked size() - 2 is greater than 0.
322 auto bprop_fg_param_size = bprop_fg->parameters().size() - 2;
323 if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) {
324 auto current_primal_fg_param_size = current_primal_fg->parameters().size();
325 MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so "
326 "Insert it. Extra parameters size: "
327 << current_primal_fg_param_size - bprop_fg_param_size;
328 // The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}.
329 for (size_t i = 0; i < current_primal_fg_param_size; ++i) {
330 auto primal_parameter = dyn_cast<Parameter>(current_primal_fg->parameters()[i]);
331 MS_EXCEPTION_IF_NULL(primal_parameter);
332 auto lifted = primal_parameter->user_data<bool>(kLiftedUserDataKey);
333 if (lifted == nullptr || !*lifted) {
334 break;
335 }
336 extra_lifted_args.push_back(
337 bprop_fg->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), primal_parameter}));
338 ++bprop_fg_param_size;
339 }
340 for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) {
341 const auto &primal_node = current_primal_fg->parameters()[i];
342 AnfNodePtr extra_node;
343 // Simplify zeros_like(primal_node) to U or IO, so extra_node in bprop_fg will not refer to primal_node
344 // as a free variable of primal_graph.
345 // Notes: if the implementation of zeros_like changes, here too.
346 if (HasAbstractUMonad(primal_node)) {
347 extra_node = NewValueNode(kUMonad);
348 } else if (HasAbstractIOMonad(primal_node)) {
349 extra_node = NewValueNode(kIOMonad);
350 } else {
351 MS_EXCEPTION(TypeError)
352 << "The params of function 'bprop' of Primitive or Cell requires the forward inputs as well "
353 "as the 'out' and 'dout'.\n"
354 << trace::GetDebugInfoStr(bprop_fg->debug_info());
355 }
356 extra_monad_args.push_back(extra_node);
357 MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString();
358 }
359 }
360 // bprop_fg has been checked in caller
361 if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
362 // Set bprop output as (env, dx, dy, dz, ...)
363 auto cbprop = bprop_fg->output()->cast_ptr<CNode>();
364 auto &inputs = cbprop->inputs();
365
366 std::vector<AnfNodePtr> args;
367 args.push_back(NewValueNode(prim::kPrimMakeTuple));
368 args.push_back(NewEnviron(bprop_fg));
369 // The lifted parameters are put in front.
370 if (!extra_lifted_args.empty()) {
371 (void)args.insert(args.cend(), extra_lifted_args.cbegin(), extra_lifted_args.cend());
372 }
373 (void)args.insert(args.cend(), inputs.cbegin() + 1, inputs.cend());
374 if (!extra_monad_args.empty()) {
375 (void)args.insert(args.cend(), extra_monad_args.cbegin(), extra_monad_args.cend());
376 }
377 return NewCNode(args, bprop_fg);
378 }
379
380 // Set bprop output as (env, dx)
381 constexpr char model_name[] = "mindspore.ops.composite.multitype_ops.add_impl";
382 constexpr char python_ops[] = "_tuple_add";
383 auto bprop_tuple_add_check_func = std::make_shared<std::function<bool(const std::vector<AbstractBasePtr> &args)>>(
384 [](const std::vector<AbstractBasePtr> &args) {
385 for (const auto &arg : args) {
386 if (!arg->isa<abstract::AbstractTuple>()) {
387 MS_EXCEPTION(TypeError) << "For bprop function, output should be a tuple, but got " << arg->ToString();
388 }
389 }
390 return true;
391 });
392 auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewEnviron(bprop_fg)}, bprop_fg);
393 auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
394 if (IsValueNode<FuncGraphBase>(tuple_add_ops)) {
395 auto tuple_add_func_graph = GetValueNode<FuncGraphBasePtr>(tuple_add_ops);
396 MS_LOG(DEBUG) << "Get tuple add func successful. Tuple add fg: " << tuple_add_func_graph->ToString();
397 auto checker = std::make_shared<FuncGraphChecker>();
398 checker->AddCheckFunc<const std::vector<AbstractBasePtr> &>(bprop_tuple_add_check_func);
399 tuple_add_func_graph->AddChecker("check_infer_inputs", checker);
400 }
401
402 if (!extra_lifted_args.empty()) {
403 (void)extra_lifted_args.insert(extra_lifted_args.cbegin(), NewValueNode(prim::kPrimMakeTuple));
404 auto extra_tuple = NewCNode(extra_lifted_args, bprop_fg);
405 tuple_env = NewCNode({tuple_add_ops, tuple_env, extra_tuple}, bprop_fg);
406 }
407 if (!extra_monad_args.empty()) {
408 (void)extra_monad_args.insert(extra_monad_args.cbegin(), NewValueNode(prim::kPrimMakeTuple));
409 auto extra_tuple = NewCNode(extra_monad_args, bprop_fg);
410 auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg);
411 return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg);
412 }
413
414 return NewCNode({tuple_add_ops, tuple_env, bprop_fg->output()}, bprop_fg);
415 }
416
TransformNormalArgs(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * transf_args)417 static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
418 std::vector<AnfNodePtr> *transf_args) {
419 MS_EXCEPTION_IF_NULL(mng);
420 // bprop_fg has been checked in caller
421 // transform except the last 2 parameters: out, dout.
422 const size_t last_parameter_sizes = 2;
423 auto bprop_fg_param_size = bprop_fg->parameters().size() - last_parameter_sizes;
424 for (size_t i = 0; i < bprop_fg_param_size; ++i) {
425 auto p = bprop_fg->parameters()[i];
426 MS_EXCEPTION_IF_NULL(p);
427
428 TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
429 auto transf_p = outer->add_parameter();
430
431 (void)mng->Replace(p, transf_p);
432 transf_args->push_back(transf_p);
433 }
434 }
TransformArgsForPrimitive(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const PrimitivePtr & primitive,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * const transf_args) const435 void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
436 const PrimitivePtr &primitive, const FuncGraphPtr &outer,
437 std::vector<AnfNodePtr> *const transf_args) const {
438 TransformNormalArgs(mng, bprop_fg, outer, transf_args);
439 // Fprop_fg for Primitive with side effect should append extra U or IO monad parameter.
440 auto effect_info = GetPrimEffectInfo(primitive);
441 if (effect_info.memory) {
442 MS_LOG(DEBUG) << "Append U monad to Fprop FuncGraph for Primitive " << primitive->ToString();
443 auto transf_p = outer->add_parameter();
444 transf_args->push_back(transf_p);
445 }
446 if (effect_info.io) {
447 MS_LOG(DEBUG) << "Append IO monad to Fprop FuncGraph for Primitive " << primitive->ToString();
448 auto transf_p = outer->add_parameter();
449 transf_args->push_back(transf_p);
450 }
451 }
452
453 template <typename T>
TransformArgsForFuncGraph(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const T & current_primal_fg,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * const transf_args) const454 void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
455 const T ¤t_primal_fg, const FuncGraphPtr &outer,
456 std::vector<AnfNodePtr> *const transf_args) const {
457 constexpr size_t need_filter_size = 2;
458 auto bprop_fg_param_size = bprop_fg->parameters().size() - need_filter_size;
459 const auto ¤t_primal_fg_params = current_primal_fg->parameters();
460 // The lifted parameters are put in front: {lifted parameters, origin parameters, u/io monad}.
461 for (size_t i = 0; i < current_primal_fg_params.size(); ++i) {
462 auto primal_parameter = dyn_cast_ptr<Parameter>(current_primal_fg_params[i]);
463 MS_EXCEPTION_IF_NULL(primal_parameter);
464 auto lifted = primal_parameter->template user_data<bool>(kLiftedUserDataKey);
465 if (lifted == nullptr || !*lifted) {
466 break;
467 }
468 TraceGuard trace_guard(std::make_shared<TraceGradFprop>(primal_parameter->debug_info()));
469 auto transf_p = outer->add_parameter();
470 transf_args->push_back(transf_p);
471 ++bprop_fg_param_size;
472 }
473 TransformNormalArgs(mng, bprop_fg, outer, transf_args);
474 // Current primal fg may have extra parameters after AutoMonad
475 if (bprop_fg_param_size < current_primal_fg_params.size()) {
476 for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) {
477 auto p = current_primal_fg_params[i];
478 MS_EXCEPTION_IF_NULL(p);
479 // extra parameters should be Monad.
480 if (!HasAbstractMonad(p)) {
481 continue;
482 }
483 MS_LOG(DEBUG) << "Function " << current_primal_fg->ToString()
484 << ", has extra monad parameter: " << p->DebugString()
485 << ", abstract: " << p->abstract()->ToString();
486
487 TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
488 auto transf_p = outer->add_parameter();
489 // See also Notes on extra_node of BuildOutput.
490 // Notes: No need to replace p with transf_p as the only use of p is here.
491 // If extra_node in bprop_fg use p as free variable, a replacement of p is required here.
492 // This replacement will make the usage of p in current_primal_fg got replaced with transf_p
493 // of outer. outer will be released after it is being cloned to fprop_fg, so the func_graph_
494 // in transf_p will be nullptr.
495 // So the RULE is DONT tamper the current_primal_fg;
496 transf_args->push_back(transf_p);
497 }
498 }
499 if (transf_args->size() != current_primal_fg_params.size()) {
500 MS_EXCEPTION(TypeError) << "Function " << current_primal_fg->ToString()
501 << ", The number of parameter of this primal function is "
502 << current_primal_fg_params.size() << ", but the number of parameters of bprop is "
503 << bprop_fg_param_size;
504 }
505 }
506
CheckBprop(const FuncGraphPtr & bprop_fg,const string & prim_to_check) const507 void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) const {
508 TraceGuard guard(std::make_shared<TraceCopy>(bprop_fg->return_node()->debug_info()));
509 auto context = MsContext::GetInstance();
510 MS_EXCEPTION_IF_NULL(context);
511 bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
512 // Skip checking if check_bprop not set
513 if (!check_bprop_flag) {
514 return;
515 }
516
517 // bprop_fg has been checked in caller
518 auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations._inner_ops");
519 MS_EXCEPTION_IF_NULL(check_bprop_class);
520 auto check_bprop =
521 bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
522
523 std::vector<AnfNodePtr> inputs;
524 inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
525 constexpr int primitive_size = 1;
526 constexpr int brprop_offset_size = 2;
527 (void)inputs.insert(inputs.cbegin() + primitive_size, bprop_fg->parameters().cbegin(),
528 bprop_fg->parameters().cend() - brprop_offset_size);
529 AnfNodePtr params = bprop_fg->NewCNode(inputs);
530
531 inputs.clear();
532 inputs.push_back(check_bprop);
533 inputs.push_back(bprop_fg->output());
534 inputs.push_back(params);
535 AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
536 bprop_fg->set_output(bprop_out);
537 }
538
KUserDefinedCellBprop(const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg)539 FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) {
540 MS_EXCEPTION_IF_NULL(bprop_fg);
541 // primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph.
542 // current_primal_fg is specalized and AutoMoaded primal_fg;
543 auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph();
544 auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr, {}, {});
545 if (expanded_fg == nullptr) {
546 MS_LOG(INTERNAL_EXCEPTION) << "Failed convert " << primal_fg->ToString()
547 << " Cell bprop function to K expanded func graph. NodeInfo: "
548 << trace::GetDebugInfoStr(primal_fg->debug_info());
549 }
550 return expanded_fg;
551 }
552
BpropCut(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources) const553 FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const {
554 auto prim = GetValueNode<PrimitivePtr>(value_node);
555 MS_EXCEPTION_IF_NULL(prim);
556 auto &node_users = resources->manager()->node_users();
557
558 auto &users = node_users[value_node];
559 auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
560 return IsPrimitiveCNode(user.first, prim);
561 });
562 if (cnode == users.end()) {
563 MS_LOG(INTERNAL_EXCEPTION) << "Fail to find cnode.";
564 }
565 auto cnode_first = cnode->first->cast_ptr<CNode>();
566 MS_EXCEPTION_IF_NULL(cnode_first);
567 auto inputs_num = cnode_first->size() - 1;
568
569 auto func_graph = std::make_shared<FuncGraph>();
570 std::vector<AnfNodePtr> outputs;
571 auto prim_py = prim->cast<PrimitivePyPtr>();
572 MS_EXCEPTION_IF_NULL(prim_py);
573 auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
574 bprop_cut->CopyHookFunction(prim_py);
575
576 auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
577 if (cell_id != "") {
578 (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
579 (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
580 }
581
582 outputs.push_back(NewValueNode(bprop_cut));
583 for (size_t i = 0; i < inputs_num; ++i) {
584 auto param = func_graph->add_parameter();
585 outputs.push_back(param);
586 }
587 auto p1 = func_graph->add_parameter();
588 auto p2 = func_graph->add_parameter();
589 outputs.push_back(p1);
590 outputs.push_back(p2);
591
592 func_graph->set_output(func_graph->NewCNode(outputs));
593 return func_graph;
594 }
595
FakeBprop(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources) const596 FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) const {
597 auto prim = value_node->value()->cast<PrimitivePtr>();
598 MS_EXCEPTION_IF_NULL(prim);
599 auto &node_users = resources->manager()->node_users();
600
601 auto &users = node_users[value_node];
602 auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
603 return IsPrimitiveCNode(user.first, prim);
604 });
605 if (cnode == users.end()) {
606 MS_LOG(INTERNAL_EXCEPTION) << "Fail to find user for " << prim->ToString();
607 }
608 auto cnode_first = cnode->first->cast_ptr<CNode>();
609 MS_EXCEPTION_IF_NULL(cnode_first);
610 auto inputs_num = cnode_first->size() - 1;
611 auto effect_info = GetPrimEffectInfo(prim);
612 // Don't add U or IO monad parameters as it will be added later.
613 size_t monad_params_size = 0;
614 if (effect_info.memory) {
615 monad_params_size++;
616 }
617 if (effect_info.io) {
618 monad_params_size++;
619 }
620 if (inputs_num < monad_params_size) {
621 MS_LOG(INTERNAL_EXCEPTION) << "Arguments number should be greater than or equal to " << monad_params_size
622 << ", but the CNode is: " << cnode->first->DebugString();
623 }
624 inputs_num -= monad_params_size;
625
626 auto func_graph = std::make_shared<FuncGraph>();
627 std::vector<AnfNodePtr> outputs;
628 outputs.push_back(NewValueNode(prim::kPrimMakeTuple));
629
630 auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
631 (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
632
633 for (size_t i = 0; i < inputs_num; ++i) {
634 // Mock params for inputs
635 auto param = func_graph->add_parameter();
636 // Mock derivatives for each inputs
637 if (IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
638 outputs.push_back(func_graph->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), param}));
639 } else {
640 outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param}));
641 }
642 }
643 // mock params for out and dout
644 (void)func_graph->add_parameter();
645 (void)func_graph->add_parameter();
646 func_graph->set_output(func_graph->NewCNode(outputs));
647 return func_graph;
648 }
649
CheckCustomVjp(const FuncGraphPtr & bprop_fg) const650 bool KPrim::CheckCustomVjp(const FuncGraphPtr &bprop_fg) const {
651 MS_EXCEPTION_IF_NULL(bprop_fg);
652 auto parameters_size = bprop_fg->parameters().size();
653 if (bprop_fg->has_flag("custom_vjp") && parameters_size == 1) {
654 return true;
655 }
656 return false;
657 }
658
GetCustomVjpBprop(const FuncGraphPtr & bprop_fg) const659 FuncGraphPtr KPrim::GetCustomVjpBprop(const FuncGraphPtr &bprop_fg) const {
660 MS_EXCEPTION_IF_NULL(bprop_fg);
661 auto bprop_fg_output = dyn_cast<CNode>(bprop_fg->output());
662 MS_EXCEPTION_IF_NULL(bprop_fg_output);
663 // Check the definition of the bprop function
664 if (IsValueNode<None>(bprop_fg_output->input(1))) {
665 MS_EXCEPTION(TypeError)
666 << "The bprop function of @custom_vjp is undefined. Please use 'defbwd(bprop)' to define the 'bprop' function.";
667 }
668
669 auto custom_vjp_bprop_fg = GetValueNode<FuncGraphPtr>(bprop_fg_output->input(1));
670 if (custom_vjp_bprop_fg != nullptr) {
671 custom_vjp_bprop_fg->set_transforms(bprop_fg->transforms());
672 return custom_vjp_bprop_fg;
673 } else {
674 MS_EXCEPTION(TypeError) << "The 'bprop' function defined by @custom_vjp defbwd(bprop) is illegal.";
675 }
676 }
677 } // namespace ad
678 } // namespace mindspore
679