• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "frontend/expander/bprop/bprop.h"
17 
18 #include <algorithm>
19 #include <queue>
20 #include <unordered_map>
21 
22 #include "ops/sequence_ops.h"
23 #include "ops/array_ops.h"
24 #include "ops/framework_ops.h"
25 #include "abstract/ops/primitive_infer_map.h"
26 #include "include/common/expander/core/infer.h"
27 #include "include/common/profiler.h"
28 #include "include/backend/kernel_graph.h"
29 #include "utils/anf_utils.h"
30 #include "include/common/debug/anf_ir_dump.h"
31 #include "frontend/expander/utils.h"
32 
33 namespace mindspore {
34 namespace expander {
35 namespace bprop {
36 class SimpleNode {
37  public:
SimpleNode(const AbstractBasePtr & abs)38   explicit SimpleNode(const AbstractBasePtr &abs) : abs_(abs->Clone()) {}
SimpleNode(const ValuePtr & value,const AbstractBasePtr & abs)39   SimpleNode(const ValuePtr &value, const AbstractBasePtr &abs) : abs_(abs->Clone()), value_(value) {}
SimpleNode(const PrimitivePtr & prim,const AbstractBasePtr & abs,const std::vector<size_t> & input_indexs)40   SimpleNode(const PrimitivePtr &prim, const AbstractBasePtr &abs, const std::vector<size_t> &input_indexs)
41       : input_indexs(std::move(input_indexs)), abs_(abs->Clone()), prim_(prim->Clone()) {}
42   ~SimpleNode() = default;
is_valuenode() const43   bool is_valuenode() const { return value_ != nullptr; }
get_primitive() const44   PrimitivePtr get_primitive() const { return prim_->Clone(); }
get_abstract() const45   AbstractBasePtr get_abstract() const { return abs_->Clone(); }
get_value() const46   ValuePtr get_value() const { return value_; }
47 
48   std::vector<size_t> input_indexs;
49 
50  protected:
51   AbstractBasePtr abs_;
52   ValuePtr value_;
53   PrimitivePtr prim_;
54 };
55 using SimpleNodePtr = std::shared_ptr<SimpleNode>;
56 
57 struct SimpleGraph {
58   std::vector<SimpleNodePtr> nodes;
59   std::vector<size_t> output_indexs;
60   std::vector<size_t> input_indexs;
61 };
62 using SimpleGraphPtr = std::shared_ptr<SimpleGraph>;
63 using BpropGraphCacheMap = std::unordered_map<abstract::AbstractBasePtrList, SimpleGraphPtr,
64                                               abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
65 using KernelGraph = session::KernelGraph;
66 
HasBpropExpander(const std::string & prim_name)67 bool HasBpropExpander(const std::string &prim_name) {
68   const BpropHandle *handle = BpropIRBuilderFactory::Instance().GetBuilder(prim_name);
69   return (handle != nullptr);
70 }
71 
72 class ShapeCalcException : public std::runtime_error {
73  public:
74   using runtime_error::runtime_error;
75 };
76 
77 class PynativeIRBuilder : public IrBuilder {
78  public:
PynativeIRBuilder(const PrimitivePtr & prim,const FuncGraphPtr & fg,const ExpanderInferPtr & infer,UserMap * users,const AnfNodePtr & dout)79   PynativeIRBuilder(const PrimitivePtr &prim, const FuncGraphPtr &fg, const ExpanderInferPtr &infer, UserMap *users,
80                     const AnfNodePtr &dout)
81       : IrBuilder(prim->name(), fg, infer), users_(users), dout_(dout), prim_(prim) {
82     MS_EXCEPTION_IF_NULL(users);
83   }
84   ~PynativeIRBuilder() = default;
85 
OutZeros(const NodePtr & node)86   NodePtr OutZeros(const NodePtr &node) override {
87     need_infer_ = false;
88     auto ret = Emit(kZerosLikeOpName, {node});
89     need_infer_ = true;
90     return ret;
91   }
92 
Build(const std::vector<NodePtr> & input_nodes,const std::vector<ValuePtr> & input_values,const HashMap<std::string,ValuePtr> & attrs,const BpropHandle & handle)93   virtual NodePtrList Build(const std::vector<NodePtr> &input_nodes, const std::vector<ValuePtr> &input_values,
94                             const HashMap<std::string, ValuePtr> &attrs, const BpropHandle &handle) {
95     if (!input_values.empty()) {
96       for (size_t i = 0; i < input_values.size(); ++i) {
97         input_nodes[i]->SetValue(input_values[i]);
98       }
99     }
100     auto output_nodes = Run(input_nodes, attrs, handle, prim_->instance_name());
101     for (size_t i = 0; i < output_nodes.size(); i++) {
102       auto &node = output_nodes[i];
103       // A Value node gradient will loss the trace context in pynative, so emit a node. A example is Eye.
104       if (node->input_type() == InputType::kConstant || IsPrimitiveCNode(node->get(), prim::kPrimZerosLike)) {
105         if (node->input_type() == InputType::kConstant) {
106           auto abs = node->abstract();
107           MS_EXCEPTION_IF_NULL(abs);
108           if (abs->isa<abstract::AbstractScalar>()) {
109             node = OutZeros(Tensor(0, abs->BuildType()));
110           } else {
111             node = OutZeros(node);
112           }
113         }
114         node->get()->set_abstract(input_nodes[i]->abstract()->Broaden());
115       }
116     }
117     return output_nodes;
118   }
119 
Conditional(const NodePtr & cond,const BlockFunc & true_case,const BlockFunc & false_case)120   NodePtr Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) override {
121     has_ctrl_flow_ = true;
122     CtrlFlowBlock cfb(this, this->func_graph(),
123                       [this](const FuncGraphPtr &fg, const ExpanderInferPtr &infer) -> EmitterPtr {
124                         return std::make_shared<PynativeIRBuilder>(this->prim_, fg, infer, this->users_, this->dout_);
125                       });
126     this->func_graph()->set_flag(kFlagIsControlFlow, true);
127     return cfb.IfThenElse(cond, true_case, false_case);
128   }
129 
While(const NodePtr & cond,const BlockFunc & body,const NodePtrList & init_list)130   NodePtr While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) override {
131     has_ctrl_flow_ = true;
132     CtrlFlowBlock cfb(this, this->func_graph(),
133                       [this](const FuncGraphPtr &fg, const ExpanderInferPtr &infer) -> EmitterPtr {
134                         return std::make_shared<PynativeIRBuilder>(this->prim_, fg, infer, this->users_, this->dout_);
135                       });
136     this->func_graph()->set_flag(kFlagIsControlFlow, true);
137     return cfb.While(cond, body, init_list);
138   }
139 
140  protected:
EmitGetItemValue(const NodePtrList & inputs)141   NodePtr EmitGetItemValue(const NodePtrList &inputs) {
142     if (inputs[0]->input_type() != InputType::kConstant) {
143       return nullptr;
144     }
145     auto real_input = inputs[0]->get()->cast<ValueNodePtr>();
146     MS_EXCEPTION_IF_NULL(real_input);
147     auto real_input_value = real_input->value()->cast<ValueSequeuePtr>();
148     if (real_input_value != nullptr) {
149       auto item_idx = GetValue<int64_t>(inputs[1]->get()->cast<ValueNodePtr>()->value());
150       auto valuenode = NewValueNode((*real_input_value)[item_idx]);
151       valuenode->set_abstract(valuenode->value()->ToAbstract()->Broaden());
152       return NewIrNode(valuenode);
153     }
154     return nullptr;
155   }
156 
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)157   NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) override {
158     if (prim->name() == prim::kPrimShapeCalc->name()) {
159       // temporary solution, remove this after input parameter's value is set.
160       throw ShapeCalcException("ShapeCalc is not supported in pynative mode.");
161     }
162     if (prim->name() == kTupleGetItemOpName) {
163       // if the getitem's real input is a ValueSequence, just return the real Value of that.
164       auto getitem_value = EmitGetItemValue(inputs);
165       if (getitem_value != nullptr) {
166         return getitem_value;
167       }
168     }
169     AnfNodePtrList cnode_inputs{NewValueNode(prim)};
170     cnode_inputs.reserve(inputs.size() + 1);
171     (void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(cnode_inputs),
172                          [](const NodePtr &inp) { return inp->get(); });
173     // PyNative use kernel graph construct bprop graph, which indicate func_graph_ here is kernel graph;
174     // And, use kernel graph create cnode will do PostNewCNode which is not necessary
175     auto cnode = func_graph_->isa<KernelGraph>() ? func_graph_->FuncGraph::NewCNode(cnode_inputs)
176                                                  : func_graph_->NewCNode(cnode_inputs);
177     if (scope_ != nullptr) {
178       cnode->set_scope(scope_);
179     }
180 
181     auto node = NewIrNode(cnode->cast<AnfNodePtr>());
182     if (need_infer_) {
183       auto value_depend = abstract::GetValueDependArgIndices(cnode);
184       if (!value_depend.empty()) {
185         for (auto idx : value_depend) {
186           size_t i = LongToSize(idx);
187           if (i < inputs.size() && !inputs[i]->HasAbstractValue()) {
188             auto v = inputs[i]->BuildValue();
189             auto tensor = v->cast<tensor::BaseTensorPtr>();
190             if (tensor != nullptr) {
191               tensor->data_sync();
192             }
193             inputs[i]->abstract()->set_value(v);
194           }
195         }
196       }
197       infer_->Infer(node);
198     }
199     // record the users
200     for (size_t i = 1; i < cnode_inputs.size(); i++) {
201       auto &inp = cnode_inputs[i];
202       if (inp == dout_ || inp->isa<Parameter>()) {
203         (void)users_->dout_user_[inp].emplace_back(cnode, i);
204       } else if (IsPrimitiveCNode(inp, prim::kPrimTupleGetItem)) {
205         // record the dout's successor getitem's users
206         auto getitem = inp->cast<CNodePtr>();
207         auto real_input = getitem->input(kIndex1);
208         if (real_input == dout_) {
209           (void)users_->tuple_getitem_user_[inp].emplace_back(cnode, i);
210         }
211       }
212     }
213     return node;
214   }
215 
216   UserMap *users_;
217   AnfNodePtr dout_;
218   bool need_infer_{true};
219   PrimitivePtr prim_;
220   bool has_ctrl_flow_{false};
221 };
222 
223 class PynativeIRBuilderWithCache : public PynativeIRBuilder {
224  public:
225   using PynativeIRBuilder::PynativeIRBuilder;
226   ~PynativeIRBuilderWithCache() = default;
227 
228   inline static std::unordered_map<PrimitivePtr, BpropGraphCacheMap, PrimitiveHasher, PrimitiveTotalEqual>
229     bprop_op_graph_map;
230 
Build(const NodePtrList & input_nodes,const std::vector<ValuePtr> & input_values,const HashMap<std::string,ValuePtr> & attrs,const BpropHandle & handle)231   NodePtrList Build(const NodePtrList &input_nodes, const std::vector<ValuePtr> &input_values,
232                     const HashMap<std::string, ValuePtr> &attrs, const BpropHandle &handle) override {
233     AbstractBasePtrList abs_list;
234     NodePtrList output_nodes;
235     abs_list.reserve(input_nodes.size());
236     (void)std::transform(input_nodes.cbegin(), input_nodes.cend(), std::back_insert_iterator(abs_list),
237                          [](const NodePtr &no) { return no->abstract(); });
238     std::vector<size_t> value_index(input_nodes.size());
239     for (size_t i = 0; i < input_values.size(); ++i) {
240       if (!input_nodes[i]->HasAbstractValue()) {
241         input_nodes[i]->SetValue(input_values[i]);
242         value_index[i] = true;
243       }
244     }
245     BpropGraphCacheMap &bprop_map = PynativeIRBuilderWithCache::bprop_op_graph_map[prim_];
246     auto it = bprop_map.find(abs_list);
247     if (it == bprop_map.end()) {
248       need_record_nodes_ = true;
249       output_nodes = PynativeIRBuilder::Build(input_nodes, {}, attrs, handle);
250       need_record_nodes_ = false;
251       if (has_ctrl_flow_) {
252         return output_nodes;
253       }
254       // need not grad if grad depend input_values.
255       for (size_t i = 0; i < input_nodes.size(); i++) {
256         if (value_index[i] && input_nodes[i]->is_used_value()) {
257           return output_nodes;
258         }
259       }
260       for (auto &node_pair : bprop_nodes_) {
261         if (IsPrimitiveCNode(node_pair.first->get(), prim::kPrimSwitch)) {
262           return output_nodes;
263         }
264       }
265       bprop_map[abs_list] = BuildBpropOpGraph(input_nodes, output_nodes);
266     } else {
267       need_infer_ = false;
268       SimpleGraphPtr graph = it->second;
269       std::vector<NodePtr> node_map(input_nodes);
270       node_map.reserve(graph->nodes.size());
271       auto SimpleNodeToMsNode = [&graph, &node_map, this](const SimpleNodePtr &node) -> NodePtr {
272         if (node->is_valuenode()) {
273           return EmitValue(node->get_value());
274         }
275         NodePtrList cnode_list;
276         cnode_list.reserve(node->input_indexs.size());
277         for (size_t i : node->input_indexs) {
278           (void)cnode_list.emplace_back(node_map[i]);
279         }
280         NodePtr new_node = EmitOp(node->get_primitive(), cnode_list);
281         AnfNodePtr ms_node = new_node->get();
282         if (ms_node->abstract() == nullptr) {
283           ms_node->set_abstract(node->get_abstract());
284         }
285         return new_node;
286       };
287       for (size_t i = graph->input_indexs.size(); i < graph->nodes.size(); i++) {
288         (void)node_map.emplace_back(SimpleNodeToMsNode(graph->nodes[i]));
289       }
290       output_nodes.reserve(graph->output_indexs.size());
291       for (size_t i : graph->output_indexs) {
292         (void)output_nodes.emplace_back(node_map[i]);
293       }
294     }
295     return output_nodes;
296   }
297 
298  protected:
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)299   NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) override {
300     auto node = PynativeIRBuilder::EmitOp(prim, inputs);
301     if (need_record_nodes_) {
302       (void)bprop_nodes_.emplace_back(std::make_pair(node, inputs));
303     }
304     return node;
305   }
306 
307  private:
BuildBpropOpGraph(const NodePtrList & input_nodes,const NodePtrList & output_nodes)308   SimpleGraphPtr BuildBpropOpGraph(const NodePtrList &input_nodes, const NodePtrList &output_nodes) {
309     std::unordered_map<NodePtr, size_t> node_map;
310     SimpleGraphPtr graph = std::make_shared<SimpleGraph>();
311     for (auto &parm : input_nodes) {
312       node_map[parm] = graph->nodes.size();
313       (void)graph->input_indexs.emplace_back(graph->nodes.size());
314       (void)graph->nodes.emplace_back(std::make_shared<SimpleNode>(parm->abstract()));
315     }
316     for (auto &[node, inputs] : bprop_nodes_) {
317       std::vector<size_t> input_indexs;
318       input_indexs.reserve(inputs.size());
319       for (auto &no : inputs) {
320         auto it = node_map.find(no);
321         if (it == node_map.end()) {
322           auto value = no->BuildValue();
323           node_map[node] = graph->nodes.size();
324           (void)input_indexs.emplace_back(graph->nodes.size());
325           (void)graph->nodes.emplace_back(std::make_shared<SimpleNode>(value, value->ToAbstract()->Broaden()));
326         } else {
327           (void)input_indexs.emplace_back(it->second);
328         }
329       }
330       PrimitivePtr primitive =
331         node->input_type() == InputType::kConstant ? prim::kPrimTupleGetItem : GetCNodePrimitive(node->get());
332       node_map[node] = graph->nodes.size();
333       (void)graph->nodes.emplace_back(std::make_shared<SimpleNode>(primitive, node->abstract(), input_indexs));
334     }
335     graph->output_indexs.reserve(output_nodes.size());
336     for (auto &node : output_nodes) {
337       (void)graph->output_indexs.emplace_back(node_map[node]);
338     }
339     return graph;
340   }
341 
342   bool need_record_nodes_{false};
343   std::vector<std::pair<NodePtr, NodePtrList>> bprop_nodes_;
344 };
345 
ClearBpropOpGraphMap()346 void ClearBpropOpGraphMap() { PynativeIRBuilderWithCache ::bprop_op_graph_map.clear(); }
347 
Run(const CNodePtr & cnode,const std::vector<ValuePtr> & input_values)348 bool BpropExpander::Run(const CNodePtr &cnode, const std::vector<ValuePtr> &input_values) {
349   MS_EXCEPTION_IF_NULL(cnode);
350   MS_LOG(DEBUG) << "Begin building bprop for " << cnode->fullname_with_scope();
351   bool ret = true;
352   if (outputs_ != nullptr) {
353     outputs_->clear();
354   }
355   auto node_name = AnfUtils::GetCNodeName(cnode);
356   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeGradExpander,
357                                      node_name, true);
358   if (OpEnvManager::UsePyBprop(node_name)) {
359     MS_LOG(DEBUG) << "Python bprop will be used for op " << node_name;
360     return false;
361   }
362   try {
363     ret = RunBprop(cnode, input_values);
364   } catch (const ShapeCalcException &e) {
365     MS_LOG(INFO) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what()
366                  << "]. python bprop will be used.";
367     if (outputs_ != nullptr) {
368       outputs_->clear();
369     }
370     ret = false;
371   } catch (const std::exception &e) {
372     MS_LOG(ERROR) << "Bprop \"" << node_name << "\" encounter a problem: [" << e.what() << "]";
373     std::rethrow_exception(std::current_exception());
374   }
375   MS_LOG(DEBUG) << "Finish building bprop for " << cnode->fullname_with_scope();
376   return ret;
377 }
378 
GetUnusedInputs(const string & op_name)379 const mindspore::HashSet<size_t> &BpropExpander::GetUnusedInputs(const string &op_name) {
380   auto handle = BpropIRBuilderFactory::Instance().GetBuilder(op_name);
381   if (handle == nullptr) {
382     MS_LOG(DEBUG) << "Bprop IRBuilder [" << op_name << "] is not registered in bprop expander.";
383     static const mindspore::HashSet<size_t> no_handle{INT_MAX};
384     return no_handle;
385   }
386   return handle->unused_inputs;
387 }
388 
RunBprop(const CNodePtr & cnode,const std::vector<ValuePtr> & input_values)389 bool BpropExpander::RunBprop(const CNodePtr &cnode, const std::vector<ValuePtr> &input_values) {
390   static const bool cache_env = (common::GetEnv("MS_DEV_DISABLE_BPROP_CACHE") != "on");
391   const auto prim = GetCNodePrimitive(cnode);
392   const auto name = prim->name();
393   std::shared_ptr<PynativeIRBuilder> ir_builder;
394   if (cache_env) {
395     ir_builder = std::make_shared<PynativeIRBuilderWithCache>(prim, cnode->func_graph(), std::make_shared<CppInfer>(),
396                                                               users_, cnode->inputs().back());
397   } else {
398     ir_builder = std::make_shared<PynativeIRBuilder>(prim, cnode->func_graph(), std::make_shared<CppInfer>(), users_,
399                                                      cnode->inputs().back());
400   }
401   input_nodes_.reserve(cnode->size());
402   (void)std::transform(
403     cnode->weak_inputs().cbegin() + 1, cnode->weak_inputs().cend(), std::back_inserter(input_nodes_),
404     [&ir_builder](const AnfNodeWeakPtr &no) { return std::make_shared<IrNode>(no.lock(), ir_builder.get()); });
405   mindspore::HashMap<std::string, ValuePtr> attrs;
406   {
407     PrimitiveReadLock read_lock(prim->shared_mutex());
408     attrs = prim->attrs();
409   }
410   auto handle = BpropIRBuilderFactory::Instance().GetBuilder(name);
411   if (handle == nullptr) {
412     MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander.";
413     return false;
414   }
415   output_nodes_ = ir_builder->Build(input_nodes_, input_values, attrs, *handle);
416   if (output_nodes_.empty()) {
417     MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty.";
418     return false;
419   }
420   PostProcess(cnode);
421   DumpResult(name);
422   return true;
423 }
424 
PostProcess(const CNodePtr & cnode) const425 void BpropExpander::PostProcess(const CNodePtr &cnode) const {
426   outputs_->reserve(output_nodes_.size());
427   constexpr const size_t num_out_and_dout = 2;
428   if (output_nodes_.size() + num_out_and_dout != input_nodes_.size()) {
429     MS_LOG(EXCEPTION) << "For bprop [" << AnfUtils::GetCNodeName(cnode)
430                       << "], the output size should be equal to input size (exclude out and dout), but got "
431                       << output_nodes_.size() << " vs " << (input_nodes_.size() - num_out_and_dout);
432   }
433   for (size_t i = 0; i < output_nodes_.size(); i++) {
434     (void)outputs_->emplace_back(output_nodes_[i]->get()->cast<CNodePtr>());
435   }
436 }
437 
DumpResult(const std::string & name) const438 void BpropExpander::DumpResult(const std::string &name) const {
439   static const bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
440   if (!dump_result) {
441     return;
442   }
443   auto fg = std::make_shared<FuncGraph>();
444   std::map<AnfNodePtr, AnfNodePtr> node_map;
445   CNodePtrList newcnodes;
446   for (auto &inp : input_nodes_) {
447     auto p = fg->add_parameter();
448     p->set_abstract(inp->get()->abstract());
449     node_map[inp->get()] = p;
450   }
451   std::queue<CNodePtr> que;
452   (void)std::for_each(outputs_->cbegin(), outputs_->cend(), [&que](const CNodePtr &cnode) { que.push(cnode); });
453 
454   while (!que.empty()) {
455     auto node = que.front();
456     que.pop();
457     if (node_map.count(node) != 0) {
458       continue;
459     }
460     auto new_node = fg->NewCNode(node->inputs());
461     new_node->CloneCNodeInfo(node);
462     new_node->set_fullname_with_scope(node->fullname_with_scope());
463     node_map[node] = new_node;
464     newcnodes.push_back(new_node);
465     for (size_t i = 1; i < node->size(); ++i) {
466       const auto &inp = node->input(i);
467       if (inp->isa<CNode>() && node_map.count(inp) == 0) {
468         que.push(inp->cast<CNodePtr>());
469       }
470     }
471   }
472 
473   for (auto &cnode : newcnodes) {
474     for (size_t i = 1; i < cnode->size(); i++) {
475       if (node_map.count(cnode->input(i)) != 0) {
476         cnode->set_input(i, node_map[cnode->input(i)]);
477       }
478     }
479   }
480 
481   if (outputs_->size() == 1) {
482     fg->set_output(node_map[(*outputs_)[0]]);
483   } else {
484     AnfNodePtrList new_outputs{NewValueNode(prim::kPrimMakeTuple)};
485     AbstractBasePtrList abs;
486     (void)std::transform(outputs_->cbegin(), outputs_->cend(), std::back_inserter(new_outputs),
487                          [&node_map, &abs](const CNodePtr &node) {
488                            abs.push_back(node->abstract());
489                            return node_map[node];
490                          });
491     auto mt = fg->NewCNode(new_outputs);
492     mt->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
493     fg->set_output(mt);
494   }
495   DumpIR("bprop/bprop_expander_" + name + ".ir", fg, true);
496 
497   if (users_ != nullptr) {
498     for (auto &uiter : users_->dout_user_) {
499       for (auto &iter : uiter.second) {
500         auto user = iter.first.lock();
501         if (user == nullptr) {
502           continue;
503         }
504         MS_LOG(INFO) << "Node " << uiter.first->ToString() << " user: " << user->fullname_with_scope()
505                      << "  index: " << iter.second;
506       }
507     }
508   }
509 }
510 
511 class LazyInfer : public CppInfer {
512  public:
Infer(const NodePtr &)513   void Infer(const NodePtr &) override { return; }
514 
GetAbstract(const NodePtr & node)515   AbstractBasePtr GetAbstract(const NodePtr &node) override {
516     auto anfnode = node->get();
517     if (anfnode->abstract() == nullptr) {
518       InferNow(anfnode);
519     }
520     return anfnode->abstract();
521   }
522 
523  protected:
InferNow(const AnfNodePtr & node)524   void InferNow(const AnfNodePtr &node) {
525     if (node->isa<CNode>()) {
526       auto cnode = node->cast<CNodePtr>();
527       for (size_t i = 1; i < cnode->size(); i++) {
528         if (cnode->input(i)->abstract() == nullptr) {
529           InferNow(cnode->input(i));
530         }
531       }
532     }
533     CppInfer::InferAnfnode(node);
534   }
535 };
536 
537 class GraphModeBuilder : public IrBuilder {
538  public:
GraphModeBuilder(const std::string & name,const FuncGraphPtr & func_graph,const ExpanderInferPtr & infer)539   GraphModeBuilder(const std::string &name, const FuncGraphPtr &func_graph, const ExpanderInferPtr &infer)
540       : IrBuilder(name, func_graph, infer) {}
541 
Build(const NodePtrList & inputs,const mindspore::HashMap<std::string,ValuePtr> & attrs,const BpropHandle & handle,const std::string & instance_name)542   NodePtrList Build(const NodePtrList &inputs, const mindspore::HashMap<std::string, ValuePtr> &attrs,
543                     const BpropHandle &handle, const std::string &instance_name) {
544     auto outputs = Run(inputs, attrs, handle, instance_name);
545     auto mt = this->MakeTuple(outputs)->get();
546     func_graph_->set_output(mt);
547     if (has_ctrl_flow_) {
548       // clear all abstract, to let the specializer re-infer the subgraph of controlflow graphs.
549       auto todos = TopoSort(func_graph_->get_return(), SuccDeeperSimple, AlwaysInclude);
550       for (auto &no : todos) {
551         no->set_abstract(nullptr);
552         if (IsValueNode<FuncGraph>(no)) {
553           auto fg = GetValueNode<FuncGraphPtr>(no);
554           for (auto &p : fg->parameters()) {
555             p->set_abstract(nullptr);
556           }
557         }
558       }
559     }
560     return outputs;
561   }
562 
Conditional(const NodePtr & cond,const BlockFunc & true_case,const BlockFunc & false_case)563   NodePtr Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) override {
564     has_ctrl_flow_ = true;
565     return IrBuilder::Conditional(cond, true_case, false_case);
566   }
567 
While(const NodePtr & cond,const BlockFunc & body,const NodePtrList & init_list)568   NodePtr While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) override {
569     has_ctrl_flow_ = true;
570     return IrBuilder::While(cond, body, init_list);
571   }
572 
573  protected:
EmitOp(const PrimitivePtr & prim,const NodePtrList & inputs)574   NodePtr EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) override {
575     auto primpy = ConvertPrimToPrimPy(prim);
576     AnfNodePtrList cnode_inputs = {NewValueNode(primpy ? primpy : prim)};
577     cnode_inputs.reserve(inputs.size() + 1);
578     (void)std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(cnode_inputs), [](const NodePtr &no) {
579       MS_EXCEPTION_IF_NULL(no);
580       return no->get();
581     });
582     // PyNative use kernel graph construct bprop graph
583     auto cnode = func_graph_->isa<KernelGraph>() ? func_graph_->FuncGraph::NewCNode(cnode_inputs)
584                                                  : func_graph_->NewCNode(cnode_inputs);
585     if (scope_ != nullptr) {
586       cnode->set_scope(scope_);
587     }
588     auto node = NewIrNode(cnode->cast<AnfNodePtr>());
589     infer_->Infer(node);
590     return node;
591   }
592 
593   bool has_ctrl_flow_{false};
594 };
595 
ExpandBpropInGraphMode(const BpropHandle * handle,const PrimitivePtr & prim,const FuncGraphPtr & graph)596 bool ExpandBpropInGraphMode(const BpropHandle *handle, const PrimitivePtr &prim, const FuncGraphPtr &graph) {
597   static const bool use_imm_infer = (common::GetEnv("MS_DEV_BPROP_IMM_INFER") == "on");
598   static const bool dump_result = (common::GetEnv("MS_DEV_DUMP_BPROP") == "on");
599   auto name = prim->name();
600   if (handle == nullptr) {
601     MS_LOG(DEBUG) << "Bprop IRBuilder [" << name << "] is not registered in bprop expander.";
602     return false;
603   }
604   ExpanderInferPtr infer;
605   if (use_imm_infer) {
606     infer = std::make_shared<CppInfer>();
607   } else {
608     infer = std::make_shared<LazyInfer>();
609   }
610   GraphModeBuilder ir_builder(name, graph, infer);
611   auto &parameters = graph->parameters();
612   NodePtrList inputs;
613   inputs.reserve(parameters.size());
614   (void)std::transform(parameters.cbegin(), parameters.cend(), std::back_inserter(inputs),
615                        [&ir_builder](const AnfNodePtr &no) { return std::make_shared<IrNode>(no, &ir_builder); });
616   auto outputs = ir_builder.Build(inputs, prim->attrs(), *handle, prim->instance_name());
617   if (outputs.empty()) {
618     MS_LOG(DEBUG) << "The output nodes of bprop function [" << name << "] is empty.";
619     return false;
620   }
621   if (dump_result) {
622     DumpIR("bprop/bprop_expander_" + name + ".ir", graph, true);
623   }
624   return true;
625 }
626 
627 #ifdef _MSC_VER
628 void RegGradArrayOps();
629 void RegGradClipOps();
630 void RegGradCommOps();
631 void RegGradDebugOps();
632 void RegGradImageOps();
633 void RegGradImplementationsOps();
634 void RegGradInnerOps();
635 void RegGradLinalgOps();
636 void RegGradMathOps();
637 void RegGradNnOps();
638 void RegGradOtherOps();
639 void RegGradQuantOps();
640 void RegGradScipyOps();
641 void RegGradSparseOps();
642 void RegGradSequenceOps();
643 void RegGradScalarOps();
644 
WinBpropRegister()645 WinBpropRegister::WinBpropRegister() {
646   RegGradArrayOps();
647   RegGradClipOps();
648   RegGradCommOps();
649   RegGradDebugOps();
650   RegGradImageOps();
651   RegGradImplementationsOps();
652   RegGradInnerOps();
653   RegGradLinalgOps();
654   RegGradMathOps();
655   RegGradNnOps();
656   RegGradOtherOps();
657   RegGradQuantOps();
658   RegGradScipyOps();
659   RegGradSparseOps();
660   RegGradSequenceOps();
661   RegGradScalarOps();
662 }
663 #endif
664 }  // namespace bprop
665 }  // namespace expander
666 }  // namespace mindspore
667