1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2022 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "ir/anf.h"
20
21 #include <algorithm>
22 #include <sstream>
23 #include <vector>
24 #include <queue>
25
26 #include "mindspore/core/ops/structure_ops.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "utils/hash_map.h"
30 #include "ir/func_graph.h"
31 #include "ir/primitive.h"
32 #include "utils/ms_context.h"
33 #include "utils/anf_utils.h"
34 #include "utils/compile_config.h"
35 #include "ops/op_def.h"
36
37 namespace mindspore {
38 namespace {
IsNeedCheckPrimitiveNode(const AnfNodePtr & prim_node)39 std::pair<bool, PrimitivePtr> IsNeedCheckPrimitiveNode(const AnfNodePtr &prim_node) {
40 if (!IsValueNode<Primitive>(prim_node)) {
41 return {false, nullptr};
42 }
43 auto prim = GetValueNode<PrimitivePtr>(prim_node);
44 MS_EXCEPTION_IF_NULL(prim);
45 if (prim->IsPythonPrim() || prim->HasAttr(kSkipCheckInputNum)) {
46 return {false, nullptr};
47 }
48 if (prim->GetAttr("primitive_function") == nullptr) {
49 return {false, nullptr};
50 }
51 auto op_def = mindspore::ops::GetOpDef(prim->name());
52 if (op_def == nullptr) {
53 return {false, nullptr};
54 }
55
56 return {true, prim};
57 }
58
PrintErrorInfo(const AnfNodeWeakPtrList & inputs,const PrimitivePtr & prim,size_t input_tensor_num,const ops::OpDefPtr & op_def)59 void PrintErrorInfo(const AnfNodeWeakPtrList &inputs, const PrimitivePtr &prim, size_t input_tensor_num,
60 const ops::OpDefPtr &op_def) {
61 std::stringstream ss;
62 size_t i = 0;
63 ss << "Inputs are as follows: \n";
64 for (const auto &input : inputs) {
65 ss << "Input[" << i++ << "]: " << input.lock()->DebugString() << "\n";
66 }
67 MS_LOG(DEBUG) << "Primitive<" << prim->name() << "> inputs num: " << input_tensor_num
68 << " is not equal to expect input num: " << op_def->args_.size() << "\n"
69 << ss.str();
70 }
71
CheckCNodeInputsNum(const AnfNodeWeakPtrList & inputs)72 void CheckCNodeInputsNum(const AnfNodeWeakPtrList &inputs) {
73 if (!IS_OUTPUT_ON(mindspore::kDebug) || inputs.empty()) {
74 return;
75 }
76
77 auto [need_check, prim] = IsNeedCheckPrimitiveNode(inputs[0].lock());
78 if (!need_check) {
79 return;
80 }
81
82 auto op_def = mindspore::ops::GetOpDef(prim->name());
83 if (op_def == nullptr) {
84 return;
85 }
86 bool input_num_err = false;
87 constexpr size_t prim_num = 1;
88 size_t input_tensor_num = inputs.size() - prim_num;
89 if (prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_MEM) || prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_IO)) {
90 size_t monad_num =
91 static_cast<size_t>(std::count_if(inputs.cbegin() + 1, inputs.end(), [](const AnfNodeWeakPtr &weak_input) {
92 const auto &input = weak_input.lock();
93 return HasAbstractMonad(input) || (IsPrimitiveCNode(input, prim::kPrimUpdateState) ||
94 IsValueNode<UMonad>(input) || IsValueNode<IOMonad>(input));
95 }));
96 // If monad input is parameter_monad, monad num is 0, actual monad num should be 1. And monad num is 0 if monad
97 // pass has not been executed.
98 if (monad_num == 0) {
99 auto monad_pass_not_executed_check_failed = op_def->args_.size() != input_tensor_num;
100 constexpr auto parameter_monad_num = 1;
101 auto exist_parameter_monad_check_failed = op_def->args_.size() != input_tensor_num - parameter_monad_num;
102 if (monad_pass_not_executed_check_failed && exist_parameter_monad_check_failed) {
103 input_num_err = true;
104 }
105 } else if (monad_num == 1) {
106 if (op_def->args_.size() != input_tensor_num - monad_num) {
107 input_num_err = true;
108 }
109 } else {
110 MS_LOG(INTERNAL_EXCEPTION) << "Get unexpected monad num: " << monad_num;
111 }
112 } else {
113 if (op_def->args_.size() != input_tensor_num) {
114 input_num_err = true;
115 }
116 }
117 if (input_num_err) {
118 PrintErrorInfo(inputs, prim, input_tensor_num, op_def);
119 }
120 }
121 } // namespace
AnfNode(const FuncGraphPtr & func_graph,NodeDebugInfoPtr && debug_info)122 AnfNode::AnfNode(const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info)
123 : func_graph_(FuncGraphWeakPtr(func_graph)),
124 abstract_(nullptr),
125 debug_info_(std::move(debug_info)),
126 fullname_with_scope_(""),
127 scope_(ScopeManager::GetInstance().GetCurrentScope()) {}
128
AnfNode(const FuncGraphPtr & func_graph)129 AnfNode::AnfNode(const FuncGraphPtr &func_graph) : AnfNode(func_graph, std::make_shared<NodeDebugInfo>()) {}
130
accept(AnfIrVisitor *)131 void AnfNode::accept(AnfIrVisitor *) {}
132
func_graph() const133 FuncGraphPtr AnfNode::func_graph() const { return func_graph_.lock(); }
134
set_func_graph(const FuncGraphPtr & func_graph)135 void AnfNode::set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
136
scope()137 ScopePtr AnfNode::scope() { return scope_; }
138
set_scope(const ScopePtr & scope)139 void AnfNode::set_scope(const ScopePtr &scope) { scope_ = scope; }
140
kernel_info() const141 const KernelInfoDevice *AnfNode::kernel_info() const { return kernel_info_ptr().get(); }
142
kernel_info()143 KernelInfoDevice *AnfNode::kernel_info() { return kernel_info_ptr().get(); }
144
kernel_info_ptr() const145 KernelInfoDevicePtr AnfNode::kernel_info_ptr() const { return user_data<KernelInfoDevice>(kKernelInfoKey); }
146
set_kernel_info(const KernelInfoDevicePtr & kernel_info)147 void AnfNode::set_kernel_info(const KernelInfoDevicePtr &kernel_info) { set_user_data(kKernelInfoKey, kernel_info); }
148
debug_info()149 NodeDebugInfoPtr AnfNode::debug_info() {
150 MS_EXCEPTION_IF_NULL(debug_info_);
151 return debug_info_;
152 }
153
set_debug_info(const NodeDebugInfoPtr & debug_info)154 void AnfNode::set_debug_info(const NodeDebugInfoPtr &debug_info) {
155 MS_EXCEPTION_IF_NULL(debug_info);
156 debug_info->set_type_name(type_name());
157 debug_info->set_debug_name(std::string()); // Clear cached debug name.
158 debug_info_ = debug_info;
159 }
160
hash() const161 std::size_t AnfNode::hash() const { return PointerHash<AnfNode>{}(this); }
162
fullname_with_scope()163 std::string AnfNode::fullname_with_scope() { return ""; }
164
UniqueName()165 std::string AnfNode::UniqueName() { return fullname_with_scope() + "_" + UniqueId(); }
166
DebugString(int recursive_level) const167 std::string AnfNode::DebugString(int recursive_level) const { return ToString(); }
168
DebugString(bool recursive) const169 std::string AnfNode::DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); }
170
dump() const171 void AnfNode::dump() const { std::cout << DebugString() << std::endl; }
172
UniqueId()173 std::string AnfNode::UniqueId() { return std::to_string(debug_info()->unique_id()); }
174
UniqueIdThroughCopy()175 std::string AnfNode::UniqueIdThroughCopy() { return std::to_string(debug_info()->unique_id_through_copy()); }
176
operator ==(const AnfNode & other) const177 bool AnfNode::operator==(const AnfNode &other) const { return &other == this; }
178
operator <<(std::ostream & os,const AnfNode & node)179 std::ostream &operator<<(std::ostream &os, const AnfNode &node) {
180 os << node.ToString();
181 return os;
182 }
183
interpret() const184 bool AnfNode::interpret() const { return interpret_flags_[kInterpret]; }
185
set_interpret(const bool & interpret)186 void AnfNode::set_interpret(const bool &interpret) { interpret_flags_[kInterpret] = interpret; }
187
interpret_internal_type()188 bool AnfNode::interpret_internal_type() { return interpret_flags_[kInterpretInternalType]; }
189
set_interpret_internal_type(const bool & interpret_internal_type)190 void AnfNode::set_interpret_internal_type(const bool &interpret_internal_type) {
191 interpret_flags_[kInterpretInternalType] = interpret_internal_type;
192 }
193
abstract() const194 const AbstractBasePtr &AnfNode::abstract() const {
195 // cppcheck-suppress unreadVariable
196 auto lock = AnfUtils::GetAbstractLock(this);
197 return abstract_;
198 }
199
set_abstract(const AbstractBasePtr & abs)200 void AnfNode::set_abstract(const AbstractBasePtr &abs) {
201 // cppcheck-suppress unreadVariable
202 auto lock = AnfUtils::GetAbstractLock(this);
203 abstract_ = abs;
204 }
205
CheckCNodeWeakInput()206 void CNode::CheckCNodeWeakInput() {
207 #if defined(DEBUG_CNODE_WEAK_INPUT)
208 for (size_t i = 0; i < weak_inputs_.size(); ++i) {
209 if (weak_inputs_[i].lock() == nullptr) {
210 MS_LOG(INTERNAL_EXCEPTION) << "The " << i << "th input is released.";
211 }
212 }
213 #endif // DEBUG_CNODE_WEAK_INPUT
214 }
215
216 // Namespace to support intermediate representation definition
CNode(const AnfNodePtrList & inputs,const FuncGraphPtr & func_graph)217 CNode::CNode(const AnfNodePtrList &inputs, const FuncGraphPtr &func_graph)
218 : AnfNode(func_graph),
219 primal_attrs_(PrimalAttrManager::GetInstance().GetCurrentPrimalAttr()),
220 primal_debug_infos_(PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo()) {
221 weak_inputs_.reserve(inputs.size());
222 std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs_),
223 [](const AnfNodePtr &node) -> AnfNodeWeakPtr { return AnfNodeWeakPtr(node); });
224 if (func_graph != nullptr) {
225 MS_LOG(DEBUG) << "Create new CNode, " << this << "@" << func_graph->ToString();
226 func_graph->AddOwnNode(inputs);
227 } else {
228 MS_LOG(WARNING) << "The func graph should not be null.";
229 inputs_bound_ = true;
230 inputs_ = inputs;
231 }
232 CheckCNodeInputsNum(weak_inputs_);
233 Init();
234 }
235
CNode(const AnfNodePtrList & inputs,const VarPtr & func_graph_as_var)236 CNode::CNode(const AnfNodePtrList &inputs, const VarPtr &func_graph_as_var) : AnfNode(nullptr), inputs_(inputs) {
237 inputs_bound_ = true;
238 weak_inputs_.reserve(inputs.size());
239 std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs_),
240 [](const AnfNodePtr &node) -> AnfNodeWeakPtr { return AnfNodeWeakPtr(node); });
241 primal_attrs_ = PrimalAttrManager::GetInstance().GetCurrentPrimalAttr();
242 primal_debug_infos_ = PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo();
243 set_user_data(kFuncGraphVarKey, func_graph_as_var);
244 Init();
245 }
246
CNode(AnfNodeWeakPtrList && weak_inputs,const FuncGraphPtr & func_graph)247 CNode::CNode(AnfNodeWeakPtrList &&weak_inputs, const FuncGraphPtr &func_graph)
248 : AnfNode(func_graph),
249 weak_inputs_(std::move(weak_inputs)),
250 primal_attrs_(PrimalAttrManager::GetInstance().GetCurrentPrimalAttr()),
251 primal_debug_infos_(PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo()) {
252 if (func_graph != nullptr) {
253 MS_LOG(DEBUG) << "Create new CNode, " << this << "@" << func_graph->ToString();
254 func_graph->AddOwnNode(weak_inputs_);
255 } else {
256 MS_LOG(INTERNAL_EXCEPTION) << "The func graph should not be null.";
257 }
258 CheckCNodeInputsNum(weak_inputs_);
259 Init();
260 }
261
CNode(AnfNodeWeakPtrList && weak_inputs,const FuncGraphPtr & func_graph,NodeDebugInfoPtr && debug_info)262 CNode::CNode(AnfNodeWeakPtrList &&weak_inputs, const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info)
263 : AnfNode(func_graph, std::move(debug_info)),
264 weak_inputs_(std::move(weak_inputs)),
265 primal_attrs_(PrimalAttrManager::GetInstance().GetCurrentPrimalAttr()),
266 primal_debug_infos_(PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo()) {
267 if (func_graph != nullptr) {
268 MS_LOG(DEBUG) << "Create new CNode, " << this << "@" << func_graph->ToString();
269 func_graph->AddOwnNode(weak_inputs_);
270 } else {
271 MS_LOG(INTERNAL_EXCEPTION) << "The func graph should not be null.";
272 }
273 CheckCNodeInputsNum(weak_inputs_);
274 Init();
275 }
276
CNode(const AnfNodeWeakPtrList & weak_inputs,const FuncGraphPtr & func_graph)277 CNode::CNode(const AnfNodeWeakPtrList &weak_inputs, const FuncGraphPtr &func_graph)
278 : AnfNode(func_graph),
279 weak_inputs_(weak_inputs),
280 primal_attrs_(PrimalAttrManager::GetInstance().GetCurrentPrimalAttr()),
281 primal_debug_infos_(PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo()) {
282 if (func_graph != nullptr) {
283 MS_LOG(DEBUG) << "Create new CNode, " << this << "@" << func_graph->ToString();
284 func_graph->AddOwnNode(weak_inputs_);
285 } else {
286 MS_LOG(INTERNAL_EXCEPTION) << "The func graph should not be null.";
287 }
288 CheckCNodeInputsNum(weak_inputs_);
289 Init();
290 }
291
Init()292 void CNode::Init() {
293 CheckCNodeWeakInput();
294 debug_info()->set_type_name(type_name());
295 }
296
set_debug_info(const NodeDebugInfoPtr & debug_info)297 void CNode::set_debug_info(const NodeDebugInfoPtr &debug_info) {
298 MS_EXCEPTION_IF_NULL(debug_info);
299 debug_info->set_type_name(type_name());
300 debug_info->set_debug_name(std::string()); // Clear cached debug name.
301 debug_info_ = debug_info;
302 }
303
size() const304 const size_t CNode::size() const { return weak_inputs_.size(); }
305
empty() const306 const bool CNode::empty() const { return size() == 0; }
307
inputs()308 const AnfNodePtrList &CNode::inputs() {
309 // Check for 'CNode(const AnfNodePtrList &, const VarPtr &)' for compatible.
310 if (inputs_bound_) {
311 return inputs_;
312 }
313
314 inputs_.clear();
315 inputs_.reserve(weak_inputs_.size());
316 std::transform(weak_inputs_.cbegin(), weak_inputs_.cend(), std::back_inserter(inputs_),
317 [this](const AnfNodeWeakPtr &weak_node) -> AnfNodePtr {
318 auto node = weak_node.lock();
319 #if defined(DEBUG_CNODE_WEAK_INPUT)
320 if (node == nullptr) {
321 MS_LOG(INTERNAL_EXCEPTION) << "Weak input lock failed, " << DebugString();
322 }
323 #endif // DEBUG_CNODE_WEAK_INPUT
324 return node;
325 });
326 inputs_bound_ = true;
327 return inputs_;
328 }
329
weak_inputs() const330 const AnfNodeWeakPtrList &CNode::weak_inputs() const { return weak_inputs_; }
331
332 // Check if CNode is an apply with the specific Primitive.
IsApply(const PrimitivePtr & value) const333 bool CNode::IsApply(const PrimitivePtr &value) const {
334 if (value == nullptr || weak_inputs_.empty()) {
335 return false;
336 }
337 auto prim = GetValuePtr<Primitive>(weak_inputs_[0].lock());
338 return (prim != nullptr) && (prim->Hash() == value->Hash()) && (prim->name() == value->name());
339 }
340
add_input(const AnfNodePtr & input)341 void CNode::add_input(const AnfNodePtr &input) {
342 if (inputs_bound_) {
343 inputs_.emplace_back(input);
344 }
345
346 (void)weak_inputs_.emplace_back(AnfNodeWeakPtr(input));
347 if (func_graph() != nullptr) {
348 func_graph()->AddOwnNode(input);
349 }
350
351 input_tensor_num_ = -1;
352 }
353
set_input(size_t i,const AnfNodePtr & new_input)354 void CNode::set_input(size_t i, const AnfNodePtr &new_input) {
355 if (inputs_bound_) {
356 if (i >= inputs_.size()) {
357 MS_LOG(INTERNAL_EXCEPTION) << "i: " << i << " out of range: " << weak_inputs_.size()
358 << ", cnode: " << DebugString();
359 }
360 inputs_[i] = new_input;
361 }
362
363 if (i >= weak_inputs_.size()) {
364 MS_LOG(INTERNAL_EXCEPTION) << "i: " << i << " out of range: " << weak_inputs_.size()
365 << ", cnode: " << DebugString();
366 }
367 weak_inputs_[i] = AnfNodeWeakPtr(new_input);
368 if (func_graph() != nullptr) {
369 func_graph()->AddOwnNode(new_input);
370 }
371
372 input_tensor_num_ = -1;
373 }
374
set_inputs(const AnfNodePtrList & inputs)375 void CNode::set_inputs(const AnfNodePtrList &inputs) {
376 if (inputs_bound_) {
377 inputs_ = inputs;
378 }
379
380 weak_inputs_.clear();
381 weak_inputs_.reserve(inputs.size());
382 std::transform(inputs.cbegin(), inputs.cend(), std::back_inserter(weak_inputs_),
383 [](const AnfNodePtr &node) -> AnfNodeWeakPtr { return AnfNodeWeakPtr(node); });
384 if (func_graph() != nullptr) {
385 func_graph()->AddOwnNode(inputs);
386 }
387 input_tensor_num_ = -1;
388 CheckCNodeInputsNum(weak_inputs_);
389 }
390
set_weak_inputs(const AnfNodeWeakPtrList & weak_inputs)391 void CNode::set_weak_inputs(const AnfNodeWeakPtrList &weak_inputs) {
392 if (inputs_bound_) {
393 inputs_.clear();
394 inputs_.reserve(weak_inputs.size());
395 std::transform(weak_inputs.cbegin(), weak_inputs.cend(), std::back_inserter(inputs_),
396 [this](const AnfNodeWeakPtr &weak_node) -> AnfNodePtr {
397 auto node = weak_node.lock();
398 if (node == nullptr) {
399 MS_LOG(INTERNAL_EXCEPTION) << "Weak input lock failed, " << DebugString();
400 }
401 return node;
402 });
403 }
404
405 weak_inputs_ = weak_inputs;
406 if (func_graph() != nullptr) {
407 func_graph()->AddOwnNode(weak_inputs_);
408 }
409 input_tensor_num_ = -1;
410 CheckCNodeInputsNum(weak_inputs_);
411 }
412
input(size_t i) const413 const AnfNodePtr CNode::input(size_t i) const {
414 if (inputs_bound_) {
415 if (i >= inputs_.size()) {
416 MS_LOG(INTERNAL_EXCEPTION) << "i: " << i << " out of range: " << inputs_.size() << ", cnode: " << DebugString();
417 }
418 return inputs_.at(i);
419 }
420
421 if (i >= weak_inputs_.size()) {
422 MS_LOG(INTERNAL_EXCEPTION) << "i: " << i << " out of range: " << weak_inputs_.size()
423 << ", cnode: " << DebugString();
424 }
425 auto res = weak_inputs_.at(i).lock();
426 if (res == nullptr) {
427 MS_LOG(ERROR) << "The input[" << i << "] is released, cnode: " << DebugString();
428 }
429 return res;
430 }
431
weak_input(size_t i) const432 const AnfNodeWeakPtr &CNode::weak_input(size_t i) const {
433 if (i >= weak_inputs_.size()) {
434 MS_LOG(INTERNAL_EXCEPTION) << "i: " << i << " out of range: " << weak_inputs_.size()
435 << ", cnode: " << DebugString();
436 }
437 return weak_inputs_.at(i);
438 }
439
DebugString(int recursive_level) const440 std::string CNode::DebugString(int recursive_level) const {
441 std::ostringstream buffer;
442 if (recursive_level > 0) {
443 if (func_graph() != nullptr) {
444 buffer << "@" << func_graph()->ToString() << ":";
445 }
446 buffer << ToString() << "{";
447 bool is_first_node = true;
448 int idx = 0;
449 for (const auto &weak_node : weak_inputs_) {
450 if (is_first_node) {
451 is_first_node = false;
452 } else {
453 buffer << ", ";
454 }
455 const auto &node = weak_node.lock();
456 if (node == nullptr) {
457 buffer << "[" << idx << "]: (released_node)";
458 ++idx;
459 continue;
460 }
461 buffer << "[" << idx << "]: " << node->DebugString(recursive_level - 1);
462 ++idx;
463 }
464 buffer << "}";
465 } else {
466 buffer << ToString();
467 }
468 return buffer.str();
469 }
470
AddFusedDebugInfo(const AnfNodePtr & node)471 void CNode::AddFusedDebugInfo(const AnfNodePtr &node) {
472 if (node == nullptr || !node->isa<CNode>()) {
473 return;
474 }
475 if (shared_from_this() == node) {
476 this->AddFusedDebugInfo(node->debug_info());
477 return;
478 }
479 auto cnode = node->cast_ptr<CNode>();
480 auto node_fused_debug_infos = cnode->fused_debug_infos();
481 if (!node_fused_debug_infos.empty()) {
482 (void)std::for_each(node_fused_debug_infos.begin(), node_fused_debug_infos.end(),
483 [this](const NodeDebugInfoPtr &debug_info) { this->AddFusedDebugInfo(debug_info); });
484 } else {
485 this->AddFusedDebugInfo(cnode->debug_info());
486 }
487
488 auto primal_debug_infos = cnode->primal_debug_infos();
489 if (!primal_debug_infos.empty()) {
490 (void)std::for_each(primal_debug_infos.begin(), primal_debug_infos.end(),
491 [this](const NodeDebugInfoPtr &debug_info) { this->AddPrimalDebugInfo(debug_info); });
492 }
493 }
494
AddFusedDebugInfoList(const AnfNodePtrList & nodes)495 void CNode::AddFusedDebugInfoList(const AnfNodePtrList &nodes) {
496 (void)std::for_each(nodes.begin(), nodes.end(), [this](const AnfNodePtr &node) { this->AddFusedDebugInfo(node); });
497 }
498
AddFusedDebugInfo(const NodeDebugInfoPtr & debug_info)499 void CNode::AddFusedDebugInfo(const NodeDebugInfoPtr &debug_info) {
500 if (debug_info == nullptr) {
501 return;
502 }
503 (void)fused_debug_infos_.emplace(debug_info);
504 }
505
AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> & debug_infos)506 void CNode::AddFusedDebugInfoList(const std::vector<NodeDebugInfoPtr> &debug_infos) {
507 (void)std::for_each(debug_infos.begin(), debug_infos.end(),
508 [this](const NodeDebugInfoPtr &debug_info) { this->AddFusedDebugInfo(debug_info); });
509 }
510
primal_debug_infos() const511 NodeDebugInfoSet CNode::primal_debug_infos() const { return primal_debug_infos_; }
512
set_primal_debug_infos(const NodeDebugInfoSet & debug_infos)513 void CNode::set_primal_debug_infos(const NodeDebugInfoSet &debug_infos) {
514 (void)std::for_each(debug_infos.begin(), debug_infos.end(),
515 [this](const NodeDebugInfoPtr &debug_info) { this->AddPrimalDebugInfo(debug_info); });
516 }
517
AddPrimalDebugInfo(const NodeDebugInfoPtr & debug_info)518 void CNode::AddPrimalDebugInfo(const NodeDebugInfoPtr &debug_info) { (void)primal_debug_infos_.emplace(debug_info); }
519
set_forward(const ValueNodePtr & forward,const std::string & id)520 void CNode::set_forward(const ValueNodePtr &forward, const std::string &id) {
521 set_user_data(kOutputValueKey, std::make_shared<OutputValue>(forward, id));
522 }
523
forward() const524 const CNode::OutputValue &CNode::forward() const {
525 static const CNode::OutputValue empty_value;
526 auto ptr = user_data<CNode::OutputValue>(kOutputValueKey);
527 if (ptr == nullptr) {
528 return empty_value;
529 }
530 return *ptr;
531 }
532
stop_gradient() const533 bool CNode::stop_gradient() const { return flags_[kStopGradient]; }
534
set_stop_gradient(bool stop_gradient)535 void CNode::set_stop_gradient(bool stop_gradient) { flags_[kStopGradient] = stop_gradient; }
536
set_fullname_with_scope(const std::string full_name)537 void CNode::set_fullname_with_scope(const std::string full_name) { fullname_with_scope_ = full_name; }
538
DebugString(bool recursive) const539 std::string CNode::DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); }
540
set_in_forward_flag(bool flag)541 void CNode::set_in_forward_flag(bool flag) { flags_[kInForwardFlag] = flag; }
542
in_forward_flag() const543 bool CNode::in_forward_flag() const { return flags_[kInForwardFlag]; }
544
set_load_flag(bool is_load)545 void CNode::set_load_flag(bool is_load) { flags_[kIsLoad] = is_load; }
546
get_load_flag() const547 bool CNode::get_load_flag() const { return flags_[kIsLoad]; }
548
func_graph_as_var() const549 VarPtr CNode::func_graph_as_var() const { return user_data<Var>(kFuncGraphVarKey); }
550
attrs() const551 const mindspore::HashMap<std::string, ValuePtr> &CNode::attrs() const { return attrs_; }
552
set_attrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)553 void CNode::set_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
554 attrs_.insert(attrs.cbegin(), attrs.cend());
555 }
556
AddAttr(const std::string & name,const ValuePtr & attr)557 void CNode::AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
558
EraseAttr(const std::string & name)559 void CNode::EraseAttr(const std::string &name) { (void)attrs_.erase(name); }
560
GetAttr(const std::string & name) const561 ValuePtr CNode::GetAttr(const std::string &name) const {
562 auto iter = attrs_.find(name);
563 return iter == attrs_.cend() ? nullptr : iter->second;
564 }
565
HasAttr(const std::string & name) const566 bool CNode::HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); }
567
input_tensor_num() const568 ssize_t CNode::input_tensor_num() const { return input_tensor_num_; }
569
primal_attrs() const570 const mindspore::HashMap<std::string, ValuePtr> &CNode::primal_attrs() const { return primal_attrs_; }
571
set_primal_attrs(const mindspore::HashMap<std::string,ValuePtr> & attrs)572 void CNode::set_primal_attrs(const mindspore::HashMap<std::string, ValuePtr> &attrs) {
573 primal_attrs_.insert(attrs.cbegin(), attrs.cend());
574 }
575
AddPrimalAttr(const std::string & name,const ValuePtr & attr)576 void CNode::AddPrimalAttr(const std::string &name, const ValuePtr &attr) { primal_attrs_[name] = attr; }
577
ErasePrimalAttr(const std::string & name)578 void CNode::ErasePrimalAttr(const std::string &name) { (void)primal_attrs_.erase(name); }
579
GetPrimalAttr(const std::string & name) const580 ValuePtr CNode::GetPrimalAttr(const std::string &name) const {
581 auto iter = primal_attrs_.find(name);
582 return iter == primal_attrs_.cend() ? nullptr : iter->second;
583 }
584
HasPrimalAttr(const std::string & name) const585 bool CNode::HasPrimalAttr(const std::string &name) const { return primal_attrs_.find(name) != primal_attrs_.end(); }
586
CloneCNodeInfo(const CNodePtr & node)587 void CNode::CloneCNodeInfo(const CNodePtr &node) {
588 MS_EXCEPTION_IF_NULL(node);
589 set_abstract(node->abstract());
590 set_forward(node->forward().first, node->forward().second);
591 set_attrs(node->attrs());
592 set_primal_attrs(node->primal_attrs());
593 set_load_flag(node->get_load_flag());
594 CloneUserData(node);
595 set_kernel_info(node->kernel_info_ptr());
596 set_primal_debug_infos(node->primal_debug_infos());
597 set_fused_debug_infos(node->fused_debug_infos());
598 }
599
set_input_tensor_num(ssize_t input_tensor_num)600 void CNode::set_input_tensor_num(ssize_t input_tensor_num) { input_tensor_num_ = input_tensor_num; }
601
IsEffectHandled() const602 bool CNode::IsEffectHandled() const { return flags_[kEffectHandled]; }
603
SetEffectHandled(bool handled)604 void CNode::SetEffectHandled(bool handled) { flags_[kEffectHandled] = handled; }
605
fused_debug_infos() const606 NodeDebugInfoSet CNode::fused_debug_infos() const { return fused_debug_infos_; }
607
set_fused_debug_infos(const NodeDebugInfoSet & fused_debug_infos)608 void CNode::set_fused_debug_infos(const NodeDebugInfoSet &fused_debug_infos) { fused_debug_infos_ = fused_debug_infos; }
609
has_side_effect_node() const610 bool CNode::has_side_effect_node() const { return has_side_effect_node_; }
611
set_has_side_effect_node(bool has_side_effect_node)612 void CNode::set_has_side_effect_node(bool has_side_effect_node) { has_side_effect_node_ = has_side_effect_node; }
613
Parameter(const FuncGraphPtr & func_graph)614 Parameter::Parameter(const FuncGraphPtr &func_graph) : ANode(func_graph) { Init(); }
615
Parameter(const FuncGraphPtr & func_graph,NodeDebugInfoPtr && debug_info)616 Parameter::Parameter(const FuncGraphPtr &func_graph, NodeDebugInfoPtr &&debug_info)
617 : ANode(func_graph, std::move(debug_info)) {
618 Init();
619 }
620
Init()621 void Parameter::Init() { debug_info()->set_type_name(type_name()); }
622
set_debug_info(const NodeDebugInfoPtr & debug_info)623 void Parameter::set_debug_info(const NodeDebugInfoPtr &debug_info) {
624 MS_EXCEPTION_IF_NULL(debug_info);
625 debug_info->set_type_name(type_name());
626 debug_info->set_debug_name(std::string()); // Clear cached debug name.
627 debug_info_ = debug_info;
628 }
629
name() const630 std::string Parameter::name() const { return name_; }
631
set_name(const std::string & name)632 void Parameter::set_name(const std::string &name) { name_ = name; }
633
fullname_with_scope()634 std::string Parameter::fullname_with_scope() { return name(); }
635
has_default() const636 bool Parameter::has_default() const { return has_default_; }
637
set_default_param(const ValuePtr & param)638 void Parameter::set_default_param(const ValuePtr ¶m) {
639 default_param_ = param;
640 has_default_ = true;
641 }
642
default_param() const643 const ValuePtr &Parameter::default_param() const { return default_param_; }
644
IncreaseUsedGraphCount()645 void Parameter::IncreaseUsedGraphCount() { used_graph_count_++; }
646
DecreaseUsedGraphCount()647 void Parameter::DecreaseUsedGraphCount() { used_graph_count_--; }
648
used_graph_count() const649 int Parameter::used_graph_count() const { return used_graph_count_; }
650
is_top_graph_param() const651 bool Parameter::is_top_graph_param() const { return is_top_graph_param_; }
652
set_is_top_graph_param(bool flag)653 void Parameter::set_is_top_graph_param(bool flag) { is_top_graph_param_ = flag; }
654
operator ==(const AnfNode & other) const655 bool Parameter::operator==(const AnfNode &other) const {
656 if (!other.isa<Parameter>()) {
657 return false;
658 }
659 auto &p = static_cast<const Parameter &>(other);
660 if (name_.length() > 0 && p.name_.length() > 0) {
661 return p.name_ == name_;
662 }
663 return shared_from_this() == other.shared_from_this();
664 }
665
SetNotUsedByRealKernelInGraph(uint32_t graph_id)666 void Parameter::SetNotUsedByRealKernelInGraph(uint32_t graph_id) { (void)not_used_in_graphs_.insert(graph_id); }
667
IsUsedByRealKernelInGraph(uint32_t graph_id) const668 bool Parameter::IsUsedByRealKernelInGraph(uint32_t graph_id) const {
669 return not_used_in_graphs_.find(graph_id) == not_used_in_graphs_.end();
670 }
671
set_has_dynamic_shape(bool flag)672 void Parameter::set_has_dynamic_shape(bool flag) { has_dynamic_shape_ = flag; }
673
has_dynamic_shape() const674 bool Parameter::has_dynamic_shape() const { return has_dynamic_shape_; }
675
set_dynamic_len(bool flag)676 void Parameter::set_dynamic_len(bool flag) { is_dynamic_len_ = flag; }
677
dynamic_len() const678 bool Parameter::dynamic_len() const { return is_dynamic_len_; }
679
set_fracz_group(int64_t fracz_group)680 void Parameter::set_fracz_group(int64_t fracz_group) { format_attrs_.fracz_group = fracz_group; }
681
fracz_group() const682 int64_t Parameter::fracz_group() const { return format_attrs_.fracz_group; }
683
set_input_size(int64_t input_size)684 void Parameter::set_input_size(int64_t input_size) { format_attrs_.input_size = input_size; }
685
input_size() const686 int64_t Parameter::input_size() const { return format_attrs_.input_size; }
687
set_hidden_size(int64_t hidden_size)688 void Parameter::set_hidden_size(int64_t hidden_size) { format_attrs_.hidden_size = hidden_size; }
689
hidden_size() const690 int64_t Parameter::hidden_size() const { return format_attrs_.hidden_size; }
691
DebugString(int recursive_level) const692 std::string Parameter::DebugString(int recursive_level) const {
693 std::ostringstream buffer;
694 if (recursive_level > 0) {
695 if (func_graph() != nullptr) {
696 buffer << "@" << func_graph()->ToString() << ":";
697 }
698 }
699 buffer << "param_" << ToString();
700 return buffer.str();
701 }
702
param_info() const703 ParamInfoPtr Parameter::param_info() const {
704 if (!has_default()) {
705 return nullptr;
706 }
707 auto tensor = default_param()->cast_ptr<tensor::MetaTensor>();
708 if (tensor == nullptr || !tensor->is_parameter()) {
709 return nullptr;
710 }
711 return tensor->param_info();
712 }
713
Value(const TypePtr t)714 Value::Value(const TypePtr t) : type_(t) {}
715
Value(const Value & other)716 Value::Value(const Value &other) : Base(other) { this->type_ = other.type_; }
717
type() const718 TypePtr Value::type() const { return type_; }
719
ToAbstract()720 abstract::AbstractBasePtr Value::ToAbstract() {
721 MS_LOG(INTERNAL_EXCEPTION) << "ToAbstract error : The class " << type_name() << "has no implement ToAbstract yet.";
722 }
723
operator =(const Value & other)724 Value &Value::operator=(const Value &other) {
725 if (&other == this) {
726 return *this;
727 }
728 this->type_ = other.type_;
729 return *this;
730 }
731
ContainsValueAny() const732 bool Value::ContainsValueAny() const { return false; }
733
ValueNode(const ValuePtr & value)734 ValueNode::ValueNode(const ValuePtr &value) : value_(value) {
735 if (value->ContainsValueAny()) {
736 MS_LOG(EXCEPTION) << "Value of value node cannot be ValueAny. Value: " << value->ToString();
737 }
738 #ifdef DEBUG
739 // Check if the func graph is release too early.
740 const auto &fg = dyn_cast<FuncGraph>(value);
741 if (fg != nullptr && fg->get_return() == nullptr) {
742 MS_LOG(WARNING) << fg->ToString() << "(ptr: " << fg << ") maybe released early, or not set return yet."
743 }
744 #endif
745 Init();
746 }
747
ValueNode(const ValuePtr & value,NodeDebugInfoPtr && debug_info)748 ValueNode::ValueNode(const ValuePtr &value, NodeDebugInfoPtr &&debug_info)
749 : ANode(nullptr, std::move(debug_info)), value_(value) {
750 MS_EXCEPTION_IF_NULL(value);
751 if (value->ContainsValueAny()) {
752 MS_LOG(EXCEPTION) << "Value of value node cannot be ValueAny. Value: " << value->ToString();
753 }
754 #ifdef DEBUG
755 // Check if the func graph is release too early.
756 const auto &fg = dyn_cast<FuncGraph>(value);
757 if (fg != nullptr && fg->get_return() == nullptr) {
758 MS_LOG(WARNING) << fg->ToString() << "(ptr: " << fg << ") maybe released early, or not set return yet."
759 }
760 #endif
761 Init();
762 }
763
Init()764 void ValueNode::Init() { debug_info()->set_type_name(type_name()); }
765
set_debug_info(const NodeDebugInfoPtr & debug_info)766 void ValueNode::set_debug_info(const NodeDebugInfoPtr &debug_info) {
767 MS_EXCEPTION_IF_NULL(debug_info);
768 debug_info->set_type_name(type_name());
769 debug_info->set_debug_name(std::string()); // Clear cached debug name.
770 debug_info_ = debug_info;
771 }
772
set_func_graph(const FuncGraphPtr &)773 void ValueNode::set_func_graph(const FuncGraphPtr &) {
774 MS_INTERNAL_EXCEPTION(ValueError) << "ValueNode should not set its func_graph.";
775 }
776
set_value(const ValuePtr & value)777 void ValueNode::set_value(const ValuePtr &value) {
778 MS_EXCEPTION_IF_NULL(value);
779 if (value->ContainsValueAny()) {
780 MS_LOG(INTERNAL_EXCEPTION) << "Value of value node cannot be ValueAny. Value: " << value->ToString();
781 }
782 value_ = value;
783 }
784
value() const785 const ValuePtr &ValueNode::value() const { return value_; }
786
set_has_new_value(bool flag)787 void ValueNode::set_has_new_value(bool flag) { has_new_value_ = flag; }
788
has_new_value() const789 bool ValueNode::has_new_value() const { return has_new_value_; }
790
used_graph_count() const791 size_t ValueNode::used_graph_count() const { return used_graph_count_; }
792
set_fracz_group(int64_t group)793 void ValueNode::set_fracz_group(int64_t group) { format_attr_.fracz_group = group; }
794
fracz_group() const795 int64_t ValueNode::fracz_group() const { return format_attr_.fracz_group; }
796
set_used_graph_count(size_t used_graph_count)797 void ValueNode::set_used_graph_count(size_t used_graph_count) { used_graph_count_ = used_graph_count; }
798
DebugString(bool recursive) const799 std::string ValueNode::DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); }
800
operator ==(const AnfNode & other) const801 bool ValueNode::operator==(const AnfNode &other) const {
802 if (!other.isa<ValueNode>()) {
803 return false;
804 }
805 auto &v = static_cast<const ValueNode &>(other);
806 return *v.value() == *value();
807 }
808
ToString() const809 std::string ValueNode::ToString() const {
810 MS_EXCEPTION_IF_NULL(value_);
811 if (value_->isa<FuncGraph>()) {
812 return value_->ToString();
813 }
814 std::ostringstream buffer;
815 buffer << AnfNode::ToString();
816 buffer << "(" << value_->ToString() << ")";
817 return buffer.str();
818 }
819
DebugString(int) const820 std::string ValueNode::DebugString(int) const {
821 MS_EXCEPTION_IF_NULL(value_);
822 std::ostringstream buffer;
823 buffer << "ValueNode<" << value_->type_name() << "> " << value_->ToString();
824 return buffer.str();
825 }
826
fullname_with_scope()827 std::string ValueNode::fullname_with_scope() {
828 if (!fullname_with_scope_.empty()) {
829 return fullname_with_scope_;
830 }
831
832 MS_EXCEPTION_IF_NULL(scope());
833 fullname_with_scope_ = scope()->name() + "/" + "data-";
834 fullname_with_scope_ += id_generator::get_id(fullname_with_scope_);
835 return fullname_with_scope_;
836 }
837
IsPrimitiveCNode(const AnfNodePtr & node,const PrimitivePtr & value)838 bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
839 auto cnode = dyn_cast_ptr<CNode>(node);
840 if (cnode == nullptr || cnode->size() == 0) {
841 return false;
842 }
843 const auto &input = cnode->input(0);
844 MS_EXCEPTION_IF_NULL(input);
845 auto prim = GetValuePtr<Primitive>(input);
846 if (prim == nullptr) {
847 return false;
848 }
849 return (value == nullptr) || ((prim->Hash() == value->Hash()) && (prim->name() == value->name()));
850 }
851
IsPrimitiveCNodeWithoutDoSignature(const AnfNodePtr & node,const PrimitivePtr & value)852 bool IsPrimitiveCNodeWithoutDoSignature(const AnfNodePtr &node, const PrimitivePtr &value) {
853 auto prim = GetCNodePrimitiveWithoutDoSignature(node);
854 if (prim == nullptr) {
855 return false;
856 }
857 return (value == nullptr) || ((prim->Hash() == value->Hash()) && (prim->name() == value->name()));
858 }
859
GetCNodePrimitive(const AnfNodePtr & node)860 PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
861 auto cnode = dyn_cast_ptr<CNode>(node);
862 if (cnode == nullptr || cnode->size() == 0) {
863 return nullptr;
864 }
865 return GetValueNode<PrimitivePtr>(cnode->input(0));
866 }
867
868 // Return the function Primitive if DoSignaturePrimitive,
869 // otherwise return the Primitive directly.
GetPrimitiveWithoutDoSignature(const AnfNodePtr & node)870 PrimitivePtr GetPrimitiveWithoutDoSignature(const AnfNodePtr &node) {
871 auto do_signature_prim = GetValuePtr<prim::DoSignaturePrimitive>(node);
872 if (do_signature_prim != nullptr) {
873 return dyn_cast<Primitive>(do_signature_prim->function());
874 }
875 return GetValueNode<PrimitivePtr>(node);
876 }
877
878 // Check the first input of CNode.
879 // Return the function Primitive if DoSignaturePrimitive,
880 // otherwise return the Primitive directly.
GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr & node)881 PrimitivePtr GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr &node) {
882 auto cnode = dyn_cast_ptr<CNode>(node);
883 if (cnode == nullptr || cnode->size() == 0) {
884 return nullptr;
885 }
886 return GetPrimitiveWithoutDoSignature(cnode->input(0));
887 }
888
889 // Return the function value if DoSignaturePrimitive,
890 // otherwise return the value directly.
GetValueWithoutDoSignature(const ValuePtr & value)891 ValuePtr GetValueWithoutDoSignature(const ValuePtr &value) {
892 auto do_signature_prim = dyn_cast_ptr<prim::DoSignaturePrimitive>(value);
893 if (do_signature_prim != nullptr) {
894 return do_signature_prim->function();
895 }
896 return value;
897 }
898
899 // Return the function value if DoSignaturePrimitive,
900 // otherwise return the value directly.
GetValueWithoutDoSignature(const AnfNodePtr & node)901 ValuePtr GetValueWithoutDoSignature(const AnfNodePtr &node) {
902 auto value = GetValueNode(node);
903 if (value == nullptr) {
904 return nullptr;
905 }
906 return GetValueWithoutDoSignature(value);
907 }
908
909 // Check the first input of CNode.
910 // Return the function value if DoSignaturePrimitive,
911 // otherwise return the value directly.
GetCNodeValueWithoutDoSignature(const AnfNodePtr & node)912 ValuePtr GetCNodeValueWithoutDoSignature(const AnfNodePtr &node) {
913 auto cnode = dyn_cast_ptr<CNode>(node);
914 if (cnode == nullptr || cnode->size() == 0) {
915 return nullptr;
916 }
917 return GetValueWithoutDoSignature(cnode->input(0));
918 }
919
GetCNodeFuncName(const CNodePtr & cnode)920 std::string GetCNodeFuncName(const CNodePtr &cnode) {
921 MS_EXCEPTION_IF_NULL(cnode);
922 if (cnode->inputs().empty()) {
923 return "";
924 }
925
926 AnfNodePtr valuenode = cnode->input(0);
927 auto value = GetValuePtr(valuenode);
928 if (value == nullptr) {
929 return "";
930 }
931 auto prim = value->cast_ptr<Primitive>();
932 if (prim != nullptr) {
933 return prim->name();
934 }
935 return value->ToString();
936 }
937
GetCNodeFuncGraph(const AnfNodePtr & node)938 FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node) {
939 auto cnode = dyn_cast_ptr<CNode>(node);
940 if (cnode != nullptr && cnode->size() > 0) {
941 return GetValueNode<FuncGraphPtr>(cnode->input(0));
942 }
943 return nullptr;
944 }
945
IsPrimitive(const AnfNodePtr & node,const PrimitivePtr & value)946 bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
947 if (IsValueNode<Primitive>(node)) {
948 auto prim = GetValuePtr<Primitive>(node);
949 MS_EXCEPTION_IF_NULL(value);
950 if (prim->Hash() == value->Hash() && prim->name() == value->name()) {
951 return true;
952 }
953 }
954 return false;
955 }
956
IsPrimitiveEquals(const PrimitivePtr & prim1,const PrimitivePtr & prim2)957 bool IsPrimitiveEquals(const PrimitivePtr &prim1, const PrimitivePtr &prim2) {
958 if (prim1 == nullptr || prim2 == nullptr) {
959 return false;
960 }
961 return (prim1 == prim2) || (prim1->Hash() == prim2->Hash() && prim1->name() == prim2->name());
962 }
963
GetAbstractMonadNum(const AbstractBasePtrList & args)964 size_t GetAbstractMonadNum(const AbstractBasePtrList &args) {
965 size_t num = 0;
966 for (auto &arg : args) {
967 if (arg->isa<abstract::AbstractMonad>()) {
968 ++num;
969 }
970 }
971 return num;
972 }
973
974 template <typename T>
HasAbstract(const AnfNodePtr & node)975 bool HasAbstract(const AnfNodePtr &node) {
976 if (node == nullptr) {
977 return false;
978 }
979 const auto &abs = node->abstract();
980 return (abs != nullptr && abs->isa<T>());
981 }
982
HasAbstractMonad(const AnfNodePtr & node)983 bool HasAbstractMonad(const AnfNodePtr &node) { return HasAbstract<abstract::AbstractMonad>(node); }
984
HasAbstractUMonad(const AnfNodePtr & node)985 bool HasAbstractUMonad(const AnfNodePtr &node) { return HasAbstract<abstract::AbstractUMonad>(node); }
986
HasAbstractIOMonad(const AnfNodePtr & node)987 bool HasAbstractIOMonad(const AnfNodePtr &node) { return HasAbstract<abstract::AbstractIOMonad>(node); }
988
GetPrimitiveFlag(const PrimitivePtr & prim,const std::string & attr)989 bool GetPrimitiveFlag(const PrimitivePtr &prim, const std::string &attr) {
990 if (prim != nullptr) {
991 auto flag = prim->GetAttr(attr);
992 if (flag && flag->isa<BoolImm>()) {
993 return GetValue<bool>(flag);
994 }
995 }
996 return false;
997 }
998
GetPrimEffectInfo(const PrimitivePtr & prim)999 EffectInfo GetPrimEffectInfo(const PrimitivePtr &prim) {
1000 bool mem = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_MEM);
1001 bool io = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_IO);
1002 bool back_mem = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP_MEM);
1003 return {EffectInfo::kDetected, mem, io, false, back_mem};
1004 }
1005
GetLoadInputs(const AnfNodePtr & node)1006 std::set<CNodePtr> GetLoadInputs(const AnfNodePtr &node) {
1007 std::set<CNodePtr> loads;
1008 auto cnode = dyn_cast_ptr<CNode>(node);
1009 if (cnode == nullptr) {
1010 return loads;
1011 }
1012 auto &inputs = cnode->weak_inputs();
1013 for (size_t i = 1; i < inputs.size(); ++i) {
1014 const auto &input = inputs.at(i).lock();
1015 if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
1016 loads.insert(input->cast<CNodePtr>());
1017 } else if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
1018 loads.merge(GetLoadInputs(input));
1019 }
1020 }
1021 return loads;
1022 }
1023
IsStateEquivalent(const AnfNodePtr & outer,const AnfNodePtr & inner)1024 bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner) {
1025 constexpr size_t kMonadInput = 2;
1026 auto outer_loads = GetLoadInputs(outer);
1027 if (outer_loads.empty()) {
1028 return true;
1029 }
1030 auto inner_loads = GetLoadInputs(inner);
1031 if (inner_loads.empty()) {
1032 return true;
1033 }
1034 outer_loads.merge(inner_loads);
1035 const auto &monad = (*outer_loads.begin())->weak_inputs().at(kMonadInput).lock();
1036 return std::all_of(++outer_loads.begin(), outer_loads.end(), [&monad, kMonadInput](const CNodePtr &load) {
1037 return load->weak_inputs().at(kMonadInput).lock() == monad;
1038 });
1039 }
1040
1041 // Check if the node is DeadNode.
IsDeadNode(const AnfNodePtr & node)1042 bool IsDeadNode(const AnfNodePtr &node) {
1043 auto value = GetValuePtr<ValueProblem>(node);
1044 return (value != nullptr) && (value->IsDead());
1045 }
1046
1047 // Check if the node is PolyNode.
IsPolyNode(const AnfNodePtr & node)1048 bool IsPolyNode(const AnfNodePtr &node) {
1049 auto value = GetValuePtr<ValueProblem>(node);
1050 return (value != nullptr) && (value->IsPoly());
1051 }
1052
NewSeenGeneration()1053 SeenNum NewSeenGeneration() {
1054 static SeenNum seen_generation = 0;
1055 ++seen_generation;
1056 // 0 is invalid number.
1057 if (seen_generation == 0) {
1058 ++seen_generation;
1059 }
1060 return seen_generation;
1061 }
1062
1063 namespace id_generator {
1064 static mindspore::HashMap<std::string, int> node_ids;
1065 static int offset = 0;
get_id(const string & scope_front_info)1066 std::string get_id(const string &scope_front_info) {
1067 if (node_ids.find(scope_front_info) == node_ids.end()) {
1068 node_ids[scope_front_info] = 0;
1069 } else {
1070 node_ids[scope_front_info]++;
1071 }
1072 std::string base_id = std::to_string(node_ids[scope_front_info]);
1073 // The id with offset means the user called reset_id_with_offset() and expect the operated id generated from 0 with an
1074 // identified offset.
1075 if (offset != 0) {
1076 return base_id + '_' + std::to_string(offset);
1077 }
1078 return base_id;
1079 }
1080
reset_id()1081 void reset_id() { node_ids.clear(); }
1082
reset_id_with_offset()1083 void reset_id_with_offset() {
1084 node_ids.clear();
1085 offset++;
1086 }
1087 } // namespace id_generator
1088 auto constexpr kPrimitiveTarget = "primitive_target";
1089 namespace {
GetPrimitiveFromValueNode(const AnfNodePtr & node)1090 PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) {
1091 auto value_node = dyn_cast_ptr<ValueNode>(node);
1092 if (value_node == nullptr) {
1093 return nullptr;
1094 }
1095 return dyn_cast<Primitive>(value_node->value());
1096 }
1097
GetNodeTargetForVarInputNode(const CNodePtr & cnode)1098 static std::string GetNodeTargetForVarInputNode(const CNodePtr &cnode) {
1099 auto &inputs = cnode->inputs();
1100 AnfNodeWeakPtrList real_inputs;
1101 const size_t update_state_valid_input_index = 2;
1102 const size_t make_tuple_valid_input_index = 1;
1103 if (cnode->IsApply(prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) {
1104 (void)std::copy(inputs.begin() + SizeToLong(update_state_valid_input_index), inputs.end(),
1105 std::back_inserter(real_inputs));
1106 } else if (cnode->IsApply(prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) {
1107 (void)std::copy(inputs.begin() + SizeToLong(make_tuple_valid_input_index), inputs.end(),
1108 std::back_inserter(real_inputs));
1109 }
1110 std::string first_input_target = kDeviceUnDefined;
1111 bool has_diff_target =
1112 std::any_of(std::rbegin(real_inputs), std::rend(real_inputs), [&first_input_target](const AnfNodeWeakPtr &n) {
1113 auto target = GetOriginNodeTarget(n.lock());
1114 if (target == kDeviceUnDefined) {
1115 return false;
1116 }
1117 if (first_input_target == kDeviceUnDefined) {
1118 first_input_target = target;
1119 }
1120 return target != first_input_target;
1121 });
1122 if (!has_diff_target) {
1123 return first_input_target;
1124 }
1125 return kDeviceUnDefined;
1126 }
1127
IsSummaryPrimitiveCNode(const AnfNodePtr & node)1128 static inline bool IsSummaryPrimitiveCNode(const AnfNodePtr &node) {
1129 return IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
1130 IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary);
1131 }
1132
GetVirtualNodeTargetFromInputs(const AnfNodePtr & node)1133 std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
1134 MS_EXCEPTION_IF_NULL(node);
1135 auto cnode = node->cast_ptr<CNode>();
1136 MS_EXCEPTION_IF_NULL(cnode);
1137 auto &weak_inputs = cnode->weak_inputs();
1138 #ifndef ENABLE_SECURITY
1139 if (IsSummaryPrimitiveCNode(node)) {
1140 if (weak_inputs.size() > 1) {
1141 return GetOriginNodeTarget(weak_inputs[1].lock());
1142 }
1143 return kDeviceUnDefined;
1144 }
1145 #endif
1146 if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
1147 const size_t node_inputs_num = 3;
1148 if (weak_inputs.size() >= node_inputs_num) {
1149 size_t use_index = 1;
1150 auto use_node = weak_inputs[use_index].lock();
1151 MS_EXCEPTION_IF_NULL(use_node);
1152 if (!use_node->isa<CNode>()) {
1153 use_index = 2;
1154 use_node = weak_inputs[use_index].lock();
1155 MS_EXCEPTION_IF_NULL(use_node);
1156 }
1157 return GetOriginNodeTarget(use_node);
1158 }
1159 } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
1160 return GetNodeTargetForVarInputNode(node->cast<CNodePtr>());
1161 } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1162 return GetOriginNodeTarget(cnode->input(1));
1163 }
1164 return kDeviceUnDefined;
1165 }
1166
GetVirtualNodeTargetFromUsers(const AnfNodePtr & node)1167 std::string GetVirtualNodeTargetFromUsers(const AnfNodePtr &node) {
1168 MS_EXCEPTION_IF_NULL(node);
1169 auto cnode = node->cast<CNodePtr>();
1170 MS_EXCEPTION_IF_NULL(cnode);
1171 auto func_graph = cnode->func_graph();
1172 if (func_graph == nullptr) {
1173 return kDeviceUnDefined;
1174 }
1175 auto manager = func_graph->manager();
1176 if (manager == nullptr) {
1177 return kDeviceUnDefined;
1178 }
1179 auto users = manager->node_users()[cnode];
1180 std::string first_user_target = kDeviceUnDefined;
1181 bool has_diff_target =
1182 std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) {
1183 auto target = GetOriginNodeTarget(u.first);
1184 if (target == kDeviceUnDefined) {
1185 return false;
1186 }
1187 if (first_user_target == kDeviceUnDefined) {
1188 first_user_target = target;
1189 }
1190 return target != first_user_target;
1191 });
1192 if (!has_diff_target) {
1193 return first_user_target;
1194 }
1195 return kDeviceUnDefined;
1196 }
1197
GetVirtualNodeTarget(const AnfNodePtr & node)1198 std::string GetVirtualNodeTarget(const AnfNodePtr &node) {
1199 MS_EXCEPTION_IF_NULL(node);
1200 node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(kDeviceUnDefined));
1201 auto target = GetVirtualNodeTargetFromInputs(node);
1202 node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(target));
1203 if (target != kDeviceUnDefined) {
1204 return target;
1205 }
1206 target = GetVirtualNodeTargetFromUsers(node);
1207 node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(target));
1208 return target;
1209 }
1210
GetTargetFromAttr(const AnfNodePtr & node)1211 std::string GetTargetFromAttr(const AnfNodePtr &node) {
1212 MS_EXCEPTION_IF_NULL(node);
1213 auto cnode = node->cast_ptr<CNode>();
1214 MS_EXCEPTION_IF_NULL(cnode);
1215 auto attr_input = cnode->input(0);
1216 auto primitive = GetPrimitiveFromValueNode(attr_input);
1217 if (primitive == nullptr) {
1218 return kDeviceUnDefined;
1219 }
1220 auto att_target = primitive->GetAttr(kPrimitiveTarget);
1221 if (att_target != nullptr) {
1222 if (!att_target->isa<StringImm>()) {
1223 MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
1224 }
1225 auto target = GetValue<std::string>(att_target);
1226 if (kTargetSet.find(target) == kTargetSet.end()) {
1227 MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target;
1228 }
1229 return target;
1230 }
1231 return kDeviceUnDefined;
1232 }
1233 } // namespace
1234
GetOriginNodeTarget(const AnfNodePtr & node)1235 std::string GetOriginNodeTarget(const AnfNodePtr &node) {
1236 MS_EXCEPTION_IF_NULL(node);
1237 if (!node->isa<CNode>()) {
1238 return kDeviceUnDefined;
1239 }
1240 auto cnode = node->cast_ptr<CNode>();
1241 MS_EXCEPTION_IF_NULL(cnode);
1242 auto ud_target = cnode->user_data<std::string>(kPrimitiveTarget);
1243 if (ud_target != nullptr) {
1244 return *ud_target.get();
1245 }
1246 auto target = GetTargetFromAttr(node);
1247 if (target != kDeviceUnDefined) {
1248 return target;
1249 }
1250 #ifndef ENABLE_SECURITY
1251 if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
1252 IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) ||
1253 IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
1254 IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
1255 IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1256 return GetVirtualNodeTarget(node);
1257 }
1258 #else
1259 if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
1260 IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
1261 IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
1262 return GetVirtualNodeTarget(node);
1263 }
1264 #endif
1265 auto context_ptr = MsContext::GetInstance();
1266 MS_EXCEPTION_IF_NULL(context_ptr);
1267 return context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1268 }
1269
GetCNodeTarget(const AnfNodePtr & node)1270 std::string GetCNodeTarget(const AnfNodePtr &node) {
1271 auto kernel_info = node->kernel_info();
1272 if (kernel_info != nullptr) {
1273 auto runtime_cache = kernel_info->runtime_cache();
1274 if (runtime_cache.runtime_cache().is_valid()) {
1275 auto tmp_target = runtime_cache.runtime_cache().device_target();
1276 if (!tmp_target.empty()) {
1277 return tmp_target;
1278 }
1279 }
1280 }
1281
1282 std::string target;
1283 auto ori_target = GetOriginNodeTarget(node);
1284 if (ori_target != kDeviceUnDefined) {
1285 target = ori_target;
1286 } else {
1287 auto context_ptr = MsContext::GetInstance();
1288 MS_EXCEPTION_IF_NULL(context_ptr);
1289 target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1290 }
1291
1292 if (kernel_info != nullptr) {
1293 auto runtime_cache = kernel_info->runtime_cache();
1294 if (runtime_cache.runtime_cache().is_valid()) {
1295 runtime_cache.runtime_cache().set_device_target(target);
1296 }
1297 }
1298 return target;
1299 }
1300
ContainMultiTarget(const AnfNodePtrList & nodes)1301 bool ContainMultiTarget(const AnfNodePtrList &nodes) {
1302 auto context_ptr = MsContext::GetInstance();
1303 MS_EXCEPTION_IF_NULL(context_ptr);
1304 std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1305 for (auto &node : nodes) {
1306 if (node->isa<CNode>()) {
1307 std::string cur_target = GetCNodeTarget(node);
1308 if (last_target != cur_target) {
1309 return true;
1310 }
1311 last_target = cur_target;
1312 }
1313 }
1314 return false;
1315 }
1316
IsOneOfPrimitive(const AnfNodePtr & node,const PrimitiveSet & prim_set)1317 bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
1318 PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
1319 return (prim && prim_set.find(prim) != prim_set.end());
1320 }
1321
IsOneOfPrimitiveCNode(const AnfNodePtr & node,const PrimitiveSet & prim_set)1322 bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) {
1323 MS_EXCEPTION_IF_NULL(node);
1324 auto cnode = node->cast_ptr<CNode>();
1325 if (cnode == nullptr || cnode->size() == 0) {
1326 return false;
1327 }
1328 return IsOneOfPrimitive(cnode->input(0), prim_set);
1329 }
1330
1331 // Set the sequence nodes' elements use flags to 'new_flag' at specific 'index' position.
SetSequenceElementsUseFlags(const AbstractBasePtr & abs,std::size_t index,bool new_flag)1332 void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, std::size_t index, bool new_flag) {
1333 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1334 if (!enable_eliminate_unused_element) {
1335 return;
1336 }
1337
1338 auto sequence_abs = dyn_cast_ptr<abstract::AbstractSequence>(abs);
1339 if (sequence_abs == nullptr) {
1340 return;
1341 }
1342 if (sequence_abs->sequence_nodes() == nullptr || sequence_abs->sequence_nodes()->empty()) {
1343 return;
1344 }
1345 for (auto &node : *sequence_abs->sequence_nodes()) {
1346 auto sequence_node = node.lock();
1347 if (sequence_node == nullptr) {
1348 MS_LOG(DEBUG) << "The node in sequence_nodes is free.";
1349 continue;
1350 }
1351 auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
1352 if (flags == nullptr) {
1353 continue;
1354 }
1355 if (index >= flags->size()) {
1356 MS_LOG(ERROR) << "The index " << index << " is out of range, size is " << flags->size() << ", for "
1357 << sequence_node->DebugString();
1358 return;
1359 }
1360 (*flags)[index] = new_flag;
1361 MS_LOG(DEBUG) << "Set item[" << index << "] use flag as " << new_flag << ", for " << sequence_node->DebugString();
1362 }
1363 }
1364
1365 // Set the sequence nodes' elements use flags all to 'new_flag'.
SetSequenceElementsUseFlags(const AbstractBasePtr & abs,bool new_flag)1366 void SetSequenceElementsUseFlags(const AbstractBasePtr &abs, bool new_flag) {
1367 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1368 if (!enable_eliminate_unused_element) {
1369 return;
1370 }
1371
1372 auto sequence_abs = dyn_cast_ptr<abstract::AbstractSequence>(abs);
1373 if (sequence_abs == nullptr) {
1374 return;
1375 }
1376 if (sequence_abs->sequence_nodes() == nullptr || sequence_abs->sequence_nodes()->empty()) {
1377 return;
1378 }
1379 for (auto &weak_node : *sequence_abs->sequence_nodes()) {
1380 auto sequence_node = weak_node.lock();
1381 if (sequence_node == nullptr) {
1382 MS_LOG(DEBUG) << "The node in sequence_nodes is free.";
1383 continue;
1384 }
1385 auto flags = GetSequenceNodeElementsUseFlags(sequence_node);
1386 if (flags != nullptr) {
1387 auto &all_flags = (*flags);
1388 (void)std::transform(all_flags.begin(), all_flags.end(), all_flags.begin(),
1389 [&new_flag](bool) -> bool { return new_flag; });
1390 }
1391 }
1392 }
1393
1394 // Set the sequence nodes' elements use flags all to 'new_flag' recursively.
SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr & abs,bool new_flag)1395 void SetSequenceElementsUseFlagsRecursively(const AbstractBasePtr &abs, bool new_flag) {
1396 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1397 if (!enable_eliminate_unused_element) {
1398 return;
1399 }
1400
1401 SetSequenceElementsUseFlags(abs, new_flag);
1402
1403 // Check its elements if it's a sequence node.
1404 auto sequence_abs = dyn_cast_ptr<abstract::AbstractSequence>(abs);
1405 if (sequence_abs != nullptr) {
1406 for (auto &element : sequence_abs->elements()) {
1407 SetSequenceElementsUseFlagsRecursively(element, new_flag);
1408 }
1409 return;
1410 }
1411
1412 // Check its elements if it's a dictionary node.
1413 auto dictionary_abs = dyn_cast_ptr<abstract::AbstractDictionary>(abs);
1414 if (dictionary_abs != nullptr) {
1415 for (auto &element : dictionary_abs->elements()) {
1416 SetSequenceElementsUseFlagsRecursively(element.second, new_flag);
1417 }
1418 }
1419 }
1420 } // namespace mindspore
1421