1 /**
2 * Copyright 2022-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "frontend/expander/bprop/bprop.h"
17
18 #include <algorithm>
19 #include <queue>
20 #include <unordered_map>
21
22 #include "ops/sequence_ops.h"
23 #include "ops/array_ops.h"
24 #include "ops/framework_ops.h"
25 #include "abstract/ops/primitive_infer_map.h"
26 #include "include/common/expander/core/infer.h"
27 #include "include/common/profiler.h"
28 #include "include/backend/kernel_graph.h"
29 #include "utils/anf_utils.h"
30 #include "include/common/debug/anf_ir_dump.h"
31 #include "frontend/expander/utils.h"
32
33 namespace mindspore {
34 namespace expander {
35 namespace bprop {
36 class SimpleNode {
37 public:
SimpleNode(const AbstractBasePtr & abs)38 explicit SimpleNode(const AbstractBasePtr &abs) : abs_(abs->Clone()) {}
SimpleNode(const ValuePtr & value,const AbstractBasePtr & abs)39 SimpleNode(const ValuePtr &value, const AbstractBasePtr &abs) : abs_(abs->Clone()), value_(value) {}
SimpleNode(const PrimitivePtr & prim,const AbstractBasePtr & abs,const std::vector<size_t> & input_indexs)40 SimpleNode(const PrimitivePtr &prim, const AbstractBasePtr &abs, const std::vector<size_t> &input_indexs)
41 : input_indexs(std::move(input_indexs)), abs_(abs->Clone()), prim_(prim->Clone()) {}
42 ~SimpleNode() = default;
is_valuenode() const43 bool is_valuenode() const { return value_ != nullptr; }
get_primitive() const44 PrimitivePtr get_primitive() const { return prim_->Clone(); }
get_abstract() const45 AbstractBasePtr get_abstract() const { return abs_->Clone(); }
get_value() const46 ValuePtr get_value() const { return value_; }
47
48 std::vector<size_t> input_indexs;
49
50 protected:
51 AbstractBasePtr abs_;
52 ValuePtr value_;
53 PrimitivePtr prim_;
54 };
55 using SimpleNodePtr = std::shared_ptr<SimpleNode>;
56
57 struct SimpleGraph {
58 std::vector<SimpleNodePtr> nodes;
59 std::vector<size_t> output_indexs;
60 std::vector<size_t> input_indexs;
61 };
62 using SimpleGraphPtr = std::shared_ptr<SimpleGraph>;
63 using BpropGraphCacheMap = std::unordered_map<abstract::AbstractBasePtrList, SimpleGraphPtr,
64 abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
65 using KernelGraph = session::KernelGraph;
66
HasBpropExpander(const std::string & prim_name)67 bool HasBpropExpander(const std::string &prim_name) {
68 const BpropHandle *handle = BpropIRBuilderFactory::Instance().GetBuilder(prim_name);
69 return (handle != nullptr);
70 }
71
72 class ShapeCalcException : public std::runtime_error {
73 public:
74 using runtime_error::runtime_error;
75 };
76
77 class PynativeIRBuilder : public IrBuilder {
78 public:
PynativeIRBuilder(const PrimitivePtr & prim,const FuncGraphPtr & fg,const ExpanderInferPtr & infer,UserMap * users,const AnfNodePtr & dout)79 PynativeIRBuilder(const PrimitivePtr &prim, const FuncGraphPtr &fg, const ExpanderInferPtr &infer, UserMap *users,
80 const AnfNodePtr &dout)
81 : IrBuilder(prim->name(), fg, infer), users_(users), dout_(dout), prim_(prim) {
82 MS_EXCEPTION_IF_NULL(users);
83 }
84 ~PynativeIRBuilder() = default;
85
OutZeros(const NodePtr & node)86 NodePtr OutZeros(const NodePtr &node) override {
87 need_infer_ = false;
88 auto ret = Emit(kZerosLikeOpName, {node});
89 need_infer_ = true;
90 return ret;
91 }
92
Build(const std::vector<NodePtr> & input_nodes,const std::vector<ValuePtr> & input_values,const HashMap<std::string,ValuePtr> & attrs,const BpropHandle & handle)93 virtual NodePtrList Build(const std::vector<NodePtr> &input_nodes, const std::vector<ValuePtr> &input_values,
94 const HashMap<std::string, ValuePtr> &attrs, const BpropHandle &handle) {
95 if (!input_values.empty()) {
96 for (size_t i = 0; i < input_values.size(); ++i) {
97 input_nodes[i]->SetValue(input_values[i]);
98 }
99 }
100 auto output_nodes = Run(input_nodes, attrs, handle, prim_->instance_name());
101 for (size_t i = 0; i < output_nodes.size(); i++) {
102 auto &node = output_nodes[i];
103 // A Value node gradient will loss the trace context in pynative, so emit a node. A example is Eye.
104 if (node->input_type() == InputType::kConstant || IsPrimitiveCNode(node->get(), prim::kPrimZerosLike)) {
105 if (node->input_type() == InputType::kConstant) {
106 auto abs = node->abstract();
107 MS_EXCEPTION_IF_NULL(abs);
108 if (abs->isa<abstract::AbstractScalar>()) {
109 node = OutZeros(Tensor(0, abs->BuildType()));
110 } else {
111 node = OutZeros(node);
112 }
113 }
114 node->get()->set_abstract(input_nodes[i]->abstract()->Broaden());
115 }
116 }
117 return output_nodes;
118 }
119
Conditional(const NodePtr & cond,const BlockFunc & true_case,const BlockFunc & false_case)120 NodePtr Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) override {
121 has_ctrl_flow_ = true;
122 CtrlFlowBlock cfb(this, this->func_graph(),
123 [this](const FuncGraphPtr &fg, const ExpanderInferPtr &infer) -> EmitterPtr {
124 return std::make_shared<PynativeIRBuilder>(this->prim_, fg, infer, this->users_, this->dout_);
125 });
126 this->func_graph()->set_flag(kFlagIsControlFlow, true);
127 return cfb.IfThenElse(cond, true_case, false_case);
128 }
129
While(const NodePtr & cond,const BlockFunc & body,const NodePtrList & init_list)130 NodePtr While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) override {
131 has_ctrl_flow_ = true;
132 CtrlFlowBlock cfb(this, this->func_graph(),
133 [this](const FuncGraphPtr &fg, const ExpanderInferPtr &infer) -> EmitterPtr {
134 return std::make_shared<PynativeIRBuilder>(this->prim_, fg, infer, this->users_, this->dout_);
135 });
136 this->func_graph()->set_flag(kFlagIsControlFlow, true);
137 return cfb.While(cond, body, init_list);
138 }
139
140 protected:
EmitGetItemValue(const NodePtrList & inputs)141 NodePtr EmitGetItemValue(const NodePtrList &inputs) {
142 if (inputs[0]->input_type() != InputType::kConstant) {
143 return nullptr;
144 }
145 auto real_input = inputs[0]->get()->cast<ValueNodePtr>();
146 MS_EXCEPTION_IF_NULL(real_input);
147 auto real_input_value = real_input->value()->cast<ValueSequeuePtr>();
148 if (real_input_value != nullptr) {
149 auto item_idx = GetValue<int64_t>(inputs[1]->get()->cast<ValueNodePtr>()->value());
150 auto valuenode = NewValueNode((*real_input_value)[item_idx]);
151 valuenode->set_abstract(valuenode->value()->ToAbstract()->Broaden());
152 return NewIrNode(valuenode);
153 }
154 return nullptr;
155 }
156
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)157 NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) override {
158 if (prim->name() == prim::kPrimShapeCalc->name()) {
159 // temporary solution, remove this after input parameter's value is set.
160 throw ShapeCalcException("ShapeCalc is not supported in pynative mode.");
161 }
162 if (prim->name() == kTupleGetItemOpName) {
163 // if the getitem's real input is a ValueSequence, just return the real Value of that.
164 auto getitem_value = EmitGetItemValue(inputs);
165 if (getitem_value != nullptr) {
166 return getitem_value;
167 }
168 }
169 AnfNodePtrList cnode_inputs{NewValueNode(prim)};
170 cnode_inputs.reserve(inputs.size() + 1);
171 (void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(cnode_inputs),
172 [](const NodePtr &inp) { return inp->get(); });
173 // PyNative use kernel graph construct bprop graph, which indicate func_graph_ here is kernel graph;
174 // And, use kernel graph create cnode will do PostNewCNode which is not necessary
175 auto cnode = func_graph_->isa<KernelGraph>() ? func_graph_->FuncGraph::NewCNode(cnode_inputs)
176 : func_graph_->NewCNode(cnode_inputs);
177 if (scope_ != nullptr) {
178 cnode->set_scope(scope_);
179 }
180
181 auto node = NewIrNode(cnode->cast<AnfNodePtr>());
182 if (need_infer_) {
183 auto value_depend = abstract::GetValueDependArgIndices(cnode);
184 if (!value_depend.empty()) {
185 for (auto idx : value_depend) {
186 size_t i = LongToSize(idx);
187 if (i < inputs.size() && !inputs[i]->HasAbstractValue()) {
188 auto v = inputs[i]->BuildValue();
189 auto tensor = v->cast<tensor::BaseTensorPtr>();
190 if (tensor != nullptr) {
191 tensor->data_sync();
192 }
193 inputs[i]->abstract()->set_value(v);
194 }
195 }
196 }
197 infer_->Infer(node);
198 }
199 // record the users
200 for (size_t i = 1; i < cnode_inputs.size(); i++) {
201 auto &inp = cnode_inputs[i];
202 if (inp == dout_ || inp->isa<Parameter>()) {
203 (void)users_->dout_user_[inp].emplace_back(cnode, i);
204 } else if (IsPrimitiveCNode(inp, prim::kPrimTupleGetItem)) {
205 // record the dout's successor getitem's users
206 auto getitem = inp->cast<CNodePtr>();
207 auto real_input = getitem->input(kIndex1);
208 if (real_input == dout_) {
209 (void)users_->tuple_getitem_user_[inp].emplace_back(cnode, i);
210 }
211 }
212 }
213 return node;
214 }
215
216 UserMap *users_;
217 AnfNodePtr dout_;
218 bool need_infer_{true};
219 PrimitivePtr prim_;
220 bool has_ctrl_flow_{false};
221 };
222
223 class PynativeIRBuilderWithCache : public PynativeIRBuilder {
224 public:
225 using PynativeIRBuilder::PynativeIRBuilder;
226 ~PynativeIRBuilderWithCache() = default;
227
228 inline static std::unordered_map<PrimitivePtr, BpropGraphCacheMap, PrimitiveHasher, PrimitiveTotalEqual>
229 bprop_op_graph_map;
230
Build(const NodePtrList & input_nodes,const std::vector<ValuePtr> & input_values,const HashMap<std::string,ValuePtr> & attrs,const BpropHandle & handle)231 NodePtrList Build(const NodePtrList &input_nodes, const std::vector<ValuePtr> &input_values,
232 const HashMap<std::string, ValuePtr> &attrs, const BpropHandle &handle) override {
233 AbstractBasePtrList abs_list;
234 NodePtrList output_nodes;
235 abs_list.reserve(input_nodes.size());
236 (void)std::transform(input_nodes.cbegin(), input_nodes.cend(), std::back_insert_iterator(abs_list),
237 [](const NodePtr &no) { return no->abstract(); });
238 std::vector<size_t> value_index(input_nodes.size());
239 for (size_t i = 0; i < input_values.size(); ++i) {
240 if (!input_nodes[i]->HasAbstractValue()) {
241 input_nodes[i]->SetValue(input_values[i]);
242 value_index[i] = true;
243 }
244 }
245 BpropGraphCacheMap &bprop_map = PynativeIRBuilderWithCache::bprop_op_graph_map[prim_];
246 auto it = bprop_map.find(abs_list);
247 if (it == bprop_map.end()) {
248 need_record_nodes_ = true;
249 output_nodes = PynativeIRBuilder::Build(input_nodes, {}, attrs, handle);
250 need_record_nodes_ = false;
251 if (has_ctrl_flow_) {
252 return output_nodes;
253 }
254 // need not grad if grad depend input_values.
255 for (size_t i = 0; i < input_nodes.size(); i++) {
256 if (value_index[i] && input_nodes[i]->is_used_value()) {
257 return output_nodes;
258 }
259 }
260 for (auto &node_pair : bprop_nodes_) {
261 if (IsPrimitiveCNode(node_pair.first->get(), prim::kPrimSwitch)) {
262 return output_nodes;
263 }
264 }
265 bprop_map[abs_list] = BuildBpropOpGraph(input_nodes, output_nodes);
266 } else {
267 need_infer_ = false;
268 SimpleGraphPtr graph = it->second;
269 std::vector<NodePtr> node_map(input_nodes);
270 node_map.reserve(graph->nodes.size());
271 auto SimpleNodeToMsNode = [&graph, &node_map, this](const SimpleNodePtr &node) -> NodePtr {
272 if (node->is_valuenode()) {
273 return EmitValue(node->get_value());
274 }
275 NodePtrList cnode_list;
276 cnode_list.reserve(node->input_indexs.size());
277 for (size_t i : node->input_indexs) {
278 (void)cnode_list.emplace_back(node_map[i]);
279 }
280 NodePtr new_node = EmitOp(node->get_primitive(), cnode_list);
281 AnfNodePtr ms_node = new_node->get();
282 if (ms_node->abstract() == nullptr) {
283 ms_node->set_abstract(node->get_abstract());
284 }
285 return new_node;
286 };
287 for (size_t i = graph->input_indexs.size(); i < graph->nodes.size(); i++) {
288 (void)node_map.emplace_back(SimpleNodeToMsNode(graph->nodes[i]));
289 }
290 output_nodes.reserve(graph->output_indexs.size());
291 for (size_t i : graph->output_indexs) {
292 (void)output_nodes.emplace_back(node_map[i]);
293 }
294 }
295 return output_nodes;
296 }
297
298 protected:
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)299 NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) override {
300 auto node = PynativeIRBuilder::EmitOp(prim, inputs);
301 if (need_record_nodes_) {
302 (void)bprop_nodes_.emplace_back(std::make_pair(node, inputs));
303 }
304 return node;
305 }
306
307 private:
BuildBpropOpGraph(const NodePtrList & input_nodes,const NodePtrList & output_nodes)308 SimpleGraphPtr BuildBpropOpGraph(const NodePtrList &input_nodes, const NodePtrList &output_nodes) {
309 std::unordered_map<NodePtr, size_t> node_map;
310 SimpleGraphPtr graph = std::make_shared<SimpleGraph>();
311 for (auto &parm : input_nodes) {
312 node_map[parm] = graph->nodes.size();
313 (void)graph->input_indexs.emplace_back(graph->nodes.size());
314 (void)graph->nodes.emplace_back(std::make_shared<SimpleNode>(parm->abstract()));
315 }
316 for (auto &[node, inputs] : bprop_nodes_) {
317 std::vector<size_t> input_indexs;
318 input_indexs.reserve(inputs.size());
319 for (auto &no : inputs) {
320 auto it = node_map.find(no);
321 if (it == node_map.end()) {
322 auto value = no->BuildValue();
323 node_map[node] = graph->nodes.size();
324 (void)input_indexs.emplace_back(graph->nodes.size());
325 (void)graph->nodes.emplace_back(std::make_shared<SimpleNode>(value, value->ToAbstract()->Broaden()));
326 } else {
327 (void)input_indexs.emplace_back(it->second);
328 }
329 }
330 PrimitivePtr primitive =
331 node->input_type() == InputType::kConstant ? prim::kPrimTupleGetItem : GetCNodePrimitive(node->get());
332 node_map[node] = graph->nodes.size();
333 (void)graph->nodes.emplace_back(std::make_shared<SimpleNode>(primitive, node->abstract(), input_indexs));
334 }
335 graph->output_indexs.reserve(output_nodes.size());
336 for (auto &node : output_nodes) {
337 (void)graph->output_indexs.emplace_back(node_map[node]);
338 }
339 return graph;
340 }
341
342 bool need_record_nodes_{false};
343 std::vector<std::pair<NodePtr, NodePtrList>> bprop_nodes_;
344 };
345
ClearBpropOpGraphMap()346 void ClearBpropOpGraphMap() { PynativeIRBuilderWithCache ::bprop_op_graph_map.clear(); }
347
Run(const CNodePtr & cnode,const std::vector<ValuePtr> & input_values)348 bool BpropExpander::Run(const CNodePtr &cnode, const std::vector<ValuePtr> &input_values) {
349 MS_EXCEPTION_IF_NULL(cnode);
350 MS_LOG(DEBUG) << "Begin building bprop for " << cnode->fullname_with_scope();
351 bool ret = true;
352 if (outputs_ != nullptr) {
353 outputs_->clear();
354 }
355 auto node_name = AnfUtils::GetCNodeName(cnode);
356 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeGradExpander,
357 node_name, true);
358 if (OpEnvManager::UsePyBprop(node_name)) {
359 MS_LOG(DEBUG) << "Python bprop will be used for op " << node_name;
360 return false;
361 }
362 try {
363 ret = RunBprop(cnode, input_values);
364 } catch (const ShapeCalcException &e) {
365 MS_LOG(INFO) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what()
366 << "]. python bprop will be used.";
367 if (outputs_ != nullptr) {
368 outputs_->clear();
369 }
370 ret = false;
371 } catch (const std::exception &e) {
372 MS_LOG(ERROR) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what() << "]";
373 std::rethrow_exception(std::current_exception());
374 }
375 MS_LOG(DEBUG) << "Finish building bprop for " << cnode->fullname_with_scope();
376 return ret;
377 }
378
GetUnusedInputs(const string & op_name)379 const mindspore::HashSet<size_t> &BpropExpander::GetUnusedInputs(const string &op_name) {
380 auto handle = BpropIRBuilderFactory::Instance().GetBuilder(op_name);
381 if (handle == nullptr) {
382 MS_LOG(DEBUG) << "Bprop IRBuilder [" << op_name << "] is not registered in bprop expander.";
383 static const mindspore::HashSet<size_t> no_handle{INT_MAX};
384 return no_handle;
385 }
386 return handle->unused_inputs;
387 }
388
RunBprop(const CNodePtr & cnode,const std::vector<ValuePtr> & input_values)389 bool BpropExpander::RunBprop(const CNodePtr &cnode, const std::vector<ValuePtr> &input_values) {
390 static const bool cache_env = (common::GetEnv("MS_DEV_DISABLE_BPROP_CACHE") != "on");
391 const auto prim = GetCNodePrimitive(cnode);
392 const auto name = prim->name();
393 std::shared_ptr<PynativeIRBuilder> ir_builder;
394 if (cache_env) {
395 ir_builder = std::make_shared<PynativeIRBuilderWithCache>(prim, cnode->func_graph(), std::make_shared<CppInfer>(),
396 users_, cnode->inputs().back());
397 } else {
398 ir_builder = std::make_shared<PynativeIRBuilder>(prim, cnode->func_graph(), std::make_shared<CppInfer>(), users_,
399 cnode->inputs().back());
400 }
401 input_nodes_.reserve(cnode->size());
402 (void)std::transform(
403 cnode->weak_inputs().cbegin() + 1, cnode->weak_inputs().cend(), std::back_inserter(input_nodes_),
404 [&ir_builder](const AnfNodeWeakPtr &no) { return std::make_shared<IrNode>(no.lock(), ir_builder.get()); });
405 mindspore::HashMap<std::string, ValuePtr> attrs;
406 {
407 PrimitiveReadLock read_lock(prim->shared_mutex());
408 attrs = prim->attrs();
409 }
410 auto handle = BpropIRBuilderFactory::Instance().GetBuilder(name);
411 if (handle == nullptr) {
412 MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander.";
413 return false;
414 }
415 output_nodes_ = ir_builder->Build(input_nodes_, input_values, attrs, *handle);
416 if (output_nodes_.empty()) {
417 MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty.";
418 return false;
419 }
420 PostProcess(cnode);
421 DumpResult(name);
422 return true;
423 }
424
PostProcess(const CNodePtr & cnode) const425 void BpropExpander::PostProcess(const CNodePtr &cnode) const {
426 outputs_->reserve(output_nodes_.size());
427 constexpr const size_t num_out_and_dout = 2;
428 if (output_nodes_.size() + num_out_and_dout != input_nodes_.size()) {
429 MS_LOG(EXCEPTION) << "For bprop [" << AnfUtils::GetCNodeName(cnode)
430 << "], the output size should be equal to input size (exclude out and dout), but got "
431 << output_nodes_.size() << " vs " << (input_nodes_.size() - num_out_and_dout);
432 }
433 for (size_t i = 0; i < output_nodes_.size(); i++) {
434 (void)outputs_->emplace_back(output_nodes_[i]->get()->cast<CNodePtr>());
435 }
436 }
437
DumpResult(const std::string & name) const438 void BpropExpander::DumpResult(const std::string &name) const {
439 static const bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
440 if (!dump_result) {
441 return;
442 }
443 auto fg = std::make_shared<FuncGraph>();
444 std::map<AnfNodePtr, AnfNodePtr> node_map;
445 CNodePtrList newcnodes;
446 for (auto &inp : input_nodes_) {
447 auto p = fg->add_parameter();
448 p->set_abstract(inp->get()->abstract());
449 node_map[inp->get()] = p;
450 }
451 std::queue<CNodePtr> que;
452 (void)std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
453
454 while (!que.empty()) {
455 auto node = que.front();
456 que.pop();
457 if (node_map.count(node) != 0) {
458 continue;
459 }
460 auto new_node = fg->NewCNode(node->inputs());
461 new_node->CloneCNodeInfo(node);
462 new_node->set_fullname_with_scope(node->fullname_with_scope());
463 node_map[node] = new_node;
464 newcnodes.push_back(new_node);
465 for (size_t i = 1; i < node->size(); ++i) {
466 const auto &inp = node->input(i);
467 if (inp->isa<CNode>() && node_map.count(inp) == 0) {
468 que.push(inp->cast<CNodePtr>());
469 }
470 }
471 }
472
473 for (auto &cnode : newcnodes) {
474 for (size_t i = 1; i < cnode->size(); i++) {
475 if (node_map.count(cnode->input(i)) != 0) {
476 cnode->set_input(i, node_map[cnode->input(i)]);
477 }
478 }
479 }
480
481 if (outputs_->size() == 1) {
482 fg->set_output(node_map[(*outputs_)[0]]);
483 } else {
484 AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)};
485 AbstractBasePtrList abs;
486 (void)std::transform(outputs_->cbegin(), outputs_->cend(), std::back_inserter(new_outputs),
487 [&node_map, &abs](const CNodePtr &node) {
488 abs.push_back(node->abstract());
489 return node_map[node];
490 });
491 auto mt = fg->NewCNode(new_outputs);
492 mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
493 fg->set_output(mt);
494 }
495 DumpIR("bprop/bprop_expander_" + name + ".ir", fg, true);
496
497 if (users_ != nullptr) {
498 for (auto &uiter : users_->dout_user_) {
499 for (auto &iter : uiter.second) {
500 auto user = iter.first.lock();
501 if (user == nullptr) {
502 continue;
503 }
504 MS_LOG(INFO) << "Node " << uiter.first->ToString() << " user: " << user->fullname_with_scope()
505 << " index: " << iter.second;
506 }
507 }
508 }
509 }
510
511 class LazyInfer : public CppInfer {
512 public:
Infer(const NodePtr &)513 void Infer(const NodePtr &) override { return; }
514
GetAbstract(const NodePtr & node)515 AbstractBasePtr GetAbstract(const NodePtr &node) override {
516 auto anfnode = node->get();
517 if (anfnode->abstract() == nullptr) {
518 InferNow(anfnode);
519 }
520 return anfnode->abstract();
521 }
522
523 protected:
InferNow(const AnfNodePtr & node)524 void InferNow(const AnfNodePtr &node) {
525 if (node->isa<CNode>()) {
526 auto cnode = node->cast<CNodePtr>();
527 for (size_t i = 1; i < cnode->size(); i++) {
528 if (cnode->input(i)->abstract() == nullptr) {
529 InferNow(cnode->input(i));
530 }
531 }
532 }
533 CppInfer::InferAnfnode(node);
534 }
535 };
536
537 class GraphModeBuilder : public IrBuilder {
538 public:
GraphModeBuilder(const std::string & name,const FuncGraphPtr & func_graph,const ExpanderInferPtr & infer)539 GraphModeBuilder(const std::string &name, const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer)
540 : IrBuilder(name, func_graph, infer) {}
541
Build(const NodePtrList & inputs,const mindspore::HashMap<std::string,ValuePtr> & attrs,const BpropHandle & handle,const std::string & instance_name)542 NodePtrList Build(const NodePtrList &inputs, const mindspore::HashMap<std::string, ValuePtr> &attrs,
543 const BpropHandle &handle, const std::string &instance_name) {
544 auto outputs = Run(inputs, attrs, handle, instance_name);
545 auto mt = this->MakeTuple(outputs)->get();
546 func_graph_->set_output(mt);
547 if (has_ctrl_flow_) {
548 // clear all abstract, to let the specializer re-infer the subgraph of controlflow graphs.
549 auto todos = TopoSort(func_graph_->get_return(), SuccDeeperSimple, AlwaysInclude);
550 for (auto &no : todos) {
551 no->set_abstract(nullptr);
552 if (IsValueNode<FuncGraph>(no)) {
553 auto fg = GetValueNode<FuncGraphPtr>(no);
554 for (auto &p : fg->parameters()) {
555 p->set_abstract(nullptr);
556 }
557 }
558 }
559 }
560 return outputs;
561 }
562
Conditional(const NodePtr & cond,const BlockFunc & true_case,const BlockFunc & false_case)563 NodePtr Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) override {
564 has_ctrl_flow_ = true;
565 return IrBuilder::Conditional(cond, true_case, false_case);
566 }
567
While(const NodePtr & cond,const BlockFunc & body,const NodePtrList & init_list)568 NodePtr While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) override {
569 has_ctrl_flow_ = true;
570 return IrBuilder::While(cond, body, init_list);
571 }
572
573 protected:
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)574 NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) override {
575 auto primpy = ConvertPrimToPrimPy(prim);
576 AnfNodePtrList cnode_inputs = {NewValueNode(primpy ? primpy : prim)};
577 cnode_inputs.reserve(inputs.size() + 1);
578 (void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(cnode_inputs), [](const NodePtr &no) {
579 MS_EXCEPTION_IF_NULL(no);
580 return no->get();
581 });
582 // PyNative use kernel graph construct bprop graph
583 auto cnode = func_graph_->isa<KernelGraph>() ? func_graph_->FuncGraph::NewCNode(cnode_inputs)
584 : func_graph_->NewCNode(cnode_inputs);
585 if (scope_ != nullptr) {
586 cnode->set_scope(scope_);
587 }
588 auto node = NewIrNode(cnode->cast<AnfNodePtr>());
589 infer_->Infer(node);
590 return node;
591 }
592
593 bool has_ctrl_flow_{false};
594 };
595
ExpandBpropInGraphMode(const BpropHandle * handle,const PrimitivePtr & prim,const FuncGraphPtr & graph)596 bool ExpandBpropInGraphMode(const BpropHandle *handle, const PrimitivePtr &prim, const FuncGraphPtr &graph) {
597 static const bool use_imm_infer = (common::GetEnv("MS_DEV_BPROP_IMM_INFER") == "on");
598 static const bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
599 auto name = prim->name();
600 if (handle == nullptr) {
601 MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander.";
602 return false;
603 }
604 ExpanderInferPtr infer;
605 if (use_imm_infer) {
606 infer = std::make_shared<CppInfer>();
607 } else {
608 infer = std::make_shared<LazyInfer>();
609 }
610 GraphModeBuilder ir_builder(name, graph, infer);
611 auto ¶meters = graph->parameters();
612 NodePtrList inputs;
613 inputs.reserve(parameters.size());
614 (void)std::transform(parameters.cbegin(), parameters.cend(), std::back_inserter(inputs),
615 [&ir_builder](const AnfNodePtr &no) { return std::make_shared<IrNode>(no, &ir_builder); });
616 auto outputs = ir_builder.Build(inputs, prim->attrs(), *handle, prim->instance_name());
617 if (outputs.empty()) {
618 MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty.";
619 return false;
620 }
621 if (dump_result) {
622 DumpIR("bprop/bprop_expander_" + name + ".ir", graph, true);
623 }
624 return true;
625 }
626
627 #ifdef _MSC_VER
628 void RegGradArrayOps();
629 void RegGradClipOps();
630 void RegGradCommOps();
631 void RegGradDebugOps();
632 void RegGradImageOps();
633 void RegGradImplementationsOps();
634 void RegGradInnerOps();
635 void RegGradLinalgOps();
636 void RegGradMathOps();
637 void RegGradNnOps();
638 void RegGradOtherOps();
639 void RegGradQuantOps();
640 void RegGradScipyOps();
641 void RegGradSparseOps();
642 void RegGradSequenceOps();
643 void RegGradScalarOps();
644
WinBpropRegister()645 WinBpropRegister::WinBpropRegister() {
646 RegGradArrayOps();
647 RegGradClipOps();
648 RegGradCommOps();
649 RegGradDebugOps();
650 RegGradImageOps();
651 RegGradImplementationsOps();
652 RegGradInnerOps();
653 RegGradLinalgOps();
654 RegGradMathOps();
655 RegGradNnOps();
656 RegGradOtherOps();
657 RegGradQuantOps();
658 RegGradScipyOps();
659 RegGradSparseOps();
660 RegGradSequenceOps();
661 RegGradScalarOps();
662 }
663 #endif
664 } // namespace bprop
665 } // namespace expander
666 } // namespace mindspore
667