• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2022 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/anf.h"
20 
21 #include <algorithm>
22 #include <sstream>
23 #include <vector>
24 #include <queue>
25 
26 #include "mindspore/core/ops/structure_ops.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "utils/hash_map.h"
30 #include "ir/func_graph.h"
31 #include "ir/primitive.h"
32 #include "utils/ms_context.h"
33 #include "utils/anf_utils.h"
34 #include "utils/compile_config.h"
35 #include "ops/op_def.h"
36 
37 namespace mindspore {
38 namespace {
IsNeedCheckPrimitiveNode(const AnfNodePtr & prim_node)39 std::pair<bool, PrimitivePtr> IsNeedCheckPrimitiveNode(const AnfNodePtr &prim_node) {
40   if (!IsValueNode<Primitive>(prim_node)) {
41     return {false, nullptr};
42   }
43   auto prim = GetValueNode<PrimitivePtr>(prim_node);
44   MS_EXCEPTION_IF_NULL(prim);
45   if (prim->IsPythonPrim() || prim->HasAttr(kSkipCheckInputNum)) {
46     return {false, nullptr};
47   }
48   if (prim->GetAttr("primitive_function") == nullptr) {
49     return {false, nullptr};
50   }
51   auto op_def = mindspore::ops::GetOpDef(prim->name());
52   if (op_def == nullptr) {
53     return {false, nullptr};
54   }
55 
56   return {true, prim};
57 }
58 
PrintErrorInfo(const AnfNodeWeakPtrList & inputs,const PrimitivePtr & prim,size_t input_tensor_num,const ops::OpDefPtr & op_def)59 void PrintErrorInfo(const AnfNodeWeakPtrList &inputs, const PrimitivePtr &prim, size_t input_tensor_num,
60                     const ops::OpDefPtr &op_def) {
61   std::stringstream ss;
62   size_t i = 0;
63   ss << "Inputs are as follows: \n";
64   for (const auto &input : inputs) {
65     ss << "Input[" << i++ << "]: " << input.lock()->DebugString() << "\n";
66   }
67   MS_LOG(DEBUG) << "Primitive<" << prim->name() << "> inputs num: " << input_tensor_num
68                 << " is not equal to expect input num: " << op_def->args_.size() << "\n"
69                 << ss.str();
70 }
71 
CheckCNodeInputsNum(const AnfNodeWeakPtrList & inputs)72 void CheckCNodeInputsNum(const AnfNodeWeakPtrList &inputs) {
73   if (!IS_OUTPUT_ON(mindspore::kDebug) || inputs.empty()) {
74     return;
75   }
76 
77   auto [need_check, prim] = IsNeedCheckPrimitiveNode(inputs[0].lock());
78   if (!need_check) {
79     return;
80   }
81 
82   auto op_def = mindspore::ops::GetOpDef(prim->name());
83   if (op_def == nullptr) {
84     return;
85   }
86   bool input_num_err = false;
87   constexpr size_t prim_num = 1;
88   size_t input_tensor_num = inputs.size() - prim_num;
89   if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) {
90     size_t monad_num =
91       static_cast<size_t>(std::count_if(inputs.cbegin() + 1, inputs.end(), [](const AnfNodeWeakPtr &weak_input) {
92         const auto &input = weak_input.lock();
93         return HasAbstractMonad(input) || (IsPrimitiveCNode(input, prim::kPrimUpdateState) ||
94                                            IsValueNode<UMonad>(input) || IsValueNode<IOMonad>(input));
95       }));
96     // If monad input is parameter_monad, monad num is 0, actual monad num should be 1. And monad num is 0 if monad
97     // pass has not been executed.
98     if (monad_num == 0) {
99       auto monad_pass_not_executed_check_failed = op_def->args_.size() != input_tensor_num;
100       constexpr auto parameter_monad_num = 1;
101       auto exist_parameter_monad_check_failed = op_def->args_.size() != input_tensor_num - parameter_monad_num;
102       if (monad_pass_not_executed_check_failed && exist_parameter_monad_check_failed) {
103         input_num_err = true;
104       }
105     } else if (monad_num == 1) {
106       if (op_def->args_.size() != input_tensor_num - monad_num) {
107         input_num_err = true;
108       }
109     } else {
110       MS_LOG(INTERNAL_EXCEPTION) << "Get unexpected monad num: " << monad_num;
111     }
112   } else {
113     if (op_def->args_.size() != input_tensor_num) {
114       input_num_err = true;
115     }
116   }
117   if (input_num_err) {
118     PrintErrorInfo(inputs, prim, input_tensor_num, op_def);
119   }
120 }
121 }  // namespace
AnfNode(const FuncGraphPtr & func_graph,NodeDebugInfoPtr && debug_info)122 AnfNode::AnfNode(const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info)
123     : func_graph_(FuncGraphWeakPtr(func_graph)),
124       abstract_(nullptr),
125       debug_info_(std::move(debug_info)),
126       fullname_with_scope_(""),
127       scope_(ScopeManager::GetInstance().GetCurrentScope()) {}
128 
AnfNode(const FuncGraphPtr & func_graph)129 AnfNode::AnfNode(const FuncGraphPtr &func_graph) : AnfNode(func_graph, std::make_shared<NodeDebugInfo>()) {}
130 
accept(AnfIrVisitor *)131 void AnfNode::accept(AnfIrVisitor *) {}
132 
func_graph() const133 FuncGraphPtr AnfNode::func_graph() const { return func_graph_.lock(); }
134 
set_func_graph(const FuncGraphPtr & func_graph)135 void AnfNode::set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
136 
scope()137 ScopePtr AnfNode::scope() { return scope_; }
138 
set_scope(const ScopePtr & scope)139 void AnfNode::set_scope(const ScopePtr &scope) { scope_ = scope; }
140 
kernel_info() const141 const KernelInfoDevice *AnfNode::kernel_info() const { return kernel_info_ptr().get(); }
142 
kernel_info()143 KernelInfoDevice *AnfNode::kernel_info() { return kernel_info_ptr().get(); }
144 
kernel_info_ptr() const145 KernelInfoDevicePtr AnfNode::kernel_info_ptr() const { return user_data<KernelInfoDevice>(kKernelInfoKey); }
146 
set_kernel_info(const KernelInfoDevicePtr & kernel_info)147 void AnfNode::set_kernel_info(const KernelInfoDevicePtr &kernel_info) { set_user_data(kKernelInfoKey, kernel_info); }
148 
debug_info()149 NodeDebugInfoPtr AnfNode::debug_info() {
150   MS_EXCEPTION_IF_NULL(debug_info_);
151   return debug_info_;
152 }
153 
set_debug_info(const NodeDebugInfoPtr & debug_info)154 void AnfNode::set_debug_info(const NodeDebugInfoPtr &debug_info) {
155   MS_EXCEPTION_IF_NULL(debug_info);
156   debug_info->set_type_name(type_name());
157   debug_info->set_debug_name(std::string());  // Clear cached debug name.
158   debug_info_ = debug_info;
159 }
160 
hash() const161 std::size_t AnfNode::hash() const { return PointerHash<AnfNode>{}(this); }
162 
fullname_with_scope()163 std::string AnfNode::fullname_with_scope() { return ""; }
164 
UniqueName()165 std::string AnfNode::UniqueName() { return fullname_with_scope() + "_" + UniqueId(); }
166 
DebugString(int recursive_level) const167 std::string AnfNode::DebugString(int recursive_level) const { return ToString(); }
168 
DebugString(bool recursive) const169 std::string AnfNode::DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); }
170 
dump() const171 void AnfNode::dump() const { std::cout << DebugString() << std::endl; }
172 
UniqueId()173 std::string AnfNode::UniqueId() { return std::to_string(debug_info()->unique_id()); }
174 
UniqueIdThroughCopy()175 std::string AnfNode::UniqueIdThroughCopy() { return std::to_string(debug_info()->unique_id_through_copy()); }
176 
operator ==(const AnfNode & other) const177 bool AnfNode::operator==(const AnfNode &other) const { return &other == this; }
178 
operator <<(std::ostream & os,const AnfNode & node)179 std::ostream &operator<<(std::ostream &os, const AnfNode &node) {
180   os << node.ToString();
181   return os;
182 }
183 
interpret() const184 bool AnfNode::interpret() const { return interpret_flags_[kInterpret]; }
185 
set_interpret(const bool & interpret)186 void AnfNode::set_interpret(const bool &interpret) { interpret_flags_[kInterpret] = interpret; }
187 
interpret_internal_type()188 bool AnfNode::interpret_internal_type() { return interpret_flags_[kInterpretInternalType]; }
189 
set_interpret_internal_type(const bool & interpret_internal_type)190 void AnfNode::set_interpret_internal_type(const bool &interpret_internal_type) {
191   interpret_flags_[kInterpretInternalType] = interpret_internal_type;
192 }
193 
abstract() const194 const AbstractBasePtr &AnfNode::abstract() const {
195   // cppcheck-suppress unreadVariable
196   auto lock = AnfUtils::GetAbstractLock(this);
197   return abstract_;
198 }
199 
set_abstract(const AbstractBasePtr & abs)200 void AnfNode::set_abstract(const AbstractBasePtr &abs) {
201   // cppcheck-suppress unreadVariable
202   auto lock = AnfUtils::GetAbstractLock(this);
203   abstract_ = abs;
204 }
205 
CheckCNodeWeakInput()206 void CNode::CheckCNodeWeakInput() {
207 #if defined(DEBUG_CNODE_WEAK_INPUT)
208   for (size_t i = 0; i < weak_inputs_.size(); ++i) {
209     if (weak_inputs_[i].lock() == nullptr) {
210       MS_LOG(INTERNAL_EXCEPTION) << "The " << i << "th input is released.";
211     }
212   }
213 #endif  // DEBUG_CNODE_WEAK_INPUT
214 }
215 
216 // Namespace to support intermediate representation definition
CNode(const AnfNodePtrList & inputs,const FuncGraphPtr & func_graph)217 CNode::CNode(const AnfNodePtrList &inputs, const FuncGraphPtr &func_graph)
218     : AnfNode(func_graph),
219       primal_attrs_(PrimalAttrManager::GetInstance().GetCurrentPrimalAttr()),
220       primal_debug_infos_(PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo()) {
221   weak_inputs_.reserve(inputs.size());
222   std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs_),
223                  [](const AnfNodePtr &node) -> AnfNodeWeakPtr { return AnfNodeWeakPtr(node); });
224   if (func_graph != nullptr) {
225     MS_LOG(DEBUG) << "Create new CNode, " << this << "@" << func_graph->ToString();
226     func_graph->AddOwnNode(inputs);
227   } else {
228     MS_LOG(WARNING) << "The func graph should not be null.";
229     inputs_bound_ = true;
230     inputs_ = inputs;
231   }
232   CheckCNodeInputsNum(weak_inputs_);
233   Init();
234 }
235 
CNode(const AnfNodePtrList & inputs,const VarPtr & func_graph_as_var)236 CNode::CNode(const AnfNodePtrList &inputs, const VarPtr &func_graph_as_var) : AnfNode(nullptr), inputs_(inputs) {
237   inputs_bound_ = true;
238   weak_inputs_.reserve(inputs.size());
239   std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs_),
240                  [](const AnfNodePtr &node) -> AnfNodeWeakPtr { return AnfNodeWeakPtr(node); });
241   primal_attrs_ = PrimalAttrManager::GetInstance().GetCurrentPrimalAttr();
242   primal_debug_infos_ = PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo();
243   set_user_data(kFuncGraphVarKey, func_graph_as_var);
244   Init();
245 }
246 
CNode(AnfNodeWeakPtrList && weak_inputs,const FuncGraphPtr & func_graph)247 CNode::CNode(AnfNodeWeakPtrList &&weak_inputs, const FuncGraphPtr &func_graph)
248     : AnfNode(func_graph),
249       weak_inputs_(std::move(weak_inputs)),
250       primal_attrs_(PrimalAttrManager::GetInstance().GetCurrentPrimalAttr()),
251       primal_debug_infos_(PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo()) {
252   if (func_graph != nullptr) {
253     MS_LOG(DEBUG) << "Create new CNode, " << this << "@" << func_graph->ToString();
254     func_graph->AddOwnNode(weak_inputs_);
255   } else {
256     MS_LOG(INTERNAL_EXCEPTION) << "The func graph should not be null.";
257   }
258   CheckCNodeInputsNum(weak_inputs_);
259   Init();
260 }
261 
CNode(AnfNodeWeakPtrList && weak_inputs,const FuncGraphPtr & func_graph,NodeDebugInfoPtr && debug_info)262 CNode::CNode(AnfNodeWeakPtrList &&weak_inputs, const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info)
263     : AnfNode(func_graph, std::move(debug_info)),
264       weak_inputs_(std::move(weak_inputs)),
265       primal_attrs_(PrimalAttrManager::GetInstance().GetCurrentPrimalAttr()),
266       primal_debug_infos_(PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo()) {
267   if (func_graph != nullptr) {
268     MS_LOG(DEBUG) << "Create new CNode, " << this << "@" << func_graph->ToString();
269     func_graph->AddOwnNode(weak_inputs_);
270   } else {
271     MS_LOG(INTERNAL_EXCEPTION) << "The func graph should not be null.";
272   }
273   CheckCNodeInputsNum(weak_inputs_);
274   Init();
275 }
276 
CNode(const AnfNodeWeakPtrList & weak_inputs,const FuncGraphPtr & func_graph)277 CNode::CNode(const AnfNodeWeakPtrList &weak_inputs, const FuncGraphPtr &func_graph)
278     : AnfNode(func_graph),
279       weak_inputs_(weak_inputs),
280       primal_attrs_(PrimalAttrManager::GetInstance().GetCurrentPrimalAttr()),
281       primal_debug_infos_(PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo()) {
282   if (func_graph != nullptr) {
283     MS_LOG(DEBUG) << "Create new CNode, " << this << "@" << func_graph->ToString();
284     func_graph->AddOwnNode(weak_inputs_);
285   } else {
286     MS_LOG(INTERNAL_EXCEPTION) << "The func graph should not be null.";
287   }
288   CheckCNodeInputsNum(weak_inputs_);
289   Init();
290 }
291 
Init()292 void CNode::Init() {
293   CheckCNodeWeakInput();
294   debug_info()->set_type_name(type_name());
295 }
296 
set_debug_info(const NodeDebugInfoPtr & debug_info)297 void CNode::set_debug_info(const NodeDebugInfoPtr &debug_info) {
298   MS_EXCEPTION_IF_NULL(debug_info);
299   debug_info->set_type_name(type_name());
300   debug_info->set_debug_name(std::string());  // Clear cached debug name.
301   debug_info_ = debug_info;
302 }
303 
size() const304 const size_t CNode::size() const { return weak_inputs_.size(); }
305 
empty() const306 const bool CNode::empty() const { return size() == 0; }
307 
inputs()308 const AnfNodePtrList &CNode::inputs() {
309   // Check for 'CNode(const AnfNodePtrList &, const VarPtr &)' for compatible.
310   if (inputs_bound_) {
311     return inputs_;
312   }
313 
314   inputs_.clear();
315   inputs_.reserve(weak_inputs_.size());
316   std::transform(weak_inputs_.cbegin(), weak_inputs_.cend(), std::back_inserter(inputs_),
317                  [this](const AnfNodeWeakPtr &weak_node) -> AnfNodePtr {
318                    auto node = weak_node.lock();
319 #if defined(DEBUG_CNODE_WEAK_INPUT)
320                    if (node == nullptr) {
321                      MS_LOG(INTERNAL_EXCEPTION) << "Weak input lock failed, " << DebugString();
322                    }
323 #endif  // DEBUG_CNODE_WEAK_INPUT
324                    return node;
325                  });
326   inputs_bound_ = true;
327   return inputs_;
328 }
329 
weak_inputs() const330 const AnfNodeWeakPtrList &CNode::weak_inputs() const { return weak_inputs_; }
331 
332 // Check if CNode is an apply with the specific Primitive.
IsApply(const PrimitivePtr & value) const333 bool CNode::IsApply(const PrimitivePtr &value) const {
334   if (value == nullptr || weak_inputs_.empty()) {
335     return false;
336   }
337   auto prim = GetValuePtr<Primitive>(weak_inputs_[0].lock());
338   return (prim != nullptr) && (prim->Hash() == value->Hash()) && (prim->name() == value->name());
339 }
340 
add_input(const AnfNodePtr & input)341 void CNode::add_input(const AnfNodePtr &input) {
342   if (inputs_bound_) {
343     inputs_.emplace_back(input);
344   }
345 
346   (void)weak_inputs_.emplace_back(AnfNodeWeakPtr(input));
347   if (func_graph() != nullptr) {
348     func_graph()->AddOwnNode(input);
349   }
350 
351   input_tensor_num_ = -1;
352 }
353 
set_input(size_t i,const AnfNodePtr & new_input)354 void CNode::set_input(size_t i, const AnfNodePtr &new_input) {
355   if (inputs_bound_) {
356     if (i >= inputs_.size()) {
357       MS_LOG(INTERNAL_EXCEPTION) << "i: " << i << " out of range: " << weak_inputs_.size()
358                                  << ", cnode: " << DebugString();
359     }
360     inputs_[i] = new_input;
361   }
362 
363   if (i >= weak_inputs_.size()) {
364     MS_LOG(INTERNAL_EXCEPTION) << "i: " << i << " out of range: " << weak_inputs_.size()
365                                << ", cnode: " << DebugString();
366   }
367   weak_inputs_[i] = AnfNodeWeakPtr(new_input);
368   if (func_graph() != nullptr) {
369     func_graph()->AddOwnNode(new_input);
370   }
371 
372   input_tensor_num_ = -1;
373 }
374 
set_inputs(const AnfNodePtrList & inputs)375 void CNode::set_inputs(const AnfNodePtrList &inputs) {
376   if (inputs_bound_) {
377     inputs_ = inputs;
378   }
379 
380   weak_inputs_.clear();
381   weak_inputs_.reserve(inputs.size());
382   std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs_),
383                  [](const AnfNodePtr &node) -> AnfNodeWeakPtr { return AnfNodeWeakPtr(node); });
384   if (func_graph() != nullptr) {
385     func_graph()->AddOwnNode(inputs);
386   }
387   input_tensor_num_ = -1;
388   CheckCNodeInputsNum(weak_inputs_);
389 }
390 
set_weak_inputs(const AnfNodeWeakPtrList & weak_inputs)391 void CNode::set_weak_inputs(const AnfNodeWeakPtrList &weak_inputs) {
392   if (inputs_bound_) {
393     inputs_.clear();
394     inputs_.reserve(weak_inputs.size());
395     std::transform(weak_inputs.cbegin(), weak_inputs.cend(), std::back_inserter(inputs_),
396                    [this](const AnfNodeWeakPtr &weak_node) -> AnfNodePtr {
397                      auto node = weak_node.lock();
398                      if (node == nullptr) {
399                        MS_LOG(INTERNAL_EXCEPTION) << "Weak input lock failed, " << DebugString();
400                      }
401                      return node;
402                    });
403   }
404 
405   weak_inputs_ = weak_inputs;
406   if (func_graph() != nullptr) {
407     func_graph()->AddOwnNode(weak_inputs_);
408   }
409   input_tensor_num_ = -1;
410   CheckCNodeInputsNum(weak_inputs_);
411 }
412 
input(size_t i) const413 const AnfNodePtr CNode::input(size_t i) const {
414   if (inputs_bound_) {
415     if (i >= inputs_.size()) {
416       MS_LOG(INTERNAL_EXCEPTION) << "i: " << i << " out of range: " << inputs_.size() << ", cnode: " << DebugString();
417     }
418     return inputs_.at(i);
419   }
420 
421   if (i >= weak_inputs_.size()) {
422     MS_LOG(INTERNAL_EXCEPTION) << "i: " << i << " out of range: " << weak_inputs_.size()
423                                << ", cnode: " << DebugString();
424   }
425   auto res = weak_inputs_.at(i).lock();
426   if (res == nullptr) {
427     MS_LOG(ERROR) << "The input[" << i << "] is released, cnode: " << DebugString();
428   }
429   return res;
430 }
431 
weak_input(size_t i) const432 const AnfNodeWeakPtr &CNode::weak_input(size_t i) const {
433   if (i >= weak_inputs_.size()) {
434     MS_LOG(INTERNAL_EXCEPTION) << "i: " << i << " out of range: " << weak_inputs_.size()
435                                << ", cnode: " << DebugString();
436   }
437   return weak_inputs_.at(i);
438 }
439 
DebugString(int recursive_level) const440 std::string CNode::DebugString(int recursive_level) const {
441   std::ostringstream buffer;
442   if (recursive_level > 0) {
443     if (func_graph() != nullptr) {
444       buffer << "@" << func_graph()->ToString() << ":";
445     }
446     buffer << ToString() << "{";
447     bool is_first_node = true;
448     int idx = 0;
449     for (const auto &weak_node : weak_inputs_) {
450       if (is_first_node) {
451         is_first_node = false;
452       } else {
453         buffer << ", ";
454       }
455       const auto &node = weak_node.lock();
456       if (node == nullptr) {
457         buffer << "[" << idx << "]: (released_node)";
458         ++idx;
459         continue;
460       }
461       buffer << "[" << idx << "]: " << node->DebugString(recursive_level - 1);
462       ++idx;
463     }
464     buffer << "}";
465   } else {
466     buffer << ToString();
467   }
468   return buffer.str();
469 }
470 
AddFusedDebugInfo(const AnfNodePtr & node)471 void CNode::AddFusedDebugInfo(const AnfNodePtr &node) {
472   if (node == nullptr || !node->isa<CNode>()) {
473     return;
474   }
475   if (shared_from_this() == node) {
476     this->AddFusedDebugInfo(node->debug_info());
477     return;
478   }
479   auto cnode = node->cast_ptr<CNode>();
480   auto node_fused_debug_infos = cnode->fused_debug_infos();
481   if (!node_fused_debug_infos.empty()) {
482     (void)std::for_each(node_fused_debug_infos.begin(), node_fused_debug_infos.end(),
483                         [this](const NodeDebugInfoPtr &debug_info) { this->AddFusedDebugInfo(debug_info); });
484   } else {
485     this->AddFusedDebugInfo(cnode->debug_info());
486   }
487 
488   auto primal_debug_infos = cnode->primal_debug_infos();
489   if (!primal_debug_infos.empty()) {
490     (void)std::for_each(primal_debug_infos.begin(), primal_debug_infos.end(),
491                         [this](const NodeDebugInfoPtr &debug_info) { this->AddPrimalDebugInfo(debug_info); });
492   }
493 }
494 
AddFusedDebugInfoList(const AnfNodePtrList & nodes)495 void CNode::AddFusedDebugInfoList(const AnfNodePtrList &nodes) {
496   (void)std::for_each(nodes.begin(), nodes.end(), [this](const AnfNodePtr &node) { this->AddFusedDebugInfo(node); });
497 }
498 
AddFusedDebugInfo(const NodeDebugInfoPtr & debug_info)499 void CNode::AddFusedDebugInfo(const NodeDebugInfoPtr &debug_info) {
500   if (debug_info == nullptr) {
501     return;
502   }
503   (void)fused_debug_infos_.emplace(debug_info);
504 }
505 
AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> & debug_infos)506 void CNode::AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos) {
507   (void)std::for_each(debug_infos.begin(), debug_infos.end(),
508                       [this](const NodeDebugInfoPtr &debug_info) { this->AddFusedDebugInfo(debug_info); });
509 }
510 
primal_debug_infos() const511 NodeDebugInfoSet CNode::primal_debug_infos() const { return primal_debug_infos_; }
512 
set_primal_debug_infos(const NodeDebugInfoSet & debug_infos)513 void CNode::set_primal_debug_infos(const NodeDebugInfoSet &debug_infos) {
514   (void)std::for_each(debug_infos.begin(), debug_infos.end(),
515                       [this](const NodeDebugInfoPtr &debug_info) { this->AddPrimalDebugInfo(debug_info); });
516 }
517 
AddPrimalDebugInfo(const NodeDebugInfoPtr & debug_info)518 void CNode::AddPrimalDebugInfo(const NodeDebugInfoPtr &debug_info) { (void)primal_debug_infos_.emplace(debug_info); }
519 
set_forward(const ValueNodePtr & forward,const std::string & id)520 void CNode::set_forward(const ValueNodePtr &forward, const std::string &id) {
521   set_user_data(kOutputValueKey, std::make_shared<OutputValue>(forward, id));
522 }
523 
forward() const524 const CNode::OutputValue &CNode::forward() const {
525   static const CNode::OutputValue empty_value;
526   auto ptr = user_data<CNode::OutputValue>(kOutputValueKey);
527   if (ptr == nullptr) {
528     return empty_value;
529   }
530   return *ptr;
531 }
532 
stop_gradient() const533 bool CNode::stop_gradient() const { return flags_[kStopGradient]; }
534 
set_stop_gradient(bool stop_gradient)535 void CNode::set_stop_gradient(bool stop_gradient) { flags_[kStopGradient] = stop_gradient; }
536 
set_fullname_with_scope(const std::string full_name)537 void CNode::set_fullname_with_scope(const std::string full_name) { fullname_with_scope_ = full_name; }
538 
DebugString(bool recursive) const539 std::string CNode::DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); }
540 
set_in_forward_flag(bool flag)541 void CNode::set_in_forward_flag(bool flag) { flags_[kInForwardFlag] = flag; }
542 
in_forward_flag() const543 bool CNode::in_forward_flag() const { return flags_[kInForwardFlag]; }
544 
set_load_flag(bool is_load)545 void CNode::set_load_flag(bool is_load) { flags_[kIsLoad] = is_load; }
546 
get_load_flag() const547 bool CNode::get_load_flag() const { return flags_[kIsLoad]; }
548 
func_graph_as_var() const549 VarPtr CNode::func_graph_as_var() const { return user_data<Var>(kFuncGraphVarKey); }
550 
attrs() const551 const mindspore::HashMap<std::string, ValuePtr> &CNode::attrs() const { return attrs_; }
552 
set_attrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)553 void CNode::set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
554   attrs_.insert(attrs.cbegin(), attrs.cend());
555 }
556 
AddAttr(const std::string & name,const ValuePtr & attr)557 void CNode::AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
558 
EraseAttr(const std::string & name)559 void CNode::EraseAttr(const std::string &name) { (void)attrs_.erase(name); }
560 
GetAttr(const std::string & name) const561 ValuePtr CNode::GetAttr(const std::string &name) const {
562   auto iter = attrs_.find(name);
563   return iter == attrs_.cend() ? nullptr : iter->second;
564 }
565 
HasAttr(const std::string & name) const566 bool CNode::HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); }
567 
input_tensor_num() const568 ssize_t CNode::input_tensor_num() const { return input_tensor_num_; }
569 
primal_attrs() const570 const mindspore::HashMap<std::string, ValuePtr> &CNode::primal_attrs() const { return primal_attrs_; }
571 
set_primal_attrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)572 void CNode::set_primal_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
573   primal_attrs_.insert(attrs.cbegin(), attrs.cend());
574 }
575 
AddPrimalAttr(const std::string & name,const ValuePtr & attr)576 void CNode::AddPrimalAttr(const std::string &name, const ValuePtr &attr) { primal_attrs_[name] = attr; }
577 
ErasePrimalAttr(const std::string & name)578 void CNode::ErasePrimalAttr(const std::string &name) { (void)primal_attrs_.erase(name); }
579 
GetPrimalAttr(const std::string & name) const580 ValuePtr CNode::GetPrimalAttr(const std::string &name) const {
581   auto iter = primal_attrs_.find(name);
582   return iter == primal_attrs_.cend() ? nullptr : iter->second;
583 }
584 
HasPrimalAttr(const std::string & name) const585 bool CNode::HasPrimalAttr(const std::string &name) const { return primal_attrs_.find(name) != primal_attrs_.end(); }
586 
CloneCNodeInfo(const CNodePtr & node)587 void CNode::CloneCNodeInfo(const CNodePtr &node) {
588   MS_EXCEPTION_IF_NULL(node);
589   set_abstract(node->abstract());
590   set_forward(node->forward().first, node->forward().second);
591   set_attrs(node->attrs());
592   set_primal_attrs(node->primal_attrs());
593   set_load_flag(node->get_load_flag());
594   CloneUserData(node);
595   set_kernel_info(node->kernel_info_ptr());
596   set_primal_debug_infos(node->primal_debug_infos());
597   set_fused_debug_infos(node->fused_debug_infos());
598 }
599 
set_input_tensor_num(ssize_t input_tensor_num)600 void CNode::set_input_tensor_num(ssize_t input_tensor_num) { input_tensor_num_ = input_tensor_num; }
601 
IsEffectHandled() const602 bool CNode::IsEffectHandled() const { return flags_[kEffectHandled]; }
603 
SetEffectHandled(bool handled)604 void CNode::SetEffectHandled(bool handled) { flags_[kEffectHandled] = handled; }
605 
fused_debug_infos() const606 NodeDebugInfoSet CNode::fused_debug_infos() const { return fused_debug_infos_; }
607 
set_fused_debug_infos(const NodeDebugInfoSet & fused_debug_infos)608 void CNode::set_fused_debug_infos(const NodeDebugInfoSet &fused_debug_infos) { fused_debug_infos_ = fused_debug_infos; }
609 
has_side_effect_node() const610 bool CNode::has_side_effect_node() const { return has_side_effect_node_; }
611 
set_has_side_effect_node(bool has_side_effect_node)612 void CNode::set_has_side_effect_node(bool has_side_effect_node) { has_side_effect_node_ = has_side_effect_node; }
613 
Parameter(const FuncGraphPtr & func_graph)614 Parameter::Parameter(const FuncGraphPtr &func_graph) : ANode(func_graph) { Init(); }
615 
Parameter(const FuncGraphPtr & func_graph,NodeDebugInfoPtr && debug_info)616 Parameter::Parameter(const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info)
617     : ANode(func_graph, std::move(debug_info)) {
618   Init();
619 }
620 
Init()621 void Parameter::Init() { debug_info()->set_type_name(type_name()); }
622 
set_debug_info(const NodeDebugInfoPtr & debug_info)623 void Parameter::set_debug_info(const NodeDebugInfoPtr &debug_info) {
624   MS_EXCEPTION_IF_NULL(debug_info);
625   debug_info->set_type_name(type_name());
626   debug_info->set_debug_name(std::string());  // Clear cached debug name.
627   debug_info_ = debug_info;
628 }
629 
name() const630 std::string Parameter::name() const { return name_; }
631 
set_name(const std::string & name)632 void Parameter::set_name(const std::string &name) { name_ = name; }
633 
fullname_with_scope()634 std::string Parameter::fullname_with_scope() { return name(); }
635 
has_default() const636 bool Parameter::has_default() const { return has_default_; }
637 
set_default_param(const ValuePtr & param)638 void Parameter::set_default_param(const ValuePtr &param) {
639   default_param_ = param;
640   has_default_ = true;
641 }
642 
default_param() const643 const ValuePtr &Parameter::default_param() const { return default_param_; }
644 
IncreaseUsedGraphCount()645 void Parameter::IncreaseUsedGraphCount() { used_graph_count_++; }
646 
DecreaseUsedGraphCount()647 void Parameter::DecreaseUsedGraphCount() { used_graph_count_--; }
648 
used_graph_count() const649 int Parameter::used_graph_count() const { return used_graph_count_; }
650 
is_top_graph_param() const651 bool Parameter::is_top_graph_param() const { return is_top_graph_param_; }
652 
set_is_top_graph_param(bool flag)653 void Parameter::set_is_top_graph_param(bool flag) { is_top_graph_param_ = flag; }
654 
operator ==(const AnfNode & other) const655 bool Parameter::operator==(const AnfNode &other) const {
656   if (!other.isa<Parameter>()) {
657     return false;
658   }
659   auto &p = static_cast<const Parameter &>(other);
660   if (name_.length() > 0 && p.name_.length() > 0) {
661     return p.name_ == name_;
662   }
663   return shared_from_this() == other.shared_from_this();
664 }
665 
SetNotUsedByRealKernelInGraph(uint32_t graph_id)666 void Parameter::SetNotUsedByRealKernelInGraph(uint32_t graph_id) { (void)not_used_in_graphs_.insert(graph_id); }
667 
IsUsedByRealKernelInGraph(uint32_t graph_id) const668 bool Parameter::IsUsedByRealKernelInGraph(uint32_t graph_id) const {
669   return not_used_in_graphs_.find(graph_id) == not_used_in_graphs_.end();
670 }
671 
set_has_dynamic_shape(bool flag)672 void Parameter::set_has_dynamic_shape(bool flag) { has_dynamic_shape_ = flag; }
673 
has_dynamic_shape() const674 bool Parameter::has_dynamic_shape() const { return has_dynamic_shape_; }
675 
set_dynamic_len(bool flag)676 void Parameter::set_dynamic_len(bool flag) { is_dynamic_len_ = flag; }
677 
dynamic_len() const678 bool Parameter::dynamic_len() const { return is_dynamic_len_; }
679 
set_fracz_group(int64_t fracz_group)680 void Parameter::set_fracz_group(int64_t fracz_group) { format_attrs_.fracz_group = fracz_group; }
681 
fracz_group() const682 int64_t Parameter::fracz_group() const { return format_attrs_.fracz_group; }
683 
set_input_size(int64_t input_size)684 void Parameter::set_input_size(int64_t input_size) { format_attrs_.input_size = input_size; }
685 
input_size() const686 int64_t Parameter::input_size() const { return format_attrs_.input_size; }
687 
set_hidden_size(int64_t hidden_size)688 void Parameter::set_hidden_size(int64_t hidden_size) { format_attrs_.hidden_size = hidden_size; }
689 
hidden_size() const690 int64_t Parameter::hidden_size() const { return format_attrs_.hidden_size; }
691 
DebugString(int recursive_level) const692 std::string Parameter::DebugString(int recursive_level) const {
693   std::ostringstream buffer;
694   if (recursive_level > 0) {
695     if (func_graph() != nullptr) {
696       buffer << "@" << func_graph()->ToString() << ":";
697     }
698   }
699   buffer << "param_" << ToString();
700   return buffer.str();
701 }
702 
param_info() const703 ParamInfoPtr Parameter::param_info() const {
704   if (!has_default()) {
705     return nullptr;
706   }
707   auto tensor = default_param()->cast_ptr<tensor::MetaTensor>();
708   if (tensor == nullptr || !tensor->is_parameter()) {
709     return nullptr;
710   }
711   return tensor->param_info();
712 }
713 
Value(const TypePtr t)714 Value::Value(const TypePtr t) : type_(t) {}
715 
Value(const Value & other)716 Value::Value(const Value &other) : Base(other) { this->type_ = other.type_; }
717 
type() const718 TypePtr Value::type() const { return type_; }
719 
ToAbstract()720 abstract::AbstractBasePtr Value::ToAbstract() {
721   MS_LOG(INTERNAL_EXCEPTION) << "ToAbstract error : The class " << type_name() << "has no implement ToAbstract yet.";
722 }
723 
operator =(const Value & other)724 Value &Value::operator=(const Value &other) {
725   if (&other == this) {
726     return *this;
727   }
728   this->type_ = other.type_;
729   return *this;
730 }
731 
ContainsValueAny() const732 bool Value::ContainsValueAny() const { return false; }
733 
ValueNode(const ValuePtr & value)734 ValueNode::ValueNode(const ValuePtr &value) : value_(value) {
735   if (value->ContainsValueAny()) {
736     MS_LOG(EXCEPTION) << "Value of value node cannot be ValueAny. Value: " << value->ToString();
737   }
738 #ifdef DEBUG
739   // Check if the func graph is release too early.
740   const auto &fg = dyn_cast<FuncGraph>(value);
741   if (fg != nullptr && fg->get_return() == nullptr) {
742     MS_LOG(WARNING) << fg->ToString() << "(ptr: " << fg << ") maybe released early, or not set return yet."
743   }
744 #endif
745   Init();
746 }
747 
ValueNode(const ValuePtr & value,NodeDebugInfoPtr && debug_info)748 ValueNode::ValueNode(const ValuePtr &value, NodeDebugInfoPtr &&debug_info)
749     : ANode(nullptr, std::move(debug_info)), value_(value) {
750   MS_EXCEPTION_IF_NULL(value);
751   if (value->ContainsValueAny()) {
752     MS_LOG(EXCEPTION) << "Value of value node cannot be ValueAny. Value: " << value->ToString();
753   }
754 #ifdef DEBUG
755   // Check if the func graph is release too early.
756   const auto &fg = dyn_cast<FuncGraph>(value);
757   if (fg != nullptr && fg->get_return() == nullptr) {
758     MS_LOG(WARNING) << fg->ToString() << "(ptr: " << fg << ") maybe released early, or not set return yet."
759   }
760 #endif
761   Init();
762 }
763 
Init()764 void ValueNode::Init() { debug_info()->set_type_name(type_name()); }
765 
set_debug_info(const NodeDebugInfoPtr & debug_info)766 void ValueNode::set_debug_info(const NodeDebugInfoPtr &debug_info) {
767   MS_EXCEPTION_IF_NULL(debug_info);
768   debug_info->set_type_name(type_name());
769   debug_info->set_debug_name(std::string());  // Clear cached debug name.
770   debug_info_ = debug_info;
771 }
772 
set_func_graph(const FuncGraphPtr &)773 void ValueNode::set_func_graph(const FuncGraphPtr &) {
774   MS_INTERNAL_EXCEPTION(ValueError) << "ValueNode should not set its func_graph.";
775 }
776 
set_value(const ValuePtr & value)777 void ValueNode::set_value(const ValuePtr &value) {
778   MS_EXCEPTION_IF_NULL(value);
779   if (value->ContainsValueAny()) {
780     MS_LOG(INTERNAL_EXCEPTION) << "Value of value node cannot be ValueAny. Value: " << value->ToString();
781   }
782   value_ = value;
783 }
784 
value() const785 const ValuePtr &ValueNode::value() const { return value_; }
786 
set_has_new_value(bool flag)787 void ValueNode::set_has_new_value(bool flag) { has_new_value_ = flag; }
788 
has_new_value() const789 bool ValueNode::has_new_value() const { return has_new_value_; }
790 
used_graph_count() const791 size_t ValueNode::used_graph_count() const { return used_graph_count_; }
792 
set_fracz_group(int64_t group)793 void ValueNode::set_fracz_group(int64_t group) { format_attr_.fracz_group = group; }
794 
fracz_group() const795 int64_t ValueNode::fracz_group() const { return format_attr_.fracz_group; }
796 
set_used_graph_count(size_t used_graph_count)797 void ValueNode::set_used_graph_count(size_t used_graph_count) { used_graph_count_ = used_graph_count; }
798 
DebugString(bool recursive) const799 std::string ValueNode::DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); }
800 
operator ==(const AnfNode & other) const801 bool ValueNode::operator==(const AnfNode &other) const {
802   if (!other.isa<ValueNode>()) {
803     return false;
804   }
805   auto &v = static_cast<const ValueNode &>(other);
806   return *v.value() == *value();
807 }
808 
ToString() const809 std::string ValueNode::ToString() const {
810   MS_EXCEPTION_IF_NULL(value_);
811   if (value_->isa<FuncGraph>()) {
812     return value_->ToString();
813   }
814   std::ostringstream buffer;
815   buffer << AnfNode::ToString();
816   buffer << "(" << value_->ToString() << ")";
817   return buffer.str();
818 }
819 
DebugString(int) const820 std::string ValueNode::DebugString(int) const {
821   MS_EXCEPTION_IF_NULL(value_);
822   std::ostringstream buffer;
823   buffer << "ValueNode<" << value_->type_name() << "> " << value_->ToString();
824   return buffer.str();
825 }
826 
fullname_with_scope()827 std::string ValueNode::fullname_with_scope() {
828   if (!fullname_with_scope_.empty()) {
829     return fullname_with_scope_;
830   }
831 
832   MS_EXCEPTION_IF_NULL(scope());
833   fullname_with_scope_ = scope()->name() + "/" + "data-";
834   fullname_with_scope_ += id_generator::get_id(fullname_with_scope_);
835   return fullname_with_scope_;
836 }
837 
IsPrimitiveCNode(const AnfNodePtr & node,const PrimitivePtr & value)838 bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
839   auto cnode = dyn_cast_ptr<CNode>(node);
840   if (cnode == nullptr || cnode->size() == 0) {
841     return false;
842   }
843   const auto &input = cnode->input(0);
844   MS_EXCEPTION_IF_NULL(input);
845   auto prim = GetValuePtr<Primitive>(input);
846   if (prim == nullptr) {
847     return false;
848   }
849   return (value == nullptr) || ((prim->Hash() == value->Hash()) && (prim->name() == value->name()));
850 }
851 
IsPrimitiveCNodeWithoutDoSignature(const AnfNodePtr & node,const PrimitivePtr & value)852 bool IsPrimitiveCNodeWithoutDoSignature(const AnfNodePtr &node, const PrimitivePtr &value) {
853   auto prim = GetCNodePrimitiveWithoutDoSignature(node);
854   if (prim == nullptr) {
855     return false;
856   }
857   return (value == nullptr) || ((prim->Hash() == value->Hash()) && (prim->name() == value->name()));
858 }
859 
GetCNodePrimitive(const AnfNodePtr & node)860 PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
861   auto cnode = dyn_cast_ptr<CNode>(node);
862   if (cnode == nullptr || cnode->size() == 0) {
863     return nullptr;
864   }
865   return GetValueNode<PrimitivePtr>(cnode->input(0));
866 }
867 
868 // Return the function Primitive if DoSignaturePrimitive,
869 // otherwise return the Primitive directly.
GetPrimitiveWithoutDoSignature(const AnfNodePtr & node)870 PrimitivePtr GetPrimitiveWithoutDoSignature(const AnfNodePtr &node) {
871   auto do_signature_prim = GetValuePtr<prim::DoSignaturePrimitive>(node);
872   if (do_signature_prim != nullptr) {
873     return dyn_cast<Primitive>(do_signature_prim->function());
874   }
875   return GetValueNode<PrimitivePtr>(node);
876 }
877 
878 // Check the first input of CNode.
879 // Return the function Primitive if DoSignaturePrimitive,
880 // otherwise return the Primitive directly.
GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr & node)881 PrimitivePtr GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr &node) {
882   auto cnode = dyn_cast_ptr<CNode>(node);
883   if (cnode == nullptr || cnode->size() == 0) {
884     return nullptr;
885   }
886   return GetPrimitiveWithoutDoSignature(cnode->input(0));
887 }
888 
889 // Return the function value if DoSignaturePrimitive,
890 // otherwise return the value directly.
GetValueWithoutDoSignature(const ValuePtr & value)891 ValuePtr GetValueWithoutDoSignature(const ValuePtr &value) {
892   auto do_signature_prim = dyn_cast_ptr<prim::DoSignaturePrimitive>(value);
893   if (do_signature_prim != nullptr) {
894     return do_signature_prim->function();
895   }
896   return value;
897 }
898 
899 // Return the function value if DoSignaturePrimitive,
900 // otherwise return the value directly.
GetValueWithoutDoSignature(const AnfNodePtr & node)901 ValuePtr GetValueWithoutDoSignature(const AnfNodePtr &node) {
902   auto value = GetValueNode(node);
903   if (value == nullptr) {
904     return nullptr;
905   }
906   return GetValueWithoutDoSignature(value);
907 }
908 
909 // Check the first input of CNode.
910 // Return the function value if DoSignaturePrimitive,
911 // otherwise return the value directly.
GetCNodeValueWithoutDoSignature(const AnfNodePtr & node)912 ValuePtr GetCNodeValueWithoutDoSignature(const AnfNodePtr &node) {
913   auto cnode = dyn_cast_ptr<CNode>(node);
914   if (cnode == nullptr || cnode->size() == 0) {
915     return nullptr;
916   }
917   return GetValueWithoutDoSignature(cnode->input(0));
918 }
919 
GetCNodeFuncName(const CNodePtr & cnode)920 std::string GetCNodeFuncName(const CNodePtr &cnode) {
921   MS_EXCEPTION_IF_NULL(cnode);
922   if (cnode->inputs().empty()) {
923     return "";
924   }
925 
926   AnfNodePtr valuenode = cnode->input(0);
927   auto value = GetValuePtr(valuenode);
928   if (value == nullptr) {
929     return "";
930   }
931   auto prim = value->cast_ptr<Primitive>();
932   if (prim != nullptr) {
933     return prim->name();
934   }
935   return value->ToString();
936 }
937 
GetCNodeFuncGraph(const AnfNodePtr & node)938 FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node) {
939   auto cnode = dyn_cast_ptr<CNode>(node);
940   if (cnode != nullptr && cnode->size() > 0) {
941     return GetValueNode<FuncGraphPtr>(cnode->input(0));
942   }
943   return nullptr;
944 }
945 
IsPrimitive(const AnfNodePtr & node,const PrimitivePtr & value)946 bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
947   if (IsValueNode<Primitive>(node)) {
948     auto prim = GetValuePtr<Primitive>(node);
949     MS_EXCEPTION_IF_NULL(value);
950     if (prim->Hash() == value->Hash() && prim->name() == value->name()) {
951       return true;
952     }
953   }
954   return false;
955 }
956 
IsPrimitiveEquals(const PrimitivePtr & prim1,const PrimitivePtr & prim2)957 bool IsPrimitiveEquals(const PrimitivePtr &prim1, const PrimitivePtr &prim2) {
958   if (prim1 == nullptr || prim2 == nullptr) {
959     return false;
960   }
961   return (prim1 == prim2) || (prim1->Hash() == prim2->Hash() && prim1->name() == prim2->name());
962 }
963 
GetAbstractMonadNum(const AbstractBasePtrList & args)964 size_t GetAbstractMonadNum(const AbstractBasePtrList &args) {
965   size_t num = 0;
966   for (auto &arg : args) {
967     if (arg->isa<abstract::AbstractMonad>()) {
968       ++num;
969     }
970   }
971   return num;
972 }
973 
974 template <typename T>
HasAbstract(const AnfNodePtr & node)975 bool HasAbstract(const AnfNodePtr &node) {
976   if (node == nullptr) {
977     return false;
978   }
979   const auto &abs = node->abstract();
980   return (abs != nullptr && abs->isa<T>());
981 }
982 
HasAbstractMonad(const AnfNodePtr & node)983 bool HasAbstractMonad(const AnfNodePtr &node) { return HasAbstract<abstract::AbstractMonad>(node); }
984 
HasAbstractUMonad(const AnfNodePtr & node)985 bool HasAbstractUMonad(const AnfNodePtr &node) { return HasAbstract<abstract::AbstractUMonad>(node); }
986 
HasAbstractIOMonad(const AnfNodePtr & node)987 bool HasAbstractIOMonad(const AnfNodePtr &node) { return HasAbstract<abstract::AbstractIOMonad>(node); }
988 
GetPrimitiveFlag(const PrimitivePtr & prim,const std::string & attr)989 bool GetPrimitiveFlag(const PrimitivePtr &prim, const std::string &attr) {
990   if (prim != nullptr) {
991     auto flag = prim->GetAttr(attr);
992     if (flag && flag->isa<BoolImm>()) {
993       return GetValue<bool>(flag);
994     }
995   }
996   return false;
997 }
998 
GetPrimEffectInfo(const PrimitivePtr & prim)999 EffectInfo GetPrimEffectInfo(const PrimitivePtr &prim) {
1000   bool mem = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_MEM);
1001   bool io = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_IO);
1002   bool back_mem = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP_MEM);
1003   return {EffectInfo::kDetected, mem, io, false, back_mem};
1004 }
1005 
GetLoadInputs(const AnfNodePtr & node)1006 std::set<CNodePtr> GetLoadInputs(const AnfNodePtr &node) {
1007   std::set<CNodePtr> loads;
1008   auto cnode = dyn_cast_ptr<CNode>(node);
1009   if (cnode == nullptr) {
1010     return loads;
1011   }
1012   auto &inputs = cnode->weak_inputs();
1013   for (size_t i = 1; i < inputs.size(); ++i) {
1014     const auto &input = inputs.at(i).lock();
1015     if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
1016       loads.insert(input->cast<CNodePtr>());
1017     } else if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
1018       loads.merge(GetLoadInputs(input));
1019     }
1020   }
1021   return loads;
1022 }
1023 
IsStateEquivalent(const AnfNodePtr & outer,const AnfNodePtr & inner)1024 bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner) {
1025   constexpr size_t kMonadInput = 2;
1026   auto outer_loads = GetLoadInputs(outer);
1027   if (outer_loads.empty()) {
1028     return true;
1029   }
1030   auto inner_loads = GetLoadInputs(inner);
1031   if (inner_loads.empty()) {
1032     return true;
1033   }
1034   outer_loads.merge(inner_loads);
1035   const auto &monad = (*outer_loads.begin())->weak_inputs().at(kMonadInput).lock();
1036   return std::all_of(++outer_loads.begin(), outer_loads.end(), [&monad, kMonadInput](const CNodePtr &load) {
1037     return load->weak_inputs().at(kMonadInput).lock() == monad;
1038   });
1039 }
1040 
1041 // Check if the node is DeadNode.
IsDeadNode(const AnfNodePtr & node)1042 bool IsDeadNode(const AnfNodePtr &node) {
1043   auto value = GetValuePtr<ValueProblem>(node);
1044   return (value != nullptr) && (value->IsDead());
1045 }
1046 
1047 // Check if the node is PolyNode.
IsPolyNode(const AnfNodePtr & node)1048 bool IsPolyNode(const AnfNodePtr &node) {
1049   auto value = GetValuePtr<ValueProblem>(node);
1050   return (value != nullptr) && (value->IsPoly());
1051 }
1052 
NewSeenGeneration()1053 SeenNum NewSeenGeneration() {
1054   static SeenNum seen_generation = 0;
1055   ++seen_generation;
1056   // 0 is invalid number.
1057   if (seen_generation == 0) {
1058     ++seen_generation;
1059   }
1060   return seen_generation;
1061 }
1062 
1063 namespace id_generator {
1064 static mindspore::HashMap<std::string, int> node_ids;
1065 static int offset = 0;
get_id(const string & scope_front_info)1066 std::string get_id(const string &scope_front_info) {
1067   if (node_ids.find(scope_front_info) == node_ids.end()) {
1068     node_ids[scope_front_info] = 0;
1069   } else {
1070     node_ids[scope_front_info]++;
1071   }
1072   std::string base_id = std::to_string(node_ids[scope_front_info]);
1073   // The id with offset means the user called reset_id_with_offset() and expect the operated id generated from 0 with an
1074   // identified offset.
1075   if (offset != 0) {
1076     return base_id + '_' + std::to_string(offset);
1077   }
1078   return base_id;
1079 }
1080 
reset_id()1081 void reset_id() { node_ids.clear(); }
1082 
reset_id_with_offset()1083 void reset_id_with_offset() {
1084   node_ids.clear();
1085   offset++;
1086 }
1087 }  // namespace id_generator
1088 auto constexpr kPrimitiveTarget = "primitive_target";
1089 namespace {
GetPrimitiveFromValueNode(const AnfNodePtr & node)1090 PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) {
1091   auto value_node = dyn_cast_ptr<ValueNode>(node);
1092   if (value_node == nullptr) {
1093     return nullptr;
1094   }
1095   return dyn_cast<Primitive>(value_node->value());
1096 }
1097 
GetNodeTargetForVarInputNode(const CNodePtr & cnode)1098 static std::string GetNodeTargetForVarInputNode(const CNodePtr &cnode) {
1099   auto &inputs = cnode->inputs();
1100   AnfNodeWeakPtrList real_inputs;
1101   const size_t update_state_valid_input_index = 2;
1102   const size_t make_tuple_valid_input_index = 1;
1103   if (cnode->IsApply(prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) {
1104     (void)std::copy(inputs.begin() + SizeToLong(update_state_valid_input_index), inputs.end(),
1105                     std::back_inserter(real_inputs));
1106   } else if (cnode->IsApply(prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) {
1107     (void)std::copy(inputs.begin() + SizeToLong(make_tuple_valid_input_index), inputs.end(),
1108                     std::back_inserter(real_inputs));
1109   }
1110   std::string first_input_target = kDeviceUnDefined;
1111   bool has_diff_target =
1112     std::any_of(std::rbegin(real_inputs), std::rend(real_inputs), [&first_input_target](const AnfNodeWeakPtr &n) {
1113       auto target = GetOriginNodeTarget(n.lock());
1114       if (target == kDeviceUnDefined) {
1115         return false;
1116       }
1117       if (first_input_target == kDeviceUnDefined) {
1118         first_input_target = target;
1119       }
1120       return target != first_input_target;
1121     });
1122   if (!has_diff_target) {
1123     return first_input_target;
1124   }
1125   return kDeviceUnDefined;
1126 }
1127 
IsSummaryPrimitiveCNode(const AnfNodePtr & node)1128 static inline bool IsSummaryPrimitiveCNode(const AnfNodePtr &node) {
1129   return IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
1130          IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary);
1131 }
1132 
GetVirtualNodeTargetFromInputs(const AnfNodePtr & node)1133 std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
1134   MS_EXCEPTION_IF_NULL(node);
1135   auto cnode = node->cast_ptr<CNode>();
1136   MS_EXCEPTION_IF_NULL(cnode);
1137   auto &weak_inputs = cnode->weak_inputs();
1138 #ifndef ENABLE_SECURITY
1139   if (IsSummaryPrimitiveCNode(node)) {
1140     if (weak_inputs.size() > 1) {
1141       return GetOriginNodeTarget(weak_inputs[1].lock());
1142     }
1143     return kDeviceUnDefined;
1144   }
1145 #endif
1146   if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
1147     const size_t node_inputs_num = 3;
1148     if (weak_inputs.size() >= node_inputs_num) {
1149       size_t use_index = 1;
1150       auto use_node = weak_inputs[use_index].lock();
1151       MS_EXCEPTION_IF_NULL(use_node);
1152       if (!use_node->isa<CNode>()) {
1153         use_index = 2;
1154         use_node = weak_inputs[use_index].lock();
1155         MS_EXCEPTION_IF_NULL(use_node);
1156       }
1157       return GetOriginNodeTarget(use_node);
1158     }
1159   } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
1160     return GetNodeTargetForVarInputNode(node->cast<CNodePtr>());
1161   } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1162     return GetOriginNodeTarget(cnode->input(1));
1163   }
1164   return kDeviceUnDefined;
1165 }
1166 
GetVirtualNodeTargetFromUsers(const AnfNodePtr & node)1167 std::string GetVirtualNodeTargetFromUsers(const AnfNodePtr &node) {
1168   MS_EXCEPTION_IF_NULL(node);
1169   auto cnode = node->cast<CNodePtr>();
1170   MS_EXCEPTION_IF_NULL(cnode);
1171   auto func_graph = cnode->func_graph();
1172   if (func_graph == nullptr) {
1173     return kDeviceUnDefined;
1174   }
1175   auto manager = func_graph->manager();
1176   if (manager == nullptr) {
1177     return kDeviceUnDefined;
1178   }
1179   auto users = manager->node_users()[cnode];
1180   std::string first_user_target = kDeviceUnDefined;
1181   bool has_diff_target =
1182     std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) {
1183       auto target = GetOriginNodeTarget(u.first);
1184       if (target == kDeviceUnDefined) {
1185         return false;
1186       }
1187       if (first_user_target == kDeviceUnDefined) {
1188         first_user_target = target;
1189       }
1190       return target != first_user_target;
1191     });
1192   if (!has_diff_target) {
1193     return first_user_target;
1194   }
1195   return kDeviceUnDefined;
1196 }
1197 
GetVirtualNodeTarget(const AnfNodePtr & node)1198 std::string GetVirtualNodeTarget(const AnfNodePtr &node) {
1199   MS_EXCEPTION_IF_NULL(node);
1200   node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(kDeviceUnDefined));
1201   auto target = GetVirtualNodeTargetFromInputs(node);
1202   node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(target));
1203   if (target != kDeviceUnDefined) {
1204     return target;
1205   }
1206   target = GetVirtualNodeTargetFromUsers(node);
1207   node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(target));
1208   return target;
1209 }
1210 
GetTargetFromAttr(const AnfNodePtr & node)1211 std::string GetTargetFromAttr(const AnfNodePtr &node) {
1212   MS_EXCEPTION_IF_NULL(node);
1213   auto cnode = node->cast_ptr<CNode>();
1214   MS_EXCEPTION_IF_NULL(cnode);
1215   auto attr_input = cnode->input(0);
1216   auto primitive = GetPrimitiveFromValueNode(attr_input);
1217   if (primitive == nullptr) {
1218     return kDeviceUnDefined;
1219   }
1220   auto att_target = primitive->GetAttr(kPrimitiveTarget);
1221   if (att_target != nullptr) {
1222     if (!att_target->isa<StringImm>()) {
1223       MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
1224     }
1225     auto target = GetValue<std::string>(att_target);
1226     if (kTargetSet.find(target) == kTargetSet.end()) {
1227       MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target;
1228     }
1229     return target;
1230   }
1231   return kDeviceUnDefined;
1232 }
1233 }  // namespace
1234 
GetOriginNodeTarget(const AnfNodePtr & node)1235 std::string GetOriginNodeTarget(const AnfNodePtr &node) {
1236   MS_EXCEPTION_IF_NULL(node);
1237   if (!node->isa<CNode>()) {
1238     return kDeviceUnDefined;
1239   }
1240   auto cnode = node->cast_ptr<CNode>();
1241   MS_EXCEPTION_IF_NULL(cnode);
1242   auto ud_target = cnode->user_data<std::string>(kPrimitiveTarget);
1243   if (ud_target != nullptr) {
1244     return *ud_target.get();
1245   }
1246   auto target = GetTargetFromAttr(node);
1247   if (target != kDeviceUnDefined) {
1248     return target;
1249   }
1250 #ifndef ENABLE_SECURITY
1251   if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
1252       IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) ||
1253       IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
1254       IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
1255       IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1256     return GetVirtualNodeTarget(node);
1257   }
1258 #else
1259   if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
1260       IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
1261       IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1262     return GetVirtualNodeTarget(node);
1263   }
1264 #endif
1265   auto context_ptr = MsContext::GetInstance();
1266   MS_EXCEPTION_IF_NULL(context_ptr);
1267   return context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1268 }
1269 
GetCNodeTarget(const AnfNodePtr & node)1270 std::string GetCNodeTarget(const AnfNodePtr &node) {
1271   auto kernel_info = node->kernel_info();
1272   if (kernel_info != nullptr) {
1273     auto runtime_cache = kernel_info->runtime_cache();
1274     if (runtime_cache.runtime_cache().is_valid()) {
1275       auto tmp_target = runtime_cache.runtime_cache().device_target();
1276       if (!tmp_target.empty()) {
1277         return tmp_target;
1278       }
1279     }
1280   }
1281 
1282   std::string target;
1283   auto ori_target = GetOriginNodeTarget(node);
1284   if (ori_target != kDeviceUnDefined) {
1285     target = ori_target;
1286   } else {
1287     auto context_ptr = MsContext::GetInstance();
1288     MS_EXCEPTION_IF_NULL(context_ptr);
1289     target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1290   }
1291 
1292   if (kernel_info != nullptr) {
1293     auto runtime_cache = kernel_info->runtime_cache();
1294     if (runtime_cache.runtime_cache().is_valid()) {
1295       runtime_cache.runtime_cache().set_device_target(target);
1296     }
1297   }
1298   return target;
1299 }
1300 
ContainMultiTarget(const AnfNodePtrList & nodes)1301 bool ContainMultiTarget(const AnfNodePtrList &nodes) {
1302   auto context_ptr = MsContext::GetInstance();
1303   MS_EXCEPTION_IF_NULL(context_ptr);
1304   std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1305   for (auto &node : nodes) {
1306     if (node->isa<CNode>()) {
1307       std::string cur_target = GetCNodeTarget(node);
1308       if (last_target != cur_target) {
1309         return true;
1310       }
1311       last_target = cur_target;
1312     }
1313   }
1314   return false;
1315 }
1316 
IsOneOfPrimitive(const AnfNodePtr & node,const PrimitiveSet & prim_set)1317 bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
1318   PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
1319   return (prim && prim_set.find(prim) != prim_set.end());
1320 }
1321 
IsOneOfPrimitiveCNode(const AnfNodePtr & node,const PrimitiveSet & prim_set)1322 bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
1323   MS_EXCEPTION_IF_NULL(node);
1324   auto cnode = node->cast_ptr<CNode>();
1325   if (cnode == nullptr || cnode->size() == 0) {
1326     return false;
1327   }
1328   return IsOneOfPrimitive(cnode->input(0), prim_set);
1329 }
1330 
1331 // Set the sequence nodes' elements use flags to 'new_flag' at specific 'index' position.
SetSequenceElementsUseFlags(const AbstractBasePtr & abs,std::size_t index,bool new_flag)1332 void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, std::size_t index, bool new_flag) {
1333   static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1334   if (!enable_eliminate_unused_element) {
1335     return;
1336   }
1337 
1338   auto sequence_abs = dyn_cast_ptr<abstract::AbstractSequence>(abs);
1339   if (sequence_abs == nullptr) {
1340     return;
1341   }
1342   if (sequence_abs->sequence_nodes() == nullptr || sequence_abs->sequence_nodes()->empty()) {
1343     return;
1344   }
1345   for (auto &node : *sequence_abs->sequence_nodes()) {
1346     auto sequence_node = node.lock();
1347     if (sequence_node == nullptr) {
1348       MS_LOG(DEBUG) << "The node in sequence_nodes is free.";
1349       continue;
1350     }
1351     auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
1352     if (flags == nullptr) {
1353       continue;
1354     }
1355     if (index >= flags->size()) {
1356       MS_LOG(ERROR) << "The index " << index << " is out of range, size is " << flags->size() << ", for "
1357                     << sequence_node->DebugString();
1358       return;
1359     }
1360     (*flags)[index] = new_flag;
1361     MS_LOG(DEBUG) << "Set item[" << index << "] use flag as " << new_flag << ", for " << sequence_node->DebugString();
1362   }
1363 }
1364 
1365 // Set the sequence nodes' elements use flags all to 'new_flag'.
SetSequenceElementsUseFlags(const AbstractBasePtr & abs,bool new_flag)1366 void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, bool new_flag) {
1367   static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1368   if (!enable_eliminate_unused_element) {
1369     return;
1370   }
1371 
1372   auto sequence_abs = dyn_cast_ptr<abstract::AbstractSequence>(abs);
1373   if (sequence_abs == nullptr) {
1374     return;
1375   }
1376   if (sequence_abs->sequence_nodes() == nullptr || sequence_abs->sequence_nodes()->empty()) {
1377     return;
1378   }
1379   for (auto &weak_node : *sequence_abs->sequence_nodes()) {
1380     auto sequence_node = weak_node.lock();
1381     if (sequence_node == nullptr) {
1382       MS_LOG(DEBUG) << "The node in sequence_nodes is free.";
1383       continue;
1384     }
1385     auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
1386     if (flags != nullptr) {
1387       auto &all_flags = (*flags);
1388       (void)std::transform(all_flags.begin(), all_flags.end(), all_flags.begin(),
1389                            [&new_flag](bool) -> bool { return new_flag; });
1390     }
1391   }
1392 }
1393 
1394 // Set the sequence nodes' elements use flags all to 'new_flag' recursively.
SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr & abs,bool new_flag)1395 void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new_flag) {
1396   static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1397   if (!enable_eliminate_unused_element) {
1398     return;
1399   }
1400 
1401   SetSequenceElementsUseFlags(abs, new_flag);
1402 
1403   // Check its elements if it's a sequence node.
1404   auto sequence_abs = dyn_cast_ptr<abstract::AbstractSequence>(abs);
1405   if (sequence_abs != nullptr) {
1406     for (auto &element : sequence_abs->elements()) {
1407       SetSequenceElementsUseFlagsRecursively(element, new_flag);
1408     }
1409     return;
1410   }
1411 
1412   // Check its elements if it's a dictionary node.
1413   auto dictionary_abs = dyn_cast_ptr<abstract::AbstractDictionary>(abs);
1414   if (dictionary_abs != nullptr) {
1415     for (auto &element : dictionary_abs->elements()) {
1416       SetSequenceElementsUseFlagsRecursively(element.second, new_flag);
1417     }
1418   }
1419 }
1420 }  // namespace mindspore
1421