1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2020-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #ifndef _WIN32
20 #include <dirent.h>
21 #endif
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include "ir/anf.h"
26 #include "pybind_api/ir/primitive_py.h"
27 #include "ir/meta_func_graph.h"
28 #include "ir/func_graph_cloner.h"
29 #include "ir/manager.h"
30 #include "pipeline/jit/resource.h"
31 #include "pipeline/jit/parse/parse.h"
32 #include "pipeline/jit/parse/resolve.h"
33 #include "frontend/optimizer/ad/dfunctor.h"
34 #include "frontend/operator/ops.h"
35 #include "frontend/operator/composite/composite.h"
36 #include "utils/utils.h"
37 #include "utils/symbolic.h"
38 #include "utils/primitive_utils.h"
39 #include "utils/ms_context.h"
40 #include "utils/info.h"
41 #include "debug/trace.h"
42 #include "debug/common.h"
43 #include "debug/dump_proto.h"
44 #include "mindspore/core/load_mindir/load_model.h"
45 #include "utils/system/sha256.h"
46 #include "utils/file_utils.h"
47
48 namespace mindspore {
49 namespace ad {
50 KPrim g_k_prims;
51
52 namespace {
53 constexpr char kBpropMindIRSuffix[] = "_bprop.mindir";
54 constexpr char kBpropMindIRDir[] = "/../bprop_mindir/";
55 constexpr char serializable_bprop_ops[] = "serializable_bprop_ops";
56 constexpr char bprop_mindir_module[] = "mindspore.ops.bprop_mindir";
57
58 #ifndef _WIN32
GetBpropDir()59 std::string GetBpropDir() {
60 static std::string bprop_dir("");
61 if (bprop_dir.empty()) {
62 py::module mod = py::module::import("mindspore.ops._grad");
63 auto grad_file_path = mod.attr("__file__").cast<std::string>();
64 bprop_dir = grad_file_path.substr(0, grad_file_path.find_last_of('/'));
65 }
66 return bprop_dir;
67 }
68
BpropMindirDirExists()69 bool BpropMindirDirExists() {
70 auto bprop_mindir_dir = GetBpropDir() + kBpropMindIRDir;
71 DIR *dir = opendir(bprop_mindir_dir.c_str());
72 if (dir != nullptr) {
73 if (closedir(dir) == -1) {
74 MS_LOG(WARNING) << "The bprop mindir dir \"" << bprop_mindir_dir << "\" close failed!";
75 }
76 return true;
77 }
78 MS_LOG(INFO) << "The bprop mindir dir \"" << bprop_mindir_dir << "\" doesn't exists.";
79 return false;
80 }
81
82 // Get the serializable bprop list from the module mindspore.ops.bprop_mindir in python.
GetSerializableBpropList()83 std::unordered_set<std::string> GetSerializableBpropList() {
84 std::unordered_set<std::string> serializable_bprop_list;
85 if (!BpropMindirDirExists()) {
86 return serializable_bprop_list;
87 }
88 py::module mod = py::module::import(bprop_mindir_module);
89 py::object serializable_bprop_ops_attr = mod.attr(serializable_bprop_ops);
90 if (!py::isinstance<py::list>(serializable_bprop_ops_attr)) {
91 MS_LOG(WARNING) << "Can not get the the serializable bprop ops list from python, it is not a python list.";
92 return serializable_bprop_list;
93 }
94
95 auto ops_list = serializable_bprop_ops_attr.cast<py::list>();
96 for (size_t i = 0; i < ops_list.size(); ++i) {
97 auto prim_adapter = ops_list[i].cast<PrimitivePyAdapterPtr>();
98 if (prim_adapter == nullptr) {
99 MS_LOG(EXCEPTION) << "The python obj in serializable bprop list should be a Primitive, but it is "
100 << py::str(ops_list[i]);
101 }
102 (void)serializable_bprop_list.insert(prim_adapter->name());
103 }
104 return serializable_bprop_list;
105 }
106
IsSerializableBprop(const std::string & prim_name)107 bool IsSerializableBprop(const std::string &prim_name) {
108 static std::unordered_set<std::string> serializable_bprop_list = GetSerializableBpropList();
109 return std::any_of(serializable_bprop_list.begin(), serializable_bprop_list.end(),
110 [&prim_name](const std::string &serializable_bprop_prim_name) {
111 return prim_name == serializable_bprop_prim_name;
112 });
113 }
114
GetBpropHash()115 std::string GetBpropHash() {
116 static std::string bprop_hash = std::string();
117 if (bprop_hash.empty()) {
118 auto bprop_dir = GetBpropDir();
119 auto realpath = FileUtils::GetRealPath(common::SafeCStr(bprop_dir));
120 if (!realpath.has_value()) {
121 MS_LOG(EXCEPTION) << "Get real path of bprop dir failed. path=" << bprop_dir;
122 }
123 bprop_hash = system::sha256::GetHashFromDir(realpath.value());
124 }
125 return bprop_hash;
126 }
127
ImportBpropFromMindIR(const PrimitivePtr & prim)128 FuncGraphPtr ImportBpropFromMindIR(const PrimitivePtr &prim) {
129 MS_EXCEPTION_IF_NULL(prim);
130 std::string bprop_dir = GetBpropDir();
131 auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
132 std::optional<std::string> bprop_mindir_realpath =
133 FileUtils::GetRealPath(common::SafeCStr(bprop_mindir_path + prim->name() + kBpropMindIRSuffix));
134 bool bprop_cache_file_exists = bprop_mindir_realpath.has_value() && Common::FileExists(bprop_mindir_realpath.value());
135 if (!bprop_cache_file_exists) {
136 return nullptr;
137 }
138 auto bprop_fg = LoadMindIR(bprop_mindir_realpath.value());
139 if (bprop_fg != nullptr && bprop_fg->bprop_hash() != GetBpropHash()) {
140 MS_LOG(EXCEPTION) << "The bprop mindir files are not up to date. Please run the " << bprop_mindir_path
141 << "generate_mindir.py to generate new mindir files.\n"
142 << "bprop_fg hash: " << bprop_fg->bprop_hash() << "\n"
143 << "bprop hash: " << GetBpropHash();
144 }
145 return bprop_fg;
146 }
147
ExportBpropToMindIR(const PrimitivePtr & prim,const FuncGraphPtr & func_graph)148 void ExportBpropToMindIR(const PrimitivePtr &prim, const FuncGraphPtr &func_graph) {
149 MS_EXCEPTION_IF_NULL(prim);
150 std::string bprop_dir = GetBpropDir();
151 func_graph->set_bprop_hash(GetBpropHash());
152 auto bprop_mindir_path = bprop_dir + kBpropMindIRDir;
153 std::optional<std::string> bprop_mindir_realpath =
154 Common::CreatePrefixPath(bprop_mindir_path + prim->name() + kBpropMindIRSuffix, true);
155 if (!bprop_mindir_realpath.has_value()) {
156 MS_LOG(ERROR) << "Failed to get the realpath of bprop mindir: " << bprop_mindir_path << prim->name()
157 << kBpropMindIRSuffix;
158 return;
159 }
160 std::ofstream fout(bprop_mindir_realpath.value());
161 mind_ir::ModelProto fg_model = GetBinaryProto(func_graph, false);
162 if (!fg_model.SerializeToOstream(&fout)) {
163 MS_LOG(WARNING) << "Failed to cache the bprop of op \"" << prim->name() << "\" to file \""
164 << bprop_mindir_realpath.value() << "\".";
165 }
166 fout.close();
167 ChangeFileMode(bprop_mindir_realpath.value(), S_IRUSR | S_IWUSR);
168 }
169
GetPythonOps(const FuncGraphPtr & fg,const AnfNodePtr & origin_node,const PrimitivePtr & prim)170 AnfNodePtr GetPythonOps(const FuncGraphPtr &fg, const AnfNodePtr &origin_node, const PrimitivePtr &prim) {
171 MS_EXCEPTION_IF_NULL(fg);
172 MS_EXCEPTION_IF_NULL(origin_node);
173 MS_EXCEPTION_IF_NULL(prim);
174 // DoSignaturePrimitive to the pair of primitive name and module name.
175 static std::unordered_map<std::string, std::pair<std::string, std::string>> python_ops{
176 {"S-Prim-zeros_like_leaf", {"zeros_like", ""}},
177 {"S-Prim-getitem", {"getitem", "mindspore.ops.composite.multitype_ops.getitem_impl"}}};
178 auto iter = python_ops.find(prim->name());
179 if (iter == python_ops.end()) {
180 return nullptr;
181 }
182 ValuePtr python_ops_value;
183 if (!iter->second.second.empty()) {
184 python_ops_value = prim::GetPythonOps(iter->second.first, iter->second.second);
185 } else {
186 python_ops_value = prim::GetPythonOps(iter->second.first);
187 }
188 auto origin_cnode = origin_node->cast<CNodePtr>();
189 MS_EXCEPTION_IF_NULL(origin_cnode);
190 auto &origin_inputs = origin_cnode->inputs();
191 std::vector<AnfNodePtr> new_inputs{NewValueNode(python_ops_value)};
192 (void)std::copy(origin_inputs.begin() + 1, origin_inputs.end(), std::back_inserter(new_inputs));
193 return fg->NewCNode(new_inputs);
194 }
195
196 // Replace the nodes whose python obj of primitive is needed in the renormalize process,
197 // with the new created python ops, such as zeros_like.
ReplacePythonOps(const FuncGraphPtr & fg)198 void ReplacePythonOps(const FuncGraphPtr &fg) {
199 MS_EXCEPTION_IF_NULL(fg);
200 std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(fg->get_return());
201 for (const auto &node : all_nodes) {
202 MS_EXCEPTION_IF_NULL(node);
203 if (!node->isa<CNode>()) {
204 continue;
205 }
206 auto cnode = node->cast<CNodePtr>();
207 for (size_t i = 0; i < cnode->size(); ++i) {
208 auto prim = GetCNodePrimitive(cnode->input(i));
209 if (prim == nullptr) {
210 continue;
211 }
212 auto new_input = GetPythonOps(fg, cnode->input(i), prim);
213 if (new_input == nullptr) {
214 continue;
215 }
216 cnode->set_input(i, new_input);
217 }
218 }
219 }
220 #endif
221 } // namespace
222
223 #ifndef _WIN32
224 // Given a python primitive, export a mindir file from the bprop defined in python.
ExportBpropMindir(const py::object & obj)225 void KPrim::ExportBpropMindir(const py::object &obj) {
226 auto prim_adapter = obj.cast<PrimitivePyAdapterPtr>();
227 if (prim_adapter == nullptr) {
228 MS_LOG(EXCEPTION) << "The python obj to be exported to bprop mindir should be a Primitive, but it is "
229 << py::str(obj);
230 }
231 auto prim = prim_adapter->attached_primitive();
232 if (prim == nullptr) {
233 prim = std::make_shared<PrimitivePy>(obj, prim_adapter);
234 prim_adapter->set_attached_primitive(prim);
235 }
236
237 // Get the bprop function from python.
238 py::function fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
239 if (py::isinstance<py::none>(fn)) {
240 fn = GetBpropFunction(prim->name());
241 }
242 if (!fn || py::isinstance<py::none>(fn)) {
243 MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
244 }
245 auto func_graph = parse::ParsePythonCode(fn);
246 if (func_graph == nullptr) {
247 MS_LOG(EXCEPTION) << "Fail to parse bprop function for " << prim->name() << ".";
248 }
249 auto res = std::make_shared<pipeline::Resource>();
250 (void)parse::ResolveFuncGraph(func_graph, res);
251 ExportBpropToMindIR(prim, func_graph);
252 }
253 #endif
254
GetBprop(const PrimitivePtr & prim,const pipeline::ResourceBasePtr & resources)255 FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim, const pipeline::ResourceBasePtr &resources) {
256 // Set a child scope named "grad'PrimitiveName'" for the bprop function,
257 // and add "Gradients" to the front.
258 static const std::string gradients_scope = "Gradients/";
259 static const std::string grad_op_child_scope_prefix = "/grad";
260 MS_EXCEPTION_IF_NULL(prim);
261 auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
262 grad_op_child_scope_prefix + prim->name());
263 ScopeGuard scope_guard(scope);
264
265 // Firstly we get bprop from mindir. If failed, parse the python function registered.
266 FuncGraphPtr func_graph = nullptr;
267 #ifndef _WIN32
268 bool serializable = IsSerializableBprop(prim->name());
269 if (serializable) {
270 func_graph = ImportBpropFromMindIR(prim);
271 if (func_graph != nullptr) {
272 ReplacePythonOps(func_graph);
273 return func_graph;
274 }
275 }
276 #endif
277 py::function fn;
278 if (prim->is_base()) {
279 fn = GetBpropFunction(prim->name());
280 } else {
281 fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
282 if (py::isinstance<py::none>(fn)) {
283 fn = GetBpropFunction(prim->name());
284 }
285 }
286 if (!fn || py::isinstance<py::none>(fn)) {
287 MS_LOG(WARNING) << "Fail to find bprop function for " << prim->name() << ". fn: " << py::str(fn);
288 return nullptr;
289 }
290 func_graph = parse::ParsePythonCode(fn);
291 if (func_graph == nullptr) {
292 MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << ".";
293 return nullptr;
294 }
295 auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP);
296 if (bprop_flag) {
297 func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true);
298 }
299 pipeline::ResourceBasePtr res = (resources != nullptr) ? resources : std::make_shared<pipeline::Resource>();
300 (void)parse::ResolveFuncGraph(func_graph, res);
301 #ifndef _WIN32
302 // Check whether the bprop needs to be exported.
303 if (serializable) {
304 ExportBpropToMindIR(prim, func_graph);
305 }
306 #endif
307 return func_graph;
308 }
309
GetPossibleBprop(const PrimitivePtr & prim)310 FuncGraphPtr KPrim::GetPossibleBprop(const PrimitivePtr &prim) {
311 FuncGraphPtr bprop_fg = nullptr;
312 auto iter = bprop_registry_.find(prim);
313 if (iter != bprop_registry_.end()) {
314 bprop_fg = iter->second;
315 }
316
317 if (bprop_fg == nullptr) {
318 bprop_fg = GetBprop(prim);
319 if (bprop_fg != nullptr) {
320 // Set bprop_g graph cache
321 bprop_registry_[prim] = bprop_fg;
322 }
323 }
324 return bprop_fg;
325 }
326
GetFprop(const PrimitivePtr & prim)327 FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) {
328 static const std::string ad_module = "mindspore.ops._grad.grad_implementations";
329 std::string func_name = "_fprop_" + prim->name();
330 py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name);
331 auto func_graph = parse::ParsePythonCode(fn);
332 MS_EXCEPTION_IF_NULL(func_graph);
333 return BasicClone(func_graph);
334 }
335
KMetaFuncGraph(const PrimitivePtr & prim)336 MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
337 MS_EXCEPTION_IF_NULL(prim);
338
339 auto iter = bprop_registry_meta_.find(prim);
340 if (iter != bprop_registry_meta_.end()) {
341 return iter->second;
342 }
343
344 if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
345 MetaFuncGraphPtr meta = std::make_shared<prim::MakeTupleGradient>("make_tuple_gradient");
346 bprop_registry_meta_[prim::kPrimMakeTuple] = meta;
347 return meta;
348 }
349
350 if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
351 MetaFuncGraphPtr meta = std::make_shared<prim::MakeListGradient>("make_list_gradient");
352 bprop_registry_meta_[prim::kPrimMakeList] = meta;
353 return meta;
354 }
355
356 MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
357 }
358
AppendMonadOutput(const FuncGraphPtr & bprop_fg,const AnfNodePtr & monad)359 static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &monad) {
360 const auto &output = bprop_fg->output();
361 MS_EXCEPTION_IF_NULL(output);
362 auto output_cnode = output->cast<CNodePtr>();
363 if (output_cnode != nullptr) {
364 // If output_cnode has the form like (make_tuple, x, y).
365 output_cnode->add_input(monad);
366 return;
367 }
368 // If output is an empty tuple, create a (make_tuple, monad) as the new output.
369 auto make_tuple = NewValueNode(prim::kPrimMakeTuple);
370 output_cnode = bprop_fg->NewCNode({make_tuple, monad});
371 bprop_fg->set_output(output_cnode);
372 }
373
374 // Append U or/and IO monad to output of Bprop funcgraph.
AdjustForAutoMonad(const PrimitivePtr & prim,const FuncGraphPtr & bprop_fg)375 static void AdjustForAutoMonad(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) {
376 auto effect_info = GetPrimEffectInfo(prim);
377 if (effect_info.memory) {
378 MS_LOG(DEBUG) << "Append U monad for Bprop FuncGraph of Primitive " << prim->ToString();
379 auto u = NewValueNode(kUMonad);
380 u->set_abstract(kUMonad->ToAbstract());
381 AppendMonadOutput(bprop_fg, u);
382 }
383 if (effect_info.io) {
384 MS_LOG(DEBUG) << "Append IO monad for Bprop FuncGraph of Primitive " << prim->ToString();
385 auto io = NewValueNode(kIOMonad);
386 io->set_abstract(kIOMonad->ToAbstract());
387 AppendMonadOutput(bprop_fg, io);
388 }
389 }
390
GetBprop(const CNodePtr & cnode,const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources,const PrimitivePtr & prim)391 FuncGraphPtr KPrim::GetBprop(const CNodePtr &cnode, const ValueNodePtr &value_node,
392 const pipeline::ResourceBasePtr &resources, const PrimitivePtr &prim) {
393 FuncGraphPtr bprop_fg = nullptr;
394 if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
395 if (MsContext::GetInstance()->get_param<int>(MsCtxParam::MS_CTX_EXECUTION_MODE) == kGraphMode) {
396 MS_LOG(EXCEPTION)
397 << "The Primitive 'HookBackward' is not supported in graph mode, which is only supported in pynative mode.\n"
398 << trace::GetDebugInfo(cnode->debug_info());
399 }
400 bprop_fg = BpropCut(value_node, resources);
401 } else {
402 auto iter = bprop_registry_.find(prim);
403 if (iter != bprop_registry_.end()) {
404 bprop_fg = iter->second;
405 }
406
407 if (bprop_fg == nullptr) {
408 bprop_fg = GetBprop(prim, resources);
409 if (bprop_fg != nullptr) {
410 // Set bprop_g graph cache
411 bprop_registry_[prim] = bprop_fg;
412 } else {
413 bprop_fg = FakeBprop(value_node, resources);
414 }
415 }
416 }
417 return bprop_fg;
418 }
419
KPrimitive(const CNodePtr & cnode,const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)420 FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node,
421 const pipeline::ResourceBasePtr &resources) {
422 if (!IsValueNode<Primitive>(value_node)) {
423 MS_LOG(EXCEPTION) << "Primitive node is not valid.";
424 }
425
426 auto prim = GetValueNode<PrimitivePtr>(value_node);
427 if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
428 auto fprop = GetFprop(prim);
429 fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
430 return fprop;
431 } else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
432 return nullptr;
433 } else if (prim->Hash() == prim::kPrimMakeList->Hash() && prim->name() == prim::kPrimMakeList->name()) {
434 return nullptr;
435 }
436
437 FuncGraphPtr bprop_fg = GetBprop(cnode, value_node, resources, prim);
438 AdjustForAutoMonad(prim, bprop_fg);
439 std::unordered_map<std::string, ValuePtr> primal_attrs;
440 std::vector<NodeDebugInfoPtr> primal_debug_infos;
441 if (resources != nullptr) {
442 auto manager = resources->manager();
443 auto &users = manager->node_users()[value_node];
444 for (auto user_iter = users.begin(); user_iter != users.end(); ++user_iter) {
445 primal_debug_infos.push_back(user_iter->first->debug_info());
446 }
447 }
448 if (cnode != nullptr) {
449 primal_attrs = cnode->primal_attrs();
450 const auto forward_node_primal_attr = prim->name() + "_" + cnode->UniqueId();
451 primal_attrs[kPrimalAttrForwardNodeName] = MakeValue(forward_node_primal_attr);
452 }
453 auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode, primal_attrs, primal_debug_infos);
454 if (expanded_fg == nullptr) {
455 MS_LOG(EXCEPTION) << "Failed convert " << prim->name()
456 << " prim bprop function to J expanded func graph. NodeInfo: "
457 << trace::GetDebugInfo(bprop_fg->debug_info());
458 }
459 if (lift_fv_before_grad && IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
460 // Inline fprop_switch before renormalize;
461 expanded_fg->set_flag(FUNC_GRAPH_FLAG_FORCE_INLINE, true);
462 MS_LOG(DEBUG) << "set force_inline for fg: " << expanded_fg->ToString();
463 }
464
465 return expanded_fg;
466 }
467
BuildOutput(const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg)468 AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) {
469 // current_primal_fg may have extra parameters like u_monad, io_monad
470 std::vector<AnfNodePtr> extra_args;
471 // caller had checked size() - 2 is greater than 0.
472 auto bprop_fg_param_size = bprop_fg->parameters().size() - 2;
473 if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) {
474 auto current_primal_fg_param_size = current_primal_fg->parameters().size();
475 MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so "
476 "Insert it. Extra parameters size: "
477 << current_primal_fg_param_size - bprop_fg_param_size;
478 for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) {
479 const auto &primal_node = current_primal_fg->parameters()[i];
480 AnfNodePtr extra_node;
481 // Simplify zeros_like(primal_node) to U or IO, so extra_node in bprop_fg will not refer to primal_node
482 // as a free variable of primal_graph.
483 // Notes: if the implementation of zeros_like changes, here too.
484 if (HasAbstractUMonad(primal_node)) {
485 extra_node = NewValueNode(kUMonad);
486 } else if (HasAbstractIOMonad(primal_node)) {
487 extra_node = NewValueNode(kIOMonad);
488 } else {
489 MS_EXCEPTION(TypeError)
490 << "The params of function 'bprop' of Primitive or Cell requires the forward inputs as well "
491 "as the 'out' and 'dout'.\n"
492 << trace::GetDebugInfo(bprop_fg->debug_info());
493 }
494 extra_args.push_back(extra_node);
495 MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString();
496 }
497 }
498 // bprop_fg has been checked in caller
499 if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) {
500 // Set bprop output as (env, dx, dy, dz, ...)
501 auto cbprop = bprop_fg->output()->cast<CNodePtr>();
502 auto &inputs = cbprop->inputs();
503
504 std::vector<AnfNodePtr> args;
505 args.push_back(NewValueNode(prim::kPrimMakeTuple));
506 args.push_back(NewValueNode(newenv));
507 (void)args.insert(args.end(), inputs.begin() + 1, inputs.end());
508 if (!extra_args.empty()) {
509 args.insert(args.end(), extra_args.cbegin(), extra_args.cend());
510 }
511 return NewCNode(args, bprop_fg);
512 }
513
514 // Set bprop output as (env, dx)
515 std::string model_name("mindspore.ops.composite.multitype_ops.add_impl");
516 std::string python_ops("_tuple_add");
517 auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg);
518 auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name));
519 if (!extra_args.empty()) {
520 extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple));
521 auto extra_tuple = NewCNode(extra_args, bprop_fg);
522 auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg);
523 return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg);
524 }
525
526 return NewCNode({tuple_add_ops, tuple_env, bprop_fg->output()}, bprop_fg);
527 }
528
TransformNormalArgs(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * const transf_args)529 static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
530 std::vector<AnfNodePtr> *const transf_args) {
531 // bprop_fg has been checked in caller
532 // transform except the last 2 parameters: out, dout.
533 const size_t last_parameter_sizes = 2;
534 auto bprop_fg_param_size = bprop_fg->parameters().size() - last_parameter_sizes;
535 for (size_t i = 0; i < bprop_fg_param_size; ++i) {
536 auto p = bprop_fg->parameters()[i];
537 MS_EXCEPTION_IF_NULL(p);
538
539 TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
540 auto transf_p = outer->add_parameter();
541
542 (void)mng->Replace(p, transf_p);
543 transf_args->push_back(transf_p);
544 }
545 }
TransformArgsForPrimitive(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const PrimitivePtr & primitive,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * const transf_args)546 void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
547 const PrimitivePtr &primitive, const FuncGraphPtr &outer,
548 std::vector<AnfNodePtr> *const transf_args) {
549 MS_EXCEPTION_IF_NULL(mng);
550 TransformNormalArgs(mng, bprop_fg, outer, transf_args);
551 // Fprop_fg for Primitive with side effect should append extra U or IO monad parameter.
552 auto effect_info = GetPrimEffectInfo(primitive);
553 if (effect_info.memory) {
554 MS_LOG(DEBUG) << "Append U monad to Fprop FuncGraph for Primitive " << primitive->ToString();
555 auto transf_p = outer->add_parameter();
556 transf_args->push_back(transf_p);
557 }
558 if (effect_info.io) {
559 MS_LOG(DEBUG) << "Append IO monad to Fprop FuncGraph for Primitive " << primitive->ToString();
560 auto transf_p = outer->add_parameter();
561 transf_args->push_back(transf_p);
562 }
563 }
564
565 template <typename T>
TransformArgsForFuncGraph(const FuncGraphManagerPtr & mng,const FuncGraphPtr & bprop_fg,const T & current_primal_fg,const FuncGraphPtr & outer,std::vector<AnfNodePtr> * const transf_args)566 void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg,
567 const T ¤t_primal_fg, const FuncGraphPtr &outer,
568 std::vector<AnfNodePtr> *const transf_args) {
569 MS_EXCEPTION_IF_NULL(mng);
570 TransformNormalArgs(mng, bprop_fg, outer, transf_args);
571 constexpr size_t need_filter_size = 2;
572 auto bprop_fg_param_size = bprop_fg->parameters().size() - need_filter_size;
573 // current_primal_fg may have extra parameters after AutoMonad
574 const auto ¤t_primal_fg_params = current_primal_fg->parameters();
575 if (bprop_fg_param_size < current_primal_fg_params.size()) {
576 for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) {
577 auto p = current_primal_fg_params[i];
578 MS_EXCEPTION_IF_NULL(p);
579 // extra parameters should be Monad.
580 if (!HasAbstractMonad(p)) {
581 continue;
582 }
583 MS_LOG(DEBUG) << "Function " << current_primal_fg->ToString()
584 << ", has extra monad parameter: " << p->DebugString()
585 << ", abstract: " << p->abstract()->ToString();
586
587 TraceGuard trace_guard(std::make_shared<TraceGradFprop>(p->debug_info()));
588 auto transf_p = outer->add_parameter();
589 // See also Notes on extra_node of BuildOutput.
590 // Notes: No need to replace p with transf_p as the only use of p is here.
591 // If extra_node in bprop_fg use p as free variable, a replacement of p is required here.
592 // This replacement will make the usage of p in current_primal_fg got replaced with transf_p
593 // of outer. outer will be released after it is being cloned to fprop_fg, so the func_graph_
594 // in transf_p will be nullptr.
595 // So the RULE is DONT tamper the current_primal_fg;
596 transf_args->push_back(transf_p);
597 }
598 }
599 if (transf_args->size() != current_primal_fg_params.size()) {
600 MS_EXCEPTION(TypeError) << "Function " << current_primal_fg->ToString()
601 << ", The number of parameter of this primal function is "
602 << current_primal_fg_params.size() << ", but the number of parameters of bprop is "
603 << bprop_fg_param_size;
604 }
605 }
606
CheckBprop(const FuncGraphPtr & bprop_fg,const string & prim_to_check)607 void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) {
608 auto context = MsContext::GetInstance();
609 MS_EXCEPTION_IF_NULL(context);
610 bool check_bprop_flag = context->get_param<bool>(MS_CTX_CHECK_BPROP_FLAG);
611 // Skip checking if check_bprop not set
612 if (!check_bprop_flag) {
613 return;
614 }
615
616 // bprop_fg has been checked in caller
617 auto check_bprop_class = prim::GetPythonOps("CheckBprop", "mindspore.ops.operations.other_ops");
618 MS_EXCEPTION_IF_NULL(check_bprop_class);
619 auto check_bprop =
620 bprop_fg->NewCNode({NewValueNode(check_bprop_class), NewValueNode(std::make_shared<StringImm>(prim_to_check))});
621
622 std::vector<AnfNodePtr> inputs;
623 inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
624 constexpr int primitive_size = 1;
625 constexpr int brprop_offset_size = 2;
626 (void)inputs.insert(inputs.begin() + primitive_size, bprop_fg->parameters().begin(),
627 bprop_fg->parameters().end() - brprop_offset_size);
628 AnfNodePtr params = bprop_fg->NewCNode(inputs);
629
630 inputs.clear();
631 inputs.push_back(check_bprop);
632 inputs.push_back(bprop_fg->output());
633 inputs.push_back(params);
634 AnfNodePtr bprop_out = bprop_fg->NewCNode(inputs);
635 bprop_fg->set_output(bprop_out);
636 }
637
KUserDefinedCellBprop(const FuncGraphPtr & bprop_fg,const FuncGraphPtr & current_primal_fg)638 FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) {
639 MS_EXCEPTION_IF_NULL(bprop_fg);
640 // primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph.
641 // current_primal_fg is specalized and AutoMoaded primal_fg;
642 auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph();
643 auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr, {}, {});
644 if (expanded_fg == nullptr) {
645 MS_LOG(EXCEPTION) << "Failed convert " << primal_fg->ToString()
646 << " Cell bprop function to K expanded func graph. NodeInfo: "
647 << trace::GetDebugInfo(primal_fg->debug_info());
648 }
649 return expanded_fg;
650 }
651
BpropCut(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)652 FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
653 auto prim = GetValueNode<PrimitivePtr>(value_node);
654 MS_EXCEPTION_IF_NULL(prim);
655 auto &node_users = resources->manager()->node_users();
656
657 auto &users = node_users[value_node];
658 auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
659 return IsPrimitiveCNode(user.first, prim);
660 });
661 if (cnode == users.end()) {
662 MS_LOG(EXCEPTION) << "Fail to find cnode.";
663 }
664 auto inputs_num = cnode->first->cast<CNodePtr>()->size() - 1;
665
666 auto func_graph = std::make_shared<FuncGraph>();
667 std::vector<AnfNodePtr> outputs;
668 auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
669 bprop_cut->CopyHookFunction(prim);
670
671 auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
672 if (cell_id != "") {
673 (void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
674 (void)bprop_cut->AddAttr("cell_id", MakeValue(cell_id));
675 }
676
677 outputs.push_back(NewValueNode(bprop_cut));
678 for (size_t i = 0; i < inputs_num; ++i) {
679 auto param = func_graph->add_parameter();
680 outputs.push_back(param);
681 }
682 auto p1 = func_graph->add_parameter();
683 auto p2 = func_graph->add_parameter();
684 outputs.push_back(p1);
685 outputs.push_back(p2);
686
687 func_graph->set_output(func_graph->NewCNode(outputs));
688 return func_graph;
689 }
690
FakeBprop(const ValueNodePtr & value_node,const pipeline::ResourceBasePtr & resources)691 FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) {
692 auto prim = value_node->value()->cast<PrimitivePtr>();
693 MS_EXCEPTION_IF_NULL(prim);
694 auto &node_users = resources->manager()->node_users();
695
696 auto &users = node_users[value_node];
697 auto cnode = std::find_if(users.begin(), users.end(), [&prim](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
698 return IsPrimitiveCNode(user.first, prim);
699 });
700 if (cnode == users.end()) {
701 MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
702 }
703 auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;
704 auto effect_info = GetPrimEffectInfo(prim);
705 // Don't add U or IO monad parameters as it will be added later.
706 size_t monad_params_size = 0;
707 if (effect_info.memory) {
708 monad_params_size++;
709 }
710 if (effect_info.io) {
711 monad_params_size++;
712 }
713 if (inputs_num < monad_params_size) {
714 MS_LOG(EXCEPTION) << "Arguments number should be greater than or equal to " << monad_params_size
715 << ", but the CNode is: " << cnode->first->DebugString();
716 }
717 inputs_num -= monad_params_size;
718
719 auto func_graph = std::make_shared<FuncGraph>();
720 std::vector<AnfNodePtr> outputs;
721 outputs.push_back(NewValueNode(prim::kPrimMakeTuple));
722
723 auto fake_bprop = std::make_shared<Primitive>("fake_bprop");
724 (void)fake_bprop->AddAttr("info", MakeValue("Primitive " + prim->name() + "'s bprop not defined."));
725
726 for (size_t i = 0; i < inputs_num; ++i) {
727 // Mock params for inputs
728 auto param = func_graph->add_parameter();
729 // Mock derivatives for each inputs
730 outputs.push_back(func_graph->NewCNode({NewValueNode(fake_bprop), param}));
731 }
732 // mock params for out and dout
733 (void)func_graph->add_parameter();
734 (void)func_graph->add_parameter();
735 func_graph->set_output(func_graph->NewCNode(outputs));
736 return func_graph;
737 }
738 } // namespace ad
739 } // namespace mindspore
740