1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "utils/anf_utils.h"
18 #include <memory>
19 #include <string>
20 #include <list>
21 #include <algorithm>
22 #include "ops/structure_ops.h"
23 #include "ops/sequence_ops.h"
24 #include "ops/other_ops.h"
25 #include "ops/framework_ops.h"
26 #include "utils/trace_base.h"
27 #include "utils/hash_map.h"
28 #include "utils/os.h"
29 #include "include/common/utils/utils.h"
30 #include "utils/ms_context.h"
31
32 namespace mindspore {
33 namespace {
34 class AbstractMutexManager {
35 public:
GetInstance()36 static AbstractMutexManager &GetInstance() {
37 static AbstractMutexManager instance;
38 return instance;
39 }
40
GetAbstractLock(const AnfNode * node)41 std::recursive_mutex *GetAbstractLock(const AnfNode *node) {
42 std::lock_guard<std::recursive_mutex> lock(mu_);
43 if (is_valid_) {
44 return &mu_for_nodes_[node];
45 } else {
46 return nullptr;
47 }
48 }
49
Close()50 void Close() {
51 // cppcheck-suppress unreadVariable
52 std::lock_guard<std::recursive_mutex> lock(mu_);
53 is_valid_ = false;
54 mu_for_nodes_.clear();
55 }
56
Open()57 void Open() {
58 // cppcheck-suppress unreadVariable
59 std::lock_guard<std::recursive_mutex> lock(mu_);
60 is_valid_ = true;
61 }
62
63 private:
64 mindspore::HashMap<const AnfNode *, std::recursive_mutex> mu_for_nodes_;
65 std::recursive_mutex mu_;
66 bool is_valid_ = false;
67 };
68
69 struct CustomActorInfo {
CustomActorInfomindspore::__anon6ebeaf0a0111::CustomActorInfo70 CustomActorInfo(const AnfUtils::CustomActorCallback &func, const std::string &type_name, const CNodePtr &cnode)
71 : actor_func(func), type_name(type_name), base_cnode_ptr(cnode) {}
72 ~CustomActorInfo() = default;
73
74 // Key for user data.
75 constexpr static char key[] = "CustomActor";
76 AnfUtils::CustomActorCallback actor_func = {};
77 std::string type_name;
78 CNodeWeakPtr base_cnode_ptr;
79 };
80 using CustomActorInfoPtr = std::shared_ptr<CustomActorInfo>;
81
82 struct CNodeCustomInfo {
CNodeCustomInfomindspore::__anon6ebeaf0a0111::CNodeCustomInfo83 CNodeCustomInfo(const AnfNodePtr &inferop, const AnfNodePtr &initop) : infer_node(inferop), init_node(initop) {}
84 ~CNodeCustomInfo() = default;
85 // Key for user data.
86 constexpr static char key[] = "CustomNodeInfo";
87 AnfNodeWeakPtr infer_node;
88 AnfNodeWeakPtr init_node;
89 };
90 using CNodeCustomInfoPtr = std::shared_ptr<CNodeCustomInfo>;
91 struct RealInputInfo {
RealInputInfomindspore::__anon6ebeaf0a0111::RealInputInfo92 explicit RealInputInfo(const CNodePtr &cnode) : base_cnode_ptr(cnode), real_input_nodes() {}
93 ~RealInputInfo() = default;
94 // Key for user data.
95 constexpr static char key[] = "RealInputInfo";
96 CNodeWeakPtr base_cnode_ptr;
97 // HashMap <input_index, pair<pre_node, pre_node_output_index>> is used to record the real input node to infer the
98 // dynamic shape information of the nodes located at the boundary of the graph partition, such as heterogeneous
99 // scenario and so on.
100 mindspore::HashMap<size_t, std::pair<AnfNodeWeakPtr, size_t>> real_input_nodes;
101 };
102
NewCustomActorNode(const CustomActorInfoPtr & actor_info,const FuncGraphPtr & g)103 AnfNodePtr NewCustomActorNode(const CustomActorInfoPtr &actor_info, const FuncGraphPtr &g) {
104 MS_EXCEPTION_IF_NULL(g);
105 auto custom_actor_node = std::make_shared<AnfNode>(g);
106 custom_actor_node->set_user_data<CustomActorInfo>(actor_info);
107 return custom_actor_node;
108 }
109 } // namespace
110
AbstractScope(std::recursive_mutex * mu)111 AbstractScope::AbstractScope(std::recursive_mutex *mu) : mu_(mu) {
112 if (mu_ != nullptr) {
113 mu_->lock();
114 }
115 }
116
AbstractScope(AbstractScope && other)117 AbstractScope::AbstractScope(AbstractScope &&other) {
118 mu_ = other.mu_;
119 other.mu_ = nullptr;
120 }
121
operator =(AbstractScope && other)122 AbstractScope &AbstractScope::operator=(AbstractScope &&other) {
123 mu_ = other.mu_;
124 other.mu_ = nullptr;
125 return *this;
126 }
127
~AbstractScope()128 AbstractScope::~AbstractScope() {
129 if (mu_ != nullptr) {
130 mu_->unlock();
131 }
132 }
133
GetAbstractLock(const AnfNode * node)134 AbstractScope AnfUtils::GetAbstractLock(const AnfNode *node) {
135 return AbstractScope(AbstractMutexManager::GetInstance().GetAbstractLock(node));
136 }
137
OpenAbstractLock()138 void AnfUtils::OpenAbstractLock() { AbstractMutexManager::GetInstance().Open(); }
139
CloseAbstractLock()140 void AnfUtils::CloseAbstractLock() { AbstractMutexManager::GetInstance().Close(); }
141
142 // If the node's shape is dynamic shape or dynamic rank, return true.
IsNodeOutputShapeDynamic(const AnfNodePtr & node)143 bool AnfUtils::IsNodeOutputShapeDynamic(const AnfNodePtr &node) {
144 MS_EXCEPTION_IF_NULL(node);
145 auto base_shape = node->Shape();
146 if (base_shape == nullptr) {
147 MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
148 return false;
149 }
150 return base_shape->IsDynamic();
151 }
152
IsRealKernel(const AnfNodePtr & node)153 bool AnfUtils::IsRealKernel(const AnfNodePtr &node) {
154 MS_EXCEPTION_IF_NULL(node);
155 #ifndef ENABLE_SECURITY
156 static const PrimitiveSet virtual_prims = {
157 prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem,
158 prim::kPrimReturn, prim::kPrimPartial, prim::kPrimDepend,
159 prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimDynamicLossScale,
160 prim::kPrimMakeList, prim::kPrimListGetItem, prim::kPrimIs_,
161 prim::kPrimIsNot, prim::kPrimIsInstance};
162 #else
163 static const PrimitiveSet virtual_prims = {
164 prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem,
165 prim::kPrimReturn, prim::kPrimPartial, prim::kPrimDepend,
166 prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimDynamicLossScale};
167 #endif
168 auto cnode = node->cast<CNodePtr>();
169 if (cnode == nullptr) {
170 // parameter and value node is a real kernel too
171 return true;
172 }
173 if (cnode->size() == 0) {
174 MS_LOG(INTERNAL_EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString()
175 << trace::DumpSourceLines(node);
176 }
177
178 auto kernel_info = cnode->kernel_info();
179 if (kernel_info) {
180 auto runtime_cache = kernel_info->runtime_cache();
181 if (runtime_cache.runtime_cache().is_real_kernel() != Uncached) {
182 return (runtime_cache.runtime_cache().is_real_kernel() == True);
183 }
184 }
185
186 // In the GE backend, summary is the actual operator,
187 // and the corresponding back-end operator is OutfeedEnqueueOpV2
188 static const PrimitiveSet summary_prims = {
189 prim::kPrimImageSummary,
190 prim::kPrimScalarSummary,
191 prim::kPrimTensorSummary,
192 prim::kPrimHistogramSummary,
193 };
194
195 bool res = !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims);
196 static std::string backend = MsContext::GetInstance()->backend_policy();
197 if (backend != "ge") {
198 res = res && !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), summary_prims);
199 }
200
201 if (kernel_info) {
202 auto runtime_cache = kernel_info->runtime_cache();
203 if (res) {
204 runtime_cache.runtime_cache().set_real_kernel(True);
205 } else {
206 runtime_cache.runtime_cache().set_real_kernel(False);
207 }
208 }
209
210 return res;
211 }
212
IsRealCNodeKernel(const AnfNodePtr & node)213 bool AnfUtils::IsRealCNodeKernel(const AnfNodePtr &node) {
214 MS_EXCEPTION_IF_NULL(node);
215 if (!node->isa<CNode>()) {
216 return false;
217 }
218 if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
219 return true;
220 }
221 return AnfUtils::IsRealKernel(node);
222 }
223
GetCNodeName(const AnfNodePtr & node)224 std::string AnfUtils::GetCNodeName(const AnfNodePtr &node) {
225 MS_EXCEPTION_IF_NULL(node);
226 if (node->isa<CNode>()) {
227 auto primitive = GetCNodePrimitive(node);
228 if (primitive != nullptr) {
229 if (primitive->name() == "Custom") {
230 auto uniq_name = primitive->GetAttr("uniq_name");
231 if (uniq_name) {
232 return GetValue<std::string>(uniq_name);
233 }
234 }
235 return primitive->name();
236 }
237
238 // Check whether call node's input is not a value node which contains FuncGraph.
239 auto cnode = dyn_cast<CNode>(node);
240 MS_EXCEPTION_IF_NULL(cnode);
241 if (cnode->size() == 0 || !IsValueNode<FuncGraph>(cnode->input(0))) {
242 return "";
243 }
244
245 auto func_graph = GetCNodeFuncGraph(node);
246 MS_EXCEPTION_IF_NULL(func_graph);
247 if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
248 std::string fg_name = "GraphKernel_";
249 fg_name += GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
250 return fg_name;
251 }
252 return func_graph->ToString();
253 }
254 MS_LOG(INTERNAL_EXCEPTION) << "Unknown anf node type " << node->DebugString() << trace::DumpSourceLines(node);
255 }
256
GetInputTensorNum(const AnfNodePtr & node)257 size_t AnfUtils::GetInputTensorNum(const AnfNodePtr &node) {
258 MS_EXCEPTION_IF_NULL(node);
259 auto cnode = node->cast<CNodePtr>();
260 if (cnode == nullptr) {
261 MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString()
262 << trace::DumpSourceLines(node);
263 }
264 {
265 // cppcheck-suppress unreadVariable
266 auto lock = AnfUtils::GetAbstractLock(cnode.get());
267 ssize_t input_tensor_num = cnode->input_tensor_num();
268 if (input_tensor_num >= 0) {
269 return static_cast<size_t>(input_tensor_num);
270 }
271 }
272
273 size_t input_num = cnode->size();
274 if (input_num == 0) {
275 MS_LOG(INTERNAL_EXCEPTION) << "Cnode inputs size can't be zero" << trace::DumpSourceLines(node);
276 }
277 // Exclude inputs[0].
278 --input_num;
279
280 // Exclude monad inputs for real cnodes.
281 if (input_num > 0 && AnfUtils::IsRealKernel(cnode)) {
282 auto &inputs = cnode->inputs();
283 // Search monad inputs, backward.
284 for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
285 // cppcheck-suppress unreadVariable
286 auto lock = AnfUtils::GetAbstractLock((*iter).get());
287 if (!HasAbstractMonad(*iter)) {
288 // Stop count if we encounter a non-monad input.
289 break;
290 }
291 --input_num;
292 }
293 }
294 // cppcheck-suppress unreadVariable
295 auto lock = AnfUtils::GetAbstractLock(cnode.get());
296 cnode->set_input_tensor_num(static_cast<ssize_t>(input_num));
297 return input_num;
298 }
299
GetOutputTensorNum(const AnfNodePtr & node)300 size_t AnfUtils::GetOutputTensorNum(const AnfNodePtr &node) {
301 MS_EXCEPTION_IF_NULL(node);
302 auto kernel_info = node->kernel_info();
303 bool is_valid_cache = false;
304 if (kernel_info != nullptr) {
305 auto runtime_cache = kernel_info->runtime_cache();
306 if (runtime_cache.runtime_cache().is_valid()) {
307 ssize_t output_tensor_num = runtime_cache.runtime_cache().output_tensor_num();
308 if (output_tensor_num >= 0) {
309 return static_cast<size_t>(output_tensor_num);
310 }
311 is_valid_cache = true;
312 }
313 }
314
315 size_t res = 1;
316 TypePtr type = node->Type();
317 if (type == nullptr) {
318 res = 0;
319 } else if (type->isa<Tuple>()) {
320 auto tuple_type = type->cast<TuplePtr>();
321 MS_EXCEPTION_IF_NULL(tuple_type);
322 res = tuple_type->size();
323 if (res == 0) {
324 return res;
325 }
326 auto last_type = tuple_type->elements()[res - 1];
327 MS_EXCEPTION_IF_NULL(last_type);
328 // Some nodes could have monad outputs like RpcRecv. We need to jump these outputs.
329 if (NeedJumpMonadOutput(node) && last_type->isa<MonadType>()) {
330 for (size_t i = 0; i < tuple_type->elements().size(); i++) {
331 auto tuple_type_elem = tuple_type->elements()[i];
332 MS_EXCEPTION_IF_NULL(tuple_type_elem);
333 if (tuple_type_elem->isa<MonadType>()) {
334 res = i;
335 break;
336 }
337 }
338 }
339 } else if (type->isa<List>()) {
340 auto list_type = type->cast<ListPtr>();
341 MS_EXCEPTION_IF_NULL(list_type);
342 res = list_type->size();
343 } else if (type->isa<TypeNone>()) {
344 res = 0;
345 } else if (type->isa<CSRTensorType>()) {
346 // Currently, CSRTensor only supports 2-D matrix (shape has 2 values). 5 outputs = 3 Tensors + 2 shape values.
347 constexpr size_t kCSRTensorOutputNum = 5;
348 res = kCSRTensorOutputNum;
349 } else if (type->isa<COOTensorType>()) {
350 // Currently, COOTensor only supports 2-D matrix (shape has 2 values). 4 outputs = 2 Tensors + 2 shape values.
351 constexpr size_t kCOOTensorOutputNum = 4;
352 res = kCOOTensorOutputNum;
353 } else if (NeedJumpMonadOutput(node) && type->isa<MonadType>()) {
354 // Some nodes could have monad outputs like RpcRecv. We need to jump these outputs.
355 res = 0;
356 }
357
358 if (is_valid_cache) {
359 kernel_info->runtime_cache().runtime_cache().set_output_tensor_num(static_cast<ssize_t>(res));
360 }
361 return res;
362 }
363
SetNodeAttr(const std::string & key,const ValuePtr & value,const AnfNodePtr & node)364 void AnfUtils::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
365 MS_EXCEPTION_IF_NULL(node);
366 if (!node->isa<CNode>()) {
367 MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString()
368 << trace::DumpSourceLines(node);
369 }
370 // single op cnode.
371 auto primitive = GetCNodePrimitive(node);
372 if (primitive != nullptr) {
373 primitive->set_attr(key, value);
374 return;
375 }
376 // graph kernel cnode.
377 auto fg = GetCNodeFuncGraph(node);
378 MS_EXCEPTION_IF_NULL(fg);
379 fg->set_attr(key, value);
380 }
381
GetIntValue(const AnfNodePtr & anf_node)382 int64_t AnfUtils::GetIntValue(const AnfNodePtr &anf_node) {
383 MS_EXCEPTION_IF_NULL(anf_node);
384 auto value_node = anf_node->cast<ValueNodePtr>();
385 MS_EXCEPTION_IF_NULL(value_node);
386 auto value = value_node->value();
387 return GetIntValue(value);
388 }
389
GetIntValue(const ValuePtr & value)390 int64_t AnfUtils::GetIntValue(const ValuePtr &value) {
391 MS_EXCEPTION_IF_NULL(value);
392 if (value->isa<Int64Imm>()) {
393 return GetValue<int64_t>(value);
394 } else if (value->isa<Int32Imm>()) {
395 return IntToLong(GetValue<int>(value));
396 } else {
397 MS_LOG(EXCEPTION) << "The value should be Int32Imm or Int64Imm, but got " << value->ToString();
398 }
399 }
400
VisitKernel(const AnfNodePtr & anf_node,size_t index)401 std::pair<AnfNodePtr, size_t> AnfUtils::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
402 MS_EXCEPTION_IF_NULL(anf_node);
403 const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
404 if (anf_node->isa<ValueNode>()) {
405 return std::make_pair(anf_node, 0);
406 } else if (anf_node->isa<Parameter>()) {
407 return std::make_pair(anf_node, 0);
408 } else if (IsCustomActorNode(anf_node)) {
409 return std::make_pair(anf_node, 0);
410 } else if (anf_node->isa<CNode>()) {
411 auto cnode = anf_node->cast<CNodePtr>();
412 MS_EXCEPTION_IF_NULL(cnode);
413 auto input0 = cnode->input(0);
414 MS_EXCEPTION_IF_NULL(input0);
415 if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
416 if (GetInputTensorNum(cnode) == 0) {
417 return std::make_pair(nullptr, 0);
418 }
419 auto node = cnode->input(index + IntToSize(1));
420 MS_EXCEPTION_IF_NULL(node);
421 return VisitKernel(node, 0);
422 } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
423 if (cnode->size() != kTupleGetItemInputSize) {
424 MS_LOG(INTERNAL_EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
425 }
426 auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
427 auto item_idx = AnfUtils::GetIntValue(input2);
428 return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), LongToSize(item_idx));
429 } else if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
430 return VisitKernel(cnode->input(kUpdateStateRealInput), 0);
431 } else if (IsOneOfPrimitive(input0, follow_first_input_prims)) {
432 return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
433 } else {
434 return std::make_pair(anf_node, index);
435 }
436 } else {
437 MS_LOG(INTERNAL_EXCEPTION) << "The input is invalid";
438 }
439 }
440
IsGraphKernel(const AnfNodePtr & node)441 bool AnfUtils::IsGraphKernel(const AnfNodePtr &node) {
442 MS_EXCEPTION_IF_NULL(node);
443 auto func_graph = GetCNodeFuncGraph(node);
444 return func_graph != nullptr && func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
445 }
446
IsNodeInGraphKernel(const AnfNodePtr & node)447 bool AnfUtils::IsNodeInGraphKernel(const AnfNodePtr &node) {
448 MS_EXCEPTION_IF_NULL(node);
449 return node->func_graph() != nullptr && node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
450 }
451
SetDumpFlag(const AnfNodePtr & node)452 void AnfUtils::SetDumpFlag(const AnfNodePtr &node) {
453 if (node == nullptr || !node->isa<CNode>()) {
454 return;
455 }
456 auto prim = GetCNodePrimitive(node);
457 if (prim != nullptr) {
458 prim->set_attr(kAttrDump, MakeValue(kValueTrue));
459 }
460 }
461
GetDumpFlag(const AnfNodePtr & node)462 bool AnfUtils::GetDumpFlag(const AnfNodePtr &node) {
463 if (node == nullptr || !node->isa<CNode>()) {
464 return false;
465 }
466 auto prim = GetCNodePrimitive(node);
467 if (prim != nullptr) {
468 auto attr = prim->GetAttr(kAttrDump);
469 if (attr != nullptr && attr->isa<StringImm>() && attr->cast<StringImmPtr>()->value() == kValueTrue) {
470 return true;
471 }
472 }
473 return false;
474 }
475
HasDumpFlag(const AnfNodePtr & node)476 bool AnfUtils::HasDumpFlag(const AnfNodePtr &node) {
477 if (node == nullptr || !node->isa<CNode>()) {
478 return false;
479 }
480 auto prim = GetCNodePrimitive(node);
481 if (prim != nullptr) {
482 return prim->HasAttr(kAttrDump);
483 }
484 return false;
485 }
486
IsCustomActorNode(const AnfNodePtr & node)487 bool AnfUtils::IsCustomActorNode(const AnfNodePtr &node) {
488 MS_EXCEPTION_IF_NULL(node);
489 return node->has_user_data<CustomActorInfo>();
490 }
491
IsCutomActorNodeSame(const AnfNodePtr & node1,const AnfNodePtr & node2)492 bool AnfUtils::IsCutomActorNodeSame(const AnfNodePtr &node1, const AnfNodePtr &node2) {
493 MS_EXCEPTION_IF_NULL(node1);
494 MS_EXCEPTION_IF_NULL(node2);
495 if (!IsCustomActorNode(node1) || !IsCustomActorNode(node2)) {
496 MS_LOG(INTERNAL_EXCEPTION) << "Two node are not all Custom Actor Node!";
497 }
498
499 auto actor_info1 = node1->user_data<CustomActorInfo>();
500 MS_EXCEPTION_IF_NULL(actor_info1);
501 std::string actor_type1 = actor_info1->type_name;
502
503 auto actor_info2 = node2->user_data<CustomActorInfo>();
504 MS_EXCEPTION_IF_NULL(actor_info2);
505 std::string actor_type2 = actor_info2->type_name;
506
507 return (actor_type1 == actor_type2);
508 }
509
GetCustomActorType(const AnfNodePtr & node)510 std::string AnfUtils::GetCustomActorType(const AnfNodePtr &node) {
511 MS_EXCEPTION_IF_NULL(node);
512 if (!IsCustomActorNode(node)) {
513 MS_LOG(INTERNAL_EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
514 }
515
516 auto actor_info = node->user_data<CustomActorInfo>();
517 MS_EXCEPTION_IF_NULL(actor_info);
518 return actor_info->type_name;
519 }
520
GetCustomActorName(const AnfNodePtr & node)521 std::string AnfUtils::GetCustomActorName(const AnfNodePtr &node) {
522 MS_EXCEPTION_IF_NULL(node);
523 if (!IsCustomActorNode(node)) {
524 MS_LOG(INTERNAL_EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
525 }
526
527 auto actor_info = node->user_data<CustomActorInfo>();
528 MS_EXCEPTION_IF_NULL(actor_info);
529 auto base_node = actor_info->base_cnode_ptr.lock();
530 MS_EXCEPTION_IF_NULL(base_node);
531 std::string actor_name = actor_info->type_name + "_of_" + base_node->fullname_with_scope();
532 return actor_name;
533 }
534
GetCustomActorBaseNode(const AnfNodePtr & node)535 CNodePtr AnfUtils::GetCustomActorBaseNode(const AnfNodePtr &node) {
536 MS_EXCEPTION_IF_NULL(node);
537 if (!IsCustomActorNode(node)) {
538 MS_LOG(INTERNAL_EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
539 }
540
541 auto actor_info = node->user_data<CustomActorInfo>();
542 MS_EXCEPTION_IF_NULL(actor_info);
543 return actor_info->base_cnode_ptr.lock();
544 }
545
GetCustomFunc(const AnfNodePtr & node)546 AnfUtils::CustomActorCallback AnfUtils::GetCustomFunc(const AnfNodePtr &node) {
547 MS_EXCEPTION_IF_NULL(node);
548 if (!IsCustomActorNode(node)) {
549 MS_LOG(INTERNAL_EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
550 }
551
552 auto actor_info = node->user_data<CustomActorInfo>();
553 MS_EXCEPTION_IF_NULL(actor_info);
554 return actor_info->actor_func;
555 }
556
NewInitActorNode(AnfUtils::CustomActorCallback f,const CNodePtr & base_cnode)557 AnfNodePtr AnfUtils::NewInitActorNode(AnfUtils::CustomActorCallback f, const CNodePtr &base_cnode) {
558 MS_EXCEPTION_IF_NULL(base_cnode);
559 auto actor_info = std::make_shared<CustomActorInfo>(f, kInit, base_cnode);
560 return NewCustomActorNode(actor_info, base_cnode->func_graph());
561 }
562
NewInferActorNode(AnfUtils::CustomActorCallback f,const CNodePtr & base_cnode)563 AnfNodePtr AnfUtils::NewInferActorNode(AnfUtils::CustomActorCallback f, const CNodePtr &base_cnode) {
564 MS_EXCEPTION_IF_NULL(base_cnode);
565 auto actor_info = std::make_shared<CustomActorInfo>(f, kInfer, base_cnode);
566 return NewCustomActorNode(actor_info, base_cnode->func_graph());
567 }
568
SetCustomInfoToBaseNode(const AnfNodePtr & base_cnode,const AnfNodePtr & inferop,const AnfNodePtr & initop)569 void AnfUtils::SetCustomInfoToBaseNode(const AnfNodePtr &base_cnode, const AnfNodePtr &inferop,
570 const AnfNodePtr &initop) {
571 MS_EXCEPTION_IF_NULL(base_cnode);
572 MS_EXCEPTION_IF_NULL(inferop);
573 MS_EXCEPTION_IF_NULL(initop);
574
575 auto actor_info = std::make_shared<CNodeCustomInfo>(inferop, initop);
576 base_cnode->set_user_data<CNodeCustomInfo>(actor_info);
577 }
578
GetCustomInferopNode(const AnfNodePtr & base_cnode)579 AnfNodePtr AnfUtils::GetCustomInferopNode(const AnfNodePtr &base_cnode) {
580 MS_EXCEPTION_IF_NULL(base_cnode);
581 auto actor_info = base_cnode->user_data<CNodeCustomInfo>();
582 if (actor_info == nullptr) {
583 return nullptr;
584 }
585 return actor_info->infer_node.lock();
586 }
587
GetRealInputNodes(const CNodePtr & cnode)588 mindspore::HashMap<size_t, std::pair<AnfNodeWeakPtr, size_t>> &AnfUtils::GetRealInputNodes(const CNodePtr &cnode) {
589 MS_EXCEPTION_IF_NULL(cnode);
590 auto real_input_info = cnode->user_data<RealInputInfo>();
591 if (real_input_info == nullptr) {
592 real_input_info = std::make_shared<RealInputInfo>(cnode);
593 cnode->set_user_data(real_input_info);
594 }
595 return real_input_info->real_input_nodes;
596 }
597
NeedJumpMonadOutput(const AnfNodePtr & node)598 bool AnfUtils::NeedJumpMonadOutput(const AnfNodePtr &node) {
599 MS_EXCEPTION_IF_NULL(node);
600 auto cnode = node->cast<CNodePtr>();
601 if (cnode == nullptr) {
602 return false;
603 }
604
605 std::vector<std::string> jump_monad_output_nodes = {kRpcRecvOpName, prim::kPrimConditionSwitch->name(),
606 prim::kPrimConditionGather->name()};
607 if (std::find(jump_monad_output_nodes.begin(), jump_monad_output_nodes.end(), GetCNodeName(cnode)) !=
608 jump_monad_output_nodes.end()) {
609 return true;
610 }
611 return false;
612 }
613
AddParameter(const ParameterPtr & param)614 void FlatParameterFinder::AddParameter(const ParameterPtr ¶m) {
615 auto tensor = dyn_cast<tensor::Tensor>(param->default_param());
616 if (tensor == nullptr) {
617 return;
618 }
619 auto [chunk, offset] = tensor->GetChunkOffset();
620 if (chunk != nullptr) {
621 (void)param_to_flat_param_.emplace(param, FlatParamInfo{nullptr, chunk, offset});
622 return;
623 }
624 if (tensor->shape_c().size() == 1) {
625 (void)candidate_flat_params_.emplace(tensor->data_c(), param);
626 }
627 }
628
AddNodes(const std::vector<AnfNodePtr> & nodes)629 void FlatParameterFinder::AddNodes(const std::vector<AnfNodePtr> &nodes) {
630 for (auto &node : nodes) {
631 auto param = dyn_cast<Parameter>(node);
632 if (param != nullptr) {
633 AddParameter(param);
634 }
635 }
636 }
637
UpdateFlatParameters()638 void FlatParameterFinder::UpdateFlatParameters() {
639 if (candidate_flat_params_.empty()) {
640 return;
641 }
642 for (auto &entry : param_to_flat_param_) {
643 auto &info = entry.second;
644 if (info.flat_param == nullptr) {
645 auto iter = candidate_flat_params_.find(info.chunk);
646 if (iter != candidate_flat_params_.end()) {
647 (void)flat_params_.emplace(iter->second);
648 info.flat_param = iter->second;
649 }
650 }
651 }
652 candidate_flat_params_.clear();
653 }
654
FindFlatParameter(const ParameterPtr & param)655 std::pair<ParameterPtr, size_t> FlatParameterFinder::FindFlatParameter(const ParameterPtr ¶m) {
656 UpdateFlatParameters();
657 auto iter = param_to_flat_param_.find(param);
658 if (iter == param_to_flat_param_.end()) {
659 return {nullptr, 0};
660 }
661 auto &flat_param = iter->second.flat_param;
662 if (flat_param == nullptr) {
663 MS_LOG(WARNING) << "Find flat Parameter for " << param->ToString() << " failed";
664 return {nullptr, 0};
665 }
666 return {flat_param, iter->second.offset};
667 }
668
GetFlatParameters()669 const std::set<ParameterPtr> &FlatParameterFinder::GetFlatParameters() {
670 UpdateFlatParameters();
671 return flat_params_;
672 }
673 } // namespace mindspore
674