• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2020-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef _WIN32
20 #include <dirent.h>
21 #endif
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include "ir/anf.h"
26 #include "pybind_api/ir/primitive_py.h"
27 #include "ir/meta_func_graph.h"
28 #include "ir/func_graph_cloner.h"
29 #include "ir/manager.h"
30 #include "pipeline/jit/resource.h"
31 #include "pipeline/jit/parse/parse.h"
32 #include "pipeline/jit/parse/resolve.h"
33 #include "frontend/optimizer/ad/dfunctor.h"
34 #include "frontend/operator/ops.h"
35 #include "frontend/operator/composite/composite.h"
36 #include "utils/utils.h"
37 #include "utils/symbolic.h"
38 #include "utils/primitive_utils.h"
39 #include "utils/ms_context.h"
40 #include "utils/info.h"
41 #include "debug/trace.h"
42 #include "debug/common.h"
43 #include "debug/dump_proto.h"
44 #include "mindspore/core/load_mindir/load_model.h"
45 #include "utils/system/sha256.h"
46 #include "utils/file_utils.h"
47 
48 namespace mindspore {
49 namespace ad {
50 KPrim g_k_prims;
51 
52 namespace {
53 constexpr char kBpropMindIRSuffix[] = "_bprop.mindir";
54 constexpr char kBpropMindIRDir[] = "/../bprop_mindir/";
55 constexpr char serializable_bprop_ops[] = "serializable_bprop_ops";
56 constexpr char bprop_mindir_module[] = "mindspore.ops.bprop_mindir";
57 
58 #ifndef _WIN32
GetBpropDir()59 std::string GetBpropDir() {
60   static std::string bprop_dir("");
61   if (bprop_dir.empty()) {
62     py::module mod = py::module::import("mindspore.ops._grad");
63     auto grad_file_path = mod.attr("__file__").cast<std::string>();
64     bprop_dir = grad_file_path.substr(0, grad_file_path.find_last_of('/'));
65   }
66   return bprop_dir;
67 }
68 
BpropMindirDirExists()69 bool BpropMindirDirExists() {
70   auto bprop_mindir_dir = GetBpropDir() + kBpropMindIRDir;
71   DIR *dir = opendir(bprop_mindir_dir.c_str());
72   if (dir != nullptr) {
73     if (closedir(dir) == -1) {
74       MS_LOG(WARNING) << "The bprop mindir dir \"" << bprop_mindir_dir << "\" close failed!";
75     }
76     return true;
77   }
78   MS_LOG(INFO) << "The bprop mindir dir \"" << bprop_mindir_dir << "\" doesn't exists.";
79   return false;
80 }
81 
82 // Get the serializable bprop list from the module mindspore.ops.bprop_mindir in python.
GetSerializableBpropList()83 std::unordered_set<std::string> GetSerializableBpropList() {
84   std::unordered_set<std::string> serializable_bprop_list;
85   if (!BpropMindirDirExists()) {
86     return serializable_bprop_list;
87   }
88   py::module mod = py::module::import(bprop_mindir_module);
89   py::object serializable_bprop_ops_attr = mod.attr(serializable_bprop_ops);
90   if (!py::isinstance<py::list>(serializable_bprop_ops_attr)) {
91     MS_LOG(WARNING) << "Can not get the the serializable bprop ops list from python, it is not a python list.";
92     return serializable_bprop_list;
93   }
94 
95   auto ops_list = serializable_bprop_ops_attr.cast<py::list>();
96   for (size_t i = 0; i < ops_list.size(); ++i) {
97     auto prim_adapter = ops_list[i].cast<PrimitivePyAdapterPtr>();
98     if (prim_adapter == nullptr) {
99       MS_LOG(EXCEPTION) << "The python obj in serializable bprop list should be a Primitive, but it is "
100                         << py::str(ops_list[i]);
101     }
102     (void)serializable_bprop_list.insert(prim_adapter->name());
103   }
104   return serializable_bprop_list;
105 }
106 
IsSerializableBprop(const std::string & prim_name)107 bool IsSerializableBprop(const std::string &prim_name) {
108   static std::unordered_set<std::string> serializable_bprop_list = GetSerializableBpropList();
109   return std::any_of(serializable_bprop_list.begin(), serializable_bprop_list.end(),
110                      [&prim_name](const std::string &serializable_bprop_prim_name) {
111                        return prim_name == serializable_bprop_prim_name;
112                      });
113 }
114 
GetBpropHash()115 std::string GetBpropHash() {
116   static std::string bprop_hash = std::string();
117   if (bprop_hash.empty()) {
118     auto bprop_dir = GetBpropDir();
119     auto realpath = FileUtils::GetRealPath(common::SafeCStr(bprop_dir));
120     if (!realpath.has_value()) {
121       MS_LOG(EXCEPTION) << "Get real path of bprop dir failed. path=" << bprop_dir;
122     }
123     bprop_hash = system::sha256::GetHashFromDir(realpath.value());
124   }
125   return bprop_hash;
126 }
127 
ImportBpropFromMindIR(const PrimitivePtr & prim)128 FuncGraphPtr ImportBpropFromMindIR(const PrimitivePtr &prim) {
129   MS_EXCEPTION_IF_NULL(prim);
130   std::string bprop_dir = GetBpropDir();
131   auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
132   std::optional<std::string> bprop_mindir_realpath =
133     FileUtils::GetRealPath(common::SafeCStr(bprop_mindir_path + prim->name() + kBpropMindIRSuffix));
134   bool bprop_cache_file_exists = bprop_mindir_realpath.has_value() && Common::FileExists(bprop_mindir_realpath.value());
135   if (!bprop_cache_file_exists) {
136     return nullptr;
137   }
138   auto bprop_fg = LoadMindIR(bprop_mindir_realpath.value());
139   if (bprop_fg != nullptr && bprop_fg->bprop_hash() != GetBpropHash()) {
140     MS_LOG(EXCEPTION) << "The bprop mindir files are not up to date. Please run the " << bprop_mindir_path
141                       << "generate_mindir.py to generate new mindir files.\n"
142                       << "bprop_fg hash: " << bprop_fg->bprop_hash() << "\n"
143                       << "bprop hash: " << GetBpropHash();
144   }
145   return bprop_fg;
146 }
147 
ExportBpropToMindIR(const PrimitivePtr & prim,const FuncGraphPtr & func_graph)148 void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_graph) {
149   MS_EXCEPTION_IF_NULL(prim);
150   std::string bprop_dir = GetBpropDir();
151   func_graph->set_bprop_hash(GetBpropHash());
152   auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
153   std::optional<std::string> bprop_mindir_realpath =
154     Common::CreatePrefixPath(bprop_mindir_path + prim->name() + kBpropMindIRSuffix, true);
155   if (!bprop_mindir_realpath.has_value()) {
156     MS_LOG(ERROR) << "Failed to get the realpath of bprop mindir: " << bprop_mindir_path << prim->name()
157                   << kBpropMindIRSuffix;
158     return;
159   }
160   std::ofstream fout(bprop_mindir_realpath.value());
161   mind_ir::ModelProto fg_model = GetBinaryProto(func_graph, false);
162   if (!fg_model.SerializeToOstream(&fout)) {
163     MS_LOG(WARNING) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \""
164                     << bprop_mindir_realpath.value() << "\".";
165   }
166   fout.close();
167   ChangeFileMode(bprop_mindir_realpath.value(), S_IRUSR | S_IWUSR);
168 }
169 
GetPythonOps(const FuncGraphPtr & fg,const AnfNodePtr & origin_node,const PrimitivePtr & prim)170 AnfNodePtr GetPythonOps(const FuncGraphPtr &fg, const AnfNodePtr &origin_node, const PrimitivePtr &prim) {
171   MS_EXCEPTION_IF_NULL(fg);
172   MS_EXCEPTION_IF_NULL(origin_node);
173   MS_EXCEPTION_IF_NULL(prim);
174   // DoSignaturePrimitive to the pair of primitive name and module name.
175   static std::unordered_map<std::string, std::pair<std::string, std::string>> python_ops{
176     {"S-Prim-zeros_like_leaf", {"zeros_like", ""}},
177     {"S-Prim-getitem", {"getitem", "mindspore.ops.composite.multitype_ops.getitem_impl"}}};
178   auto iter = python_ops.find(prim->name());
179   if (iter == python_ops.end()) {
180     return nullptr;
181   }
182   ValuePtr python_ops_value;
183   if (!iter->second.second.empty()) {
184     python_ops_value = prim::GetPythonOps(iter->second.first, iter->second.second);
185   } else {
186     python_ops_value = prim::GetPythonOps(iter->second.first);
187   }
188   auto origin_cnode = origin_node->cast<CNodePtr>();
189   MS_EXCEPTION_IF_NULL(origin_cnode);
190   auto &origin_inputs = origin_cnode->inputs();
191   std::vector<AnfNodePtr> new_inputs{NewValueNode(python_ops_value)};
192   (void)std::copy(origin_inputs.begin() + 1, origin_inputs.end(), std::back_inserter(new_inputs));
193   return fg->NewCNode(new_inputs);
194 }
195 
196 // Replace the nodes whose python obj of primitive is needed in the renormalize process,
197 // with the new created python ops, such as zeros_like.
ReplacePythonOps(const FuncGraphPtr & fg)198 void ReplacePythonOps(const FuncGraphPtr &fg) {
199   MS_EXCEPTION_IF_NULL(fg);
200   std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(fg->get_return());
201   for (const auto &node : all_nodes) {
202     MS_EXCEPTION_IF_NULL(node);
203     if (!node->isa<CNode>()) {
204       continue;
205     }
206     auto cnode = node->cast<CNodePtr>();
207     for (size_t i = 0; i < cnode->size(); ++i) {
208       auto prim = GetCNodePrimitive(cnode->input(i));
209       if (prim == nullptr) {
210         continue;
211       }
212       auto new_input = GetPythonOps(fg, cnode->input(i), prim);
213       if (new_input == nullptr) {
214         continue;
215       }
216       cnode->set_input(i, new_input);
217     }
218   }
219 }
220 #endif
221 }  // namespace
222 
223 #ifndef _WIN32
224 // Given a python primitive, export a mindir file from the bprop defined in python.
ExportBpropMindir(const py::object & obj)225 void KPrim::ExportBpropMindir(const py::object &obj) {
226   auto prim_adapter = obj.cast<PrimitivePyAdapterPtr>();
227   if (prim_adapter == nullptr) {
228     MS_LOG(EXCEPTION) << "The python obj to be exported to bprop mindir should be a Primitive, but it is "
229                       << py::str(obj);
230   }
231   auto prim = prim_adapter->attached_primitive();
232   if (prim == nullptr) {
233     prim = std::make_shared<PrimitivePy>(obj, prim_adapter);
234     prim_adapter->set_attached_primitive(prim);
235   }
236 
237   // Get the bprop function from python.
238   py::function fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
239   if (py::isinstance<py::none>(fn)) {
240     fn = GetBpropFunction(prim->name());
241   }
242   if (!fn || py::isinstance<py::none>(fn)) {
243     MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
244   }
245   auto func_graph = parse::ParsePythonCode(fn);
246   if (func_graph == nullptr) {
247     MS_LOG(EXCEPTION) << "Fail to parse bprop function for " << prim->name() << ".";
248   }
249   auto res = std::make_shared<pipeline::Resource>();
250   (void)parse::ResolveFuncGraph(func_graph, res);
251   ExportBpropToMindIR(prim, func_graph);
252 }
253 #endif
254 
GetBprop(const PrimitivePtr & prim,const pipeline::ResourceBasePtr & resources)255 FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
256   // Set a child scope named "grad'PrimitiveName'" for the bprop function,
257   // and add "Gradients" to the front.
258   static const std::string gradients_scope = "Gradients/";
259   static const std::string grad_op_child_scope_prefix = "/grad";
260   MS_EXCEPTION_IF_NULL(prim);
261   auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
262                                        grad_op_child_scope_prefix + prim->name());
263   ScopeGuard scope_guard(scope);
264 
265   // Firstly we get bprop from mindir. If failed, parse the python function registered.
266   FuncGraphPtr func_graph = nullptr;
267 #ifndef _WIN32
268   bool serializable = IsSerializableBprop(prim->name());
269   if (serializable) {
270     func_graph = ImportBpropFromMindIR(prim);
271     if (func_graph != nullptr) {
272       ReplacePythonOps(func_graph);
273       return func_graph;
274     }
275   }
276 #endif
277   py::function fn;
278   if (prim->is_base()) {
279     fn = GetBpropFunction(prim->name());
280   } else {
281     fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
282     if (py::isinstance<py::none>(fn)) {
283       fn = GetBpropFunction(prim->name());
284     }
285   }
286   if (!fn || py::isinstance<py::none>(fn)) {
287     MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn);
288     return nullptr;
289   }
290   func_graph = parse::ParsePythonCode(fn);
291   if (func_graph == nullptr) {
292     MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << ".";
293     return nullptr;
294   }
295   auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
296   if (bprop_flag) {
297     func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
298   }
299   pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>();
300   (void)parse::ResolveFuncGraph(func_graph, res);
301 #ifndef _WIN32
302   // Check whether the bprop needs to be exported.
303   if (serializable) {
304     ExportBpropToMindIR(prim, func_graph);
305   }
306 #endif
307   return func_graph;
308 }
309 
GetPossibleBprop(const PrimitivePtr & prim)310 FuncGraphPtr KPrim::GetPossibleBprop(const PrimitivePtr &prim) {
311   FuncGraphPtr bprop_fg = nullptr;
312   auto iter = bprop_registry_.find(prim);
313   if (iter != bprop_registry_.end()) {
314     bprop_fg = iter->second;
315   }
316 
317   if (bprop_fg == nullptr) {
318     bprop_fg = GetBprop(prim);
319     if (bprop_fg != nullptr) {
320       // Set bprop_g graph cache
321       bprop_registry_[prim] = bprop_fg;
322     }
323   }
324   return bprop_fg;
325 }
326 
GetFprop(const PrimitivePtr & prim)327 FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) {
328   static const std::string ad_module = "mindspore.ops._grad.grad_implementations";
329   std::string func_name = "_fprop_" + prim->name();
330   py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name);
331   auto func_graph = parse::ParsePythonCode(fn);
332   MS_EXCEPTION_IF_NULL(func_graph);
333   return BasicClone(func_graph);
334 }
335 
KMetaFuncGraph(const PrimitivePtr & prim)336 MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
337   MS_EXCEPTION_IF_NULL(prim);
338 
339   auto iter = bprop_registry_meta_.find(prim);
340   if (iter != bprop_registry_meta_.end()) {
341     return iter->second;
342   }
343 
344   if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
345     MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
346     bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
347     return meta;
348   }
349 
350   if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
351     MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient");
352     bprop_registry_meta_[prim::kPrimMakeList] = meta;
353     return meta;
354   }
355 
356   MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
357 }
358 
AppendMonadOutput(const FuncGraphPtr & bprop_fg,const AnfNodePtr & monad)359 static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &monad) {
360   const auto &output = bprop_fg->output();
361   MS_EXCEPTION_IF_NULL(output);
362   auto output_cnode = output->cast<CNodePtr>();
363   if (output_cnode != nullptr) {
364     // If output_cnode has the form like (make_tuple, x, y).
365     output_cnode->add_input(monad);
366     return;
367   }
368   // If output is an empty tuple, create a (make_tuple, monad) as the new output.
369   auto make_tuple = NewValueNode(prim::kPrimMakeTuple);
370   output_cnode = bprop_fg->NewCNode({make_tuple, monad});
371   bprop_fg->set_output(output_cnode);
372 }
373 
374 // Append U or/and IO monad to output of Bprop funcgraph.
AdjustForAutoMonad(const PrimitivePtr & prim,const FuncGraphPtr & bprop_fg)375 static void AdjustForAutoMonad(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
376   auto effect_info = GetPrimEffectInfo(prim);
377   if (effect_info.memory) {
378     MS_LOG(DEBUG) << "Append U monad for Bprop FuncGraph of Primitive " << prim->ToString();
379     auto u = NewValueNode(kUMonad);
380     u->set_abstract(kUMonad->ToAbstract());
381     AppendMonadOutput(bprop_fg, u);
382   }
383   if (effect_info.io) {
384     MS_LOG(DEBUG) << "Append IO monad for Bprop FuncGraph of Primitive " << prim->ToString();
385     auto io = NewValueNode(kIOMonad);
386     io->set_abstract(kIOMonad->ToAbstract());
387     AppendMonadOutput(bprop_fg, io);
388   }
389 }
390 
GetBprop(const CNodePtr & cnode,const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources,const PrimitivePtr & prim)391 FuncGraphPtr KPrim::GetBprop(const CNodePtr &cnode, const ValueNodePtr &value_node,
392                              const pipeline::ResourceBasePtr &resources, const PrimitivePtr &prim) {
393   FuncGraphPtr bprop_fg = nullptr;
394   if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
395     if (MsContext::GetInstance()->get_param<int>(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) {
396       MS_LOG(EXCEPTION)
397         << "The Primitive 'HookBackward' is not supported in graph mode, which is only supported in pynative mode.\n"
398         << trace::GetDebugInfo(cnode->debug_info());
399     }
400     bprop_fg = BpropCut(value_node, resources);
401   } else {
402     auto iter = bprop_registry_.find(prim);
403     if (iter != bprop_registry_.end()) {
404       bprop_fg = iter->second;
405     }
406 
407     if (bprop_fg == nullptr) {
408       bprop_fg = GetBprop(prim, resources);
409       if (bprop_fg != nullptr) {
410         // Set bprop_g graph cache
411         bprop_registry_[prim] = bprop_fg;
412       } else {
413         bprop_fg = FakeBprop(value_node, resources);
414       }
415     }
416   }
417   return bprop_fg;
418 }
419 
KPrimitive(const CNodePtr & cnode,const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)420 FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
421                                const pipeline::ResourceBasePtr &resources) {
422   if (!IsValueNode<Primitive>(value_node)) {
423     MS_LOG(EXCEPTION) << "Primitive node is not valid.";
424   }
425 
426   auto prim = GetValueNode<PrimitivePtr>(value_node);
427   if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
428     auto fprop = GetFprop(prim);
429     fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
430     return fprop;
431   } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
432     return nullptr;
433   } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
434     return nullptr;
435   }
436 
437   FuncGraphPtr bprop_fg = GetBprop(cnode, value_node, resources, prim);
438   AdjustForAutoMonad(prim, bprop_fg);
439   std::unordered_map<std::string, ValuePtr> primal_attrs;
440   std::vector<NodeDebugInfoPtr> primal_debug_infos;
441   if (resources != nullptr) {
442     auto manager = resources->manager();
443     auto &users = manager->node_users()[value_node];
444     for (auto user_iter = users.begin(); user_iter != users.end(); ++user_iter) {
445       primal_debug_infos.push_back(user_iter->first->debug_info());
446     }
447   }
448   if (cnode != nullptr) {
449     primal_attrs = cnode->primal_attrs();
450     const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
451     primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
452   }
453   auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos);
454   if (expanded_fg == nullptr) {
455     MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
456                       << " prim bprop function to J expanded func graph. NodeInfo: "
457                       << trace::GetDebugInfo(bprop_fg->debug_info());
458   }
459   if (lift_fv_before_grad && IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
460     // Inline fprop_switch before renormalize;
461     expanded_fg->set_flag(FUNC_GRAPH_FLAG_FORCE_INLINE, true);
462     MS_LOG(DEBUG) << "set force_inline for fg: " << expanded_fg->ToString();
463   }
464 
465   return expanded_fg;
466 }
467 
BuildOutput(const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg)468 AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg) {
469   // current_primal_fg may have extra parameters like u_monad, io_monad
470   std::vector<AnfNodePtr> extra_args;
471   // caller had checked size() - 2 is greater than 0.
472   auto bprop_fg_param_size = bprop_fg->parameters().size() - 2;
473   if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) {
474     auto current_primal_fg_param_size = current_primal_fg->parameters().size();
475     MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so "
476                      "Insert it. Extra parameters size: "
477                   << current_primal_fg_param_size - bprop_fg_param_size;
478     for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) {
479       const auto &primal_node = current_primal_fg->parameters()[i];
480       AnfNodePtr extra_node;
481       // Simplify zeros_like(primal_node) to U or IO, so extra_node in bprop_fg will not refer to primal_node
482       // as a free variable of primal_graph.
483       // Notes: if the implementation of zeros_like changes, here too.
484       if (HasAbstractUMonad(primal_node)) {
485         extra_node = NewValueNode(kUMonad);
486       } else if (HasAbstractIOMonad(primal_node)) {
487         extra_node = NewValueNode(kIOMonad);
488       } else {
489         MS_EXCEPTION(TypeError)
490           << "The params of function 'bprop' of Primitive or Cell requires the forward inputs as well "
491              "as the 'out' and 'dout'.\n"
492           << trace::GetDebugInfo(bprop_fg->debug_info());
493       }
494       extra_args.push_back(extra_node);
495       MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString();
496     }
497   }
498   // bprop_fg has been checked in caller
499   if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
500     // Set bprop output as (env, dx, dy, dz, ...)
501     auto cbprop = bprop_fg->output()->cast<CNodePtr>();
502     auto &inputs = cbprop->inputs();
503 
504     std::vector<AnfNodePtr> args;
505     args.push_back(NewValueNode(prim::kPrimMakeTuple));
506     args.push_back(NewValueNode(newenv));
507     (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
508     if (!extra_args.empty()) {
509       args.insert(args.end(), extra_args.cbegin(), extra_args.cend());
510     }
511     return NewCNode(args, bprop_fg);
512   }
513 
514   // Set bprop output as (env, dx)
515   std::string model_name("mindspore.ops.composite.multitype_ops.add_impl");
516   std::string python_ops("_tuple_add");
517   auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg);
518   auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
519   if (!extra_args.empty()) {
520     extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple));
521     auto extra_tuple = NewCNode(extra_args, bprop_fg);
522     auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg);
523     return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg);
524   }
525 
526   return NewCNode({tuple_add_ops, tuple_env, bprop_fg->output()}, bprop_fg);
527 }
528 
TransformNormalArgs(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * const transf_args)529 static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
530                                 std::vector<AnfNodePtr> *const transf_args) {
531   // bprop_fg has been checked in caller
532   // transform except the last 2 parameters: out, dout.
533   const size_t last_parameter_sizes = 2;
534   auto bprop_fg_param_size = bprop_fg->parameters().size() - last_parameter_sizes;
535   for (size_t i = 0; i < bprop_fg_param_size; ++i) {
536     auto p = bprop_fg->parameters()[i];
537     MS_EXCEPTION_IF_NULL(p);
538 
539     TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
540     auto transf_p = outer->add_parameter();
541 
542     (void)mng->Replace(p, transf_p);
543     transf_args->push_back(transf_p);
544   }
545 }
TransformArgsForPrimitive(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const PrimitivePtr & primitive,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * const transf_args)546 void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
547                                       const PrimitivePtr &primitive, const FuncGraphPtr &outer,
548                                       std::vector<AnfNodePtr> *const transf_args) {
549   MS_EXCEPTION_IF_NULL(mng);
550   TransformNormalArgs(mng, bprop_fg, outer, transf_args);
551   // Fprop_fg for Primitive with side effect should append extra U or IO monad parameter.
552   auto effect_info = GetPrimEffectInfo(primitive);
553   if (effect_info.memory) {
554     MS_LOG(DEBUG) << "Append U monad to Fprop FuncGraph for Primitive " << primitive->ToString();
555     auto transf_p = outer->add_parameter();
556     transf_args->push_back(transf_p);
557   }
558   if (effect_info.io) {
559     MS_LOG(DEBUG) << "Append IO monad to Fprop FuncGraph for Primitive " << primitive->ToString();
560     auto transf_p = outer->add_parameter();
561     transf_args->push_back(transf_p);
562   }
563 }
564 
565 template <typename T>
TransformArgsForFuncGraph(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const T & current_primal_fg,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * const transf_args)566 void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
567                                       const T &current_primal_fg, const FuncGraphPtr &outer,
568                                       std::vector<AnfNodePtr> *const transf_args) {
569   MS_EXCEPTION_IF_NULL(mng);
570   TransformNormalArgs(mng, bprop_fg, outer, transf_args);
571   constexpr size_t need_filter_size = 2;
572   auto bprop_fg_param_size = bprop_fg->parameters().size() - need_filter_size;
573   // current_primal_fg may have extra parameters after AutoMonad
574   const auto &current_primal_fg_params = current_primal_fg->parameters();
575   if (bprop_fg_param_size < current_primal_fg_params.size()) {
576     for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) {
577       auto p = current_primal_fg_params[i];
578       MS_EXCEPTION_IF_NULL(p);
579       // extra parameters should be Monad.
580       if (!HasAbstractMonad(p)) {
581         continue;
582       }
583       MS_LOG(DEBUG) << "Function " << current_primal_fg->ToString()
584                     << ", has extra monad parameter: " << p->DebugString()
585                     << ", abstract: " << p->abstract()->ToString();
586 
587       TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
588       auto transf_p = outer->add_parameter();
589       // See also Notes on extra_node of BuildOutput.
590       // Notes: No need to replace p with transf_p as the only use of p is here.
591       // If extra_node in bprop_fg use p as free variable, a replacement of p is required here.
592       // This replacement will make the usage of p in current_primal_fg got replaced with transf_p
593       // of outer. outer will be released after it is being cloned to fprop_fg, so the func_graph_
594       // in transf_p will be nullptr.
595       // So the RULE is DONT tamper the current_primal_fg;
596       transf_args->push_back(transf_p);
597     }
598   }
599   if (transf_args->size() != current_primal_fg_params.size()) {
600     MS_EXCEPTION(TypeError) << "Function " << current_primal_fg->ToString()
601                             << ", The number of parameter of this primal function is "
602                             << current_primal_fg_params.size() << ", but the number of parameters of bprop is "
603                             << bprop_fg_param_size;
604   }
605 }
606 
CheckBprop(const FuncGraphPtr & bprop_fg,const string & prim_to_check)607 void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
608   auto context = MsContext::GetInstance();
609   MS_EXCEPTION_IF_NULL(context);
610   bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
611   // Skip checking if check_bprop not set
612   if (!check_bprop_flag) {
613     return;
614   }
615 
616   // bprop_fg has been checked in caller
617   auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
618   MS_EXCEPTION_IF_NULL(check_bprop_class);
619   auto check_bprop =
620     bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
621 
622   std::vector<AnfNodePtr> inputs;
623   inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
624   constexpr int primitive_size = 1;
625   constexpr int brprop_offset_size = 2;
626   (void)inputs.insert(inputs.begin() + primitive_size, bprop_fg->parameters().begin(),
627                       bprop_fg->parameters().end() - brprop_offset_size);
628   AnfNodePtr params = bprop_fg->NewCNode(inputs);
629 
630   inputs.clear();
631   inputs.push_back(check_bprop);
632   inputs.push_back(bprop_fg->output());
633   inputs.push_back(params);
634   AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
635   bprop_fg->set_output(bprop_out);
636 }
637 
KUserDefinedCellBprop(const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg)638 FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr &current_primal_fg) {
639   MS_EXCEPTION_IF_NULL(bprop_fg);
640   // primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph.
641   // current_primal_fg is specalized and AutoMoaded primal_fg;
642   auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph();
643   auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr, {}, {});
644   if (expanded_fg == nullptr) {
645     MS_LOG(EXCEPTION) << "Failed convert " << primal_fg->ToString()
646                       << " Cell bprop function to K expanded func graph. NodeInfo: "
647                       << trace::GetDebugInfo(primal_fg->debug_info());
648   }
649   return expanded_fg;
650 }
651 
BpropCut(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)652 FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
653   auto prim = GetValueNode<PrimitivePtr>(value_node);
654   MS_EXCEPTION_IF_NULL(prim);
655   auto &node_users = resources->manager()->node_users();
656 
657   auto &users = node_users[value_node];
658   auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
659     return IsPrimitiveCNode(user.first, prim);
660   });
661   if (cnode == users.end()) {
662     MS_LOG(EXCEPTION) << "Fail to find cnode.";
663   }
664   auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1;
665 
666   auto func_graph = std::make_shared<FuncGraph>();
667   std::vector<AnfNodePtr> outputs;
668   auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
669   bprop_cut->CopyHookFunction(prim);
670 
671   auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
672   if (cell_id != "") {
673     (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
674     (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
675   }
676 
677   outputs.push_back(NewValueNode(bprop_cut));
678   for (size_t i = 0; i < inputs_num; ++i) {
679     auto param = func_graph->add_parameter();
680     outputs.push_back(param);
681   }
682   auto p1 = func_graph->add_parameter();
683   auto p2 = func_graph->add_parameter();
684   outputs.push_back(p1);
685   outputs.push_back(p2);
686 
687   func_graph->set_output(func_graph->NewCNode(outputs));
688   return func_graph;
689 }
690 
FakeBprop(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)691 FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
692   auto prim = value_node->value()->cast<PrimitivePtr>();
693   MS_EXCEPTION_IF_NULL(prim);
694   auto &node_users = resources->manager()->node_users();
695 
696   auto &users = node_users[value_node];
697   auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
698     return IsPrimitiveCNode(user.first, prim);
699   });
700   if (cnode == users.end()) {
701     MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
702   }
703   auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
704   auto effect_info = GetPrimEffectInfo(prim);
705   // Don't add U or IO monad parameters as it will be added later.
706   size_t monad_params_size = 0;
707   if (effect_info.memory) {
708     monad_params_size++;
709   }
710   if (effect_info.io) {
711     monad_params_size++;
712   }
713   if (inputs_num < monad_params_size) {
714     MS_LOG(EXCEPTION) << "Arguments number should be greater than or equal to " << monad_params_size
715                       << ", but the CNode is: " << cnode->first->DebugString();
716   }
717   inputs_num -= monad_params_size;
718 
719   auto func_graph = std::make_shared<FuncGraph>();
720   std::vector<AnfNodePtr> outputs;
721   outputs.push_back(NewValueNode(prim::kPrimMakeTuple));
722 
723   auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
724   (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
725 
726   for (size_t i = 0; i < inputs_num; ++i) {
727     // Mock params for inputs
728     auto param = func_graph->add_parameter();
729     // Mock derivatives for each inputs
730     outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param}));
731   }
732   // mock params for out and dout
733   (void)func_graph->add_parameter();
734   (void)func_graph->add_parameter();
735   func_graph->set_output(func_graph->NewCNode(outputs));
736   return func_graph;
737 }
738 }  // namespace ad
739 }  // namespace mindspore
740