• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "utils/anf_utils.h"
18 #include <memory>
19 #include <string>
20 #include <list>
21 #include <algorithm>
22 #include "ops/structure_ops.h"
23 #include "ops/sequence_ops.h"
24 #include "ops/other_ops.h"
25 #include "ops/framework_ops.h"
26 #include "utils/trace_base.h"
27 #include "utils/hash_map.h"
28 #include "utils/os.h"
29 #include "include/common/utils/utils.h"
30 #include "utils/ms_context.h"
31 
32 namespace mindspore {
33 namespace {
34 class AbstractMutexManager {
35  public:
GetInstance()36   static AbstractMutexManager &GetInstance() {
37     static AbstractMutexManager instance;
38     return instance;
39   }
40 
GetAbstractLock(const AnfNode * node)41   std::recursive_mutex *GetAbstractLock(const AnfNode *node) {
42     std::lock_guard<std::recursive_mutex> lock(mu_);
43     if (is_valid_) {
44       return &mu_for_nodes_[node];
45     } else {
46       return nullptr;
47     }
48   }
49 
Close()50   void Close() {
51     // cppcheck-suppress unreadVariable
52     std::lock_guard<std::recursive_mutex> lock(mu_);
53     is_valid_ = false;
54     mu_for_nodes_.clear();
55   }
56 
Open()57   void Open() {
58     // cppcheck-suppress unreadVariable
59     std::lock_guard<std::recursive_mutex> lock(mu_);
60     is_valid_ = true;
61   }
62 
63  private:
64   mindspore::HashMap<const AnfNode *, std::recursive_mutex> mu_for_nodes_;
65   std::recursive_mutex mu_;
66   bool is_valid_ = false;
67 };
68 
69 struct CustomActorInfo {
CustomActorInfomindspore::__anon6ebeaf0a0111::CustomActorInfo70   CustomActorInfo(const AnfUtils::CustomActorCallback &func, const std::string &type_name, const CNodePtr &cnode)
71       : actor_func(func), type_name(type_name), base_cnode_ptr(cnode) {}
72   ~CustomActorInfo() = default;
73 
74   // Key for user data.
75   constexpr static char key[] = "CustomActor";
76   AnfUtils::CustomActorCallback actor_func = {};
77   std::string type_name;
78   CNodeWeakPtr base_cnode_ptr;
79 };
80 using CustomActorInfoPtr = std::shared_ptr<CustomActorInfo>;
81 
82 struct CNodeCustomInfo {
CNodeCustomInfomindspore::__anon6ebeaf0a0111::CNodeCustomInfo83   CNodeCustomInfo(const AnfNodePtr &inferop, const AnfNodePtr &initop) : infer_node(inferop), init_node(initop) {}
84   ~CNodeCustomInfo() = default;
85   // Key for user data.
86   constexpr static char key[] = "CustomNodeInfo";
87   AnfNodeWeakPtr infer_node;
88   AnfNodeWeakPtr init_node;
89 };
90 using CNodeCustomInfoPtr = std::shared_ptr<CNodeCustomInfo>;
91 struct RealInputInfo {
RealInputInfomindspore::__anon6ebeaf0a0111::RealInputInfo92   explicit RealInputInfo(const CNodePtr &cnode) : base_cnode_ptr(cnode), real_input_nodes() {}
93   ~RealInputInfo() = default;
94   // Key for user data.
95   constexpr static char key[] = "RealInputInfo";
96   CNodeWeakPtr base_cnode_ptr;
97   // HashMap <input_index, pair<pre_node, pre_node_output_index>> is used to record the real input node to infer the
98   // dynamic shape information of the nodes located at the boundary of the graph partition, such as heterogeneous
99   // scenario and so on.
100   mindspore::HashMap<size_t, std::pair<AnfNodeWeakPtr, size_t>> real_input_nodes;
101 };
102 
NewCustomActorNode(const CustomActorInfoPtr & actor_info,const FuncGraphPtr & g)103 AnfNodePtr NewCustomActorNode(const CustomActorInfoPtr &actor_info, const FuncGraphPtr &g) {
104   MS_EXCEPTION_IF_NULL(g);
105   auto custom_actor_node = std::make_shared<AnfNode>(g);
106   custom_actor_node->set_user_data<CustomActorInfo>(actor_info);
107   return custom_actor_node;
108 }
109 }  // namespace
110 
AbstractScope(std::recursive_mutex * mu)111 AbstractScope::AbstractScope(std::recursive_mutex *mu) : mu_(mu) {
112   if (mu_ != nullptr) {
113     mu_->lock();
114   }
115 }
116 
AbstractScope(AbstractScope && other)117 AbstractScope::AbstractScope(AbstractScope &&other) {
118   mu_ = other.mu_;
119   other.mu_ = nullptr;
120 }
121 
operator =(AbstractScope && other)122 AbstractScope &AbstractScope::operator=(AbstractScope &&other) {
123   mu_ = other.mu_;
124   other.mu_ = nullptr;
125   return *this;
126 }
127 
~AbstractScope()128 AbstractScope::~AbstractScope() {
129   if (mu_ != nullptr) {
130     mu_->unlock();
131   }
132 }
133 
GetAbstractLock(const AnfNode * node)134 AbstractScope AnfUtils::GetAbstractLock(const AnfNode *node) {
135   return AbstractScope(AbstractMutexManager::GetInstance().GetAbstractLock(node));
136 }
137 
OpenAbstractLock()138 void AnfUtils::OpenAbstractLock() { AbstractMutexManager::GetInstance().Open(); }
139 
CloseAbstractLock()140 void AnfUtils::CloseAbstractLock() { AbstractMutexManager::GetInstance().Close(); }
141 
142 // If the node's shape is dynamic shape or dynamic rank, return true.
IsNodeOutputShapeDynamic(const AnfNodePtr & node)143 bool AnfUtils::IsNodeOutputShapeDynamic(const AnfNodePtr &node) {
144   MS_EXCEPTION_IF_NULL(node);
145   auto base_shape = node->Shape();
146   if (base_shape == nullptr) {
147     MS_LOG(INFO) << "Invalid base shape, node: " << node->fullname_with_scope();
148     return false;
149   }
150   return base_shape->IsDynamic();
151 }
152 
IsRealKernel(const AnfNodePtr & node)153 bool AnfUtils::IsRealKernel(const AnfNodePtr &node) {
154   MS_EXCEPTION_IF_NULL(node);
155 #ifndef ENABLE_SECURITY
156   static const PrimitiveSet virtual_prims = {
157     prim::kPrimMakeTuple,   prim::kPrimStateSetItem, prim::kPrimTupleGetItem,
158     prim::kPrimReturn,      prim::kPrimPartial,      prim::kPrimDepend,
159     prim::kPrimUpdateState, prim::kPrimLoad,         prim::kPrimDynamicLossScale,
160     prim::kPrimMakeList,    prim::kPrimListGetItem,  prim::kPrimIs_,
161     prim::kPrimIsNot,       prim::kPrimIsInstance};
162 #else
163   static const PrimitiveSet virtual_prims = {
164     prim::kPrimMakeTuple,   prim::kPrimStateSetItem, prim::kPrimTupleGetItem,
165     prim::kPrimReturn,      prim::kPrimPartial,      prim::kPrimDepend,
166     prim::kPrimUpdateState, prim::kPrimLoad,         prim::kPrimDynamicLossScale};
167 #endif
168   auto cnode = node->cast<CNodePtr>();
169   if (cnode == nullptr) {
170     // parameter and value node is a real kernel too
171     return true;
172   }
173   if (cnode->size() == 0) {
174     MS_LOG(INTERNAL_EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString()
175                                << trace::DumpSourceLines(node);
176   }
177 
178   auto kernel_info = cnode->kernel_info();
179   if (kernel_info) {
180     auto runtime_cache = kernel_info->runtime_cache();
181     if (runtime_cache.runtime_cache().is_real_kernel() != Uncached) {
182       return (runtime_cache.runtime_cache().is_real_kernel() == True);
183     }
184   }
185 
186   // In the GE backend, summary is the actual operator,
187   // and the corresponding back-end operator is OutfeedEnqueueOpV2
188   static const PrimitiveSet summary_prims = {
189     prim::kPrimImageSummary,
190     prim::kPrimScalarSummary,
191     prim::kPrimTensorSummary,
192     prim::kPrimHistogramSummary,
193   };
194 
195   bool res = !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), virtual_prims);
196   static std::string backend = MsContext::GetInstance()->backend_policy();
197   if (backend != "ge") {
198     res = res && !IsOneOfPrimitive(cnode->input(kAnfPrimitiveIndex), summary_prims);
199   }
200 
201   if (kernel_info) {
202     auto runtime_cache = kernel_info->runtime_cache();
203     if (res) {
204       runtime_cache.runtime_cache().set_real_kernel(True);
205     } else {
206       runtime_cache.runtime_cache().set_real_kernel(False);
207     }
208   }
209 
210   return res;
211 }
212 
IsRealCNodeKernel(const AnfNodePtr & node)213 bool AnfUtils::IsRealCNodeKernel(const AnfNodePtr &node) {
214   MS_EXCEPTION_IF_NULL(node);
215   if (!node->isa<CNode>()) {
216     return false;
217   }
218   if (IsPrimitiveCNode(node, prim::kPrimReturn)) {
219     return true;
220   }
221   return AnfUtils::IsRealKernel(node);
222 }
223 
GetCNodeName(const AnfNodePtr & node)224 std::string AnfUtils::GetCNodeName(const AnfNodePtr &node) {
225   MS_EXCEPTION_IF_NULL(node);
226   if (node->isa<CNode>()) {
227     auto primitive = GetCNodePrimitive(node);
228     if (primitive != nullptr) {
229       if (primitive->name() == "Custom") {
230         auto uniq_name = primitive->GetAttr("uniq_name");
231         if (uniq_name) {
232           return GetValue<std::string>(uniq_name);
233         }
234       }
235       return primitive->name();
236     }
237 
238     // Check whether call node's input is not a value node which contains FuncGraph.
239     auto cnode = dyn_cast<CNode>(node);
240     MS_EXCEPTION_IF_NULL(cnode);
241     if (cnode->size() == 0 || !IsValueNode<FuncGraph>(cnode->input(0))) {
242       return "";
243     }
244 
245     auto func_graph = GetCNodeFuncGraph(node);
246     MS_EXCEPTION_IF_NULL(func_graph);
247     if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
248       std::string fg_name = "GraphKernel_";
249       fg_name += GetValue<std::string>(func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
250       return fg_name;
251     }
252     return func_graph->ToString();
253   }
254   MS_LOG(INTERNAL_EXCEPTION) << "Unknown anf node type " << node->DebugString() << trace::DumpSourceLines(node);
255 }
256 
GetInputTensorNum(const AnfNodePtr & node)257 size_t AnfUtils::GetInputTensorNum(const AnfNodePtr &node) {
258   MS_EXCEPTION_IF_NULL(node);
259   auto cnode = node->cast<CNodePtr>();
260   if (cnode == nullptr) {
261     MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString()
262                                << trace::DumpSourceLines(node);
263   }
264   {
265     // cppcheck-suppress unreadVariable
266     auto lock = AnfUtils::GetAbstractLock(cnode.get());
267     ssize_t input_tensor_num = cnode->input_tensor_num();
268     if (input_tensor_num >= 0) {
269       return static_cast<size_t>(input_tensor_num);
270     }
271   }
272 
273   size_t input_num = cnode->size();
274   if (input_num == 0) {
275     MS_LOG(INTERNAL_EXCEPTION) << "Cnode inputs size can't be zero" << trace::DumpSourceLines(node);
276   }
277   // Exclude inputs[0].
278   --input_num;
279 
280   // Exclude monad inputs for real cnodes.
281   if (input_num > 0 && AnfUtils::IsRealKernel(cnode)) {
282     auto &inputs = cnode->inputs();
283     // Search monad inputs, backward.
284     for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
285       // cppcheck-suppress unreadVariable
286       auto lock = AnfUtils::GetAbstractLock((*iter).get());
287       if (!HasAbstractMonad(*iter)) {
288         // Stop count if we encounter a non-monad input.
289         break;
290       }
291       --input_num;
292     }
293   }
294   // cppcheck-suppress unreadVariable
295   auto lock = AnfUtils::GetAbstractLock(cnode.get());
296   cnode->set_input_tensor_num(static_cast<ssize_t>(input_num));
297   return input_num;
298 }
299 
GetOutputTensorNum(const AnfNodePtr & node)300 size_t AnfUtils::GetOutputTensorNum(const AnfNodePtr &node) {
301   MS_EXCEPTION_IF_NULL(node);
302   auto kernel_info = node->kernel_info();
303   bool is_valid_cache = false;
304   if (kernel_info != nullptr) {
305     auto runtime_cache = kernel_info->runtime_cache();
306     if (runtime_cache.runtime_cache().is_valid()) {
307       ssize_t output_tensor_num = runtime_cache.runtime_cache().output_tensor_num();
308       if (output_tensor_num >= 0) {
309         return static_cast<size_t>(output_tensor_num);
310       }
311       is_valid_cache = true;
312     }
313   }
314 
315   size_t res = 1;
316   TypePtr type = node->Type();
317   if (type == nullptr) {
318     res = 0;
319   } else if (type->isa<Tuple>()) {
320     auto tuple_type = type->cast<TuplePtr>();
321     MS_EXCEPTION_IF_NULL(tuple_type);
322     res = tuple_type->size();
323     if (res == 0) {
324       return res;
325     }
326     auto last_type = tuple_type->elements()[res - 1];
327     MS_EXCEPTION_IF_NULL(last_type);
328     // Some nodes could have monad outputs like RpcRecv. We need to jump these outputs.
329     if (NeedJumpMonadOutput(node) && last_type->isa<MonadType>()) {
330       for (size_t i = 0; i < tuple_type->elements().size(); i++) {
331         auto tuple_type_elem = tuple_type->elements()[i];
332         MS_EXCEPTION_IF_NULL(tuple_type_elem);
333         if (tuple_type_elem->isa<MonadType>()) {
334           res = i;
335           break;
336         }
337       }
338     }
339   } else if (type->isa<List>()) {
340     auto list_type = type->cast<ListPtr>();
341     MS_EXCEPTION_IF_NULL(list_type);
342     res = list_type->size();
343   } else if (type->isa<TypeNone>()) {
344     res = 0;
345   } else if (type->isa<CSRTensorType>()) {
346     // Currently, CSRTensor only supports 2-D matrix (shape has 2 values). 5 outputs = 3 Tensors + 2 shape values.
347     constexpr size_t kCSRTensorOutputNum = 5;
348     res = kCSRTensorOutputNum;
349   } else if (type->isa<COOTensorType>()) {
350     // Currently, COOTensor only supports 2-D matrix (shape has 2 values). 4 outputs = 2 Tensors + 2 shape values.
351     constexpr size_t kCOOTensorOutputNum = 4;
352     res = kCOOTensorOutputNum;
353   } else if (NeedJumpMonadOutput(node) && type->isa<MonadType>()) {
354     // Some nodes could have monad outputs like RpcRecv. We need to jump these outputs.
355     res = 0;
356   }
357 
358   if (is_valid_cache) {
359     kernel_info->runtime_cache().runtime_cache().set_output_tensor_num(static_cast<ssize_t>(res));
360   }
361   return res;
362 }
363 
SetNodeAttr(const std::string & key,const ValuePtr & value,const AnfNodePtr & node)364 void AnfUtils::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
365   MS_EXCEPTION_IF_NULL(node);
366   if (!node->isa<CNode>()) {
367     MS_LOG(INTERNAL_EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString()
368                                << trace::DumpSourceLines(node);
369   }
370   // single op cnode.
371   auto primitive = GetCNodePrimitive(node);
372   if (primitive != nullptr) {
373     primitive->set_attr(key, value);
374     return;
375   }
376   // graph kernel cnode.
377   auto fg = GetCNodeFuncGraph(node);
378   MS_EXCEPTION_IF_NULL(fg);
379   fg->set_attr(key, value);
380 }
381 
GetIntValue(const AnfNodePtr & anf_node)382 int64_t AnfUtils::GetIntValue(const AnfNodePtr &anf_node) {
383   MS_EXCEPTION_IF_NULL(anf_node);
384   auto value_node = anf_node->cast<ValueNodePtr>();
385   MS_EXCEPTION_IF_NULL(value_node);
386   auto value = value_node->value();
387   return GetIntValue(value);
388 }
389 
GetIntValue(const ValuePtr & value)390 int64_t AnfUtils::GetIntValue(const ValuePtr &value) {
391   MS_EXCEPTION_IF_NULL(value);
392   if (value->isa<Int64Imm>()) {
393     return GetValue<int64_t>(value);
394   } else if (value->isa<Int32Imm>()) {
395     return IntToLong(GetValue<int>(value));
396   } else {
397     MS_LOG(EXCEPTION) << "The value should be Int32Imm or Int64Imm, but got " << value->ToString();
398   }
399 }
400 
VisitKernel(const AnfNodePtr & anf_node,size_t index)401 std::pair<AnfNodePtr, size_t> AnfUtils::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
402   MS_EXCEPTION_IF_NULL(anf_node);
403   const PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad};
404   if (anf_node->isa<ValueNode>()) {
405     return std::make_pair(anf_node, 0);
406   } else if (anf_node->isa<Parameter>()) {
407     return std::make_pair(anf_node, 0);
408   } else if (IsCustomActorNode(anf_node)) {
409     return std::make_pair(anf_node, 0);
410   } else if (anf_node->isa<CNode>()) {
411     auto cnode = anf_node->cast<CNodePtr>();
412     MS_EXCEPTION_IF_NULL(cnode);
413     auto input0 = cnode->input(0);
414     MS_EXCEPTION_IF_NULL(input0);
415     if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
416       if (GetInputTensorNum(cnode) == 0) {
417         return std::make_pair(nullptr, 0);
418       }
419       auto node = cnode->input(index + IntToSize(1));
420       MS_EXCEPTION_IF_NULL(node);
421       return VisitKernel(node, 0);
422     } else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
423       if (cnode->size() != kTupleGetItemInputSize) {
424         MS_LOG(INTERNAL_EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
425       }
426       auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
427       auto item_idx = AnfUtils::GetIntValue(input2);
428       return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), LongToSize(item_idx));
429     } else if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
430       return VisitKernel(cnode->input(kUpdateStateRealInput), 0);
431     } else if (IsOneOfPrimitive(input0, follow_first_input_prims)) {
432       return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
433     } else {
434       return std::make_pair(anf_node, index);
435     }
436   } else {
437     MS_LOG(INTERNAL_EXCEPTION) << "The input is invalid";
438   }
439 }
440 
IsGraphKernel(const AnfNodePtr & node)441 bool AnfUtils::IsGraphKernel(const AnfNodePtr &node) {
442   MS_EXCEPTION_IF_NULL(node);
443   auto func_graph = GetCNodeFuncGraph(node);
444   return func_graph != nullptr && func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
445 }
446 
IsNodeInGraphKernel(const AnfNodePtr & node)447 bool AnfUtils::IsNodeInGraphKernel(const AnfNodePtr &node) {
448   MS_EXCEPTION_IF_NULL(node);
449   return node->func_graph() != nullptr && node->func_graph()->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
450 }
451 
SetDumpFlag(const AnfNodePtr & node)452 void AnfUtils::SetDumpFlag(const AnfNodePtr &node) {
453   if (node == nullptr || !node->isa<CNode>()) {
454     return;
455   }
456   auto prim = GetCNodePrimitive(node);
457   if (prim != nullptr) {
458     prim->set_attr(kAttrDump, MakeValue(kValueTrue));
459   }
460 }
461 
GetDumpFlag(const AnfNodePtr & node)462 bool AnfUtils::GetDumpFlag(const AnfNodePtr &node) {
463   if (node == nullptr || !node->isa<CNode>()) {
464     return false;
465   }
466   auto prim = GetCNodePrimitive(node);
467   if (prim != nullptr) {
468     auto attr = prim->GetAttr(kAttrDump);
469     if (attr != nullptr && attr->isa<StringImm>() && attr->cast<StringImmPtr>()->value() == kValueTrue) {
470       return true;
471     }
472   }
473   return false;
474 }
475 
HasDumpFlag(const AnfNodePtr & node)476 bool AnfUtils::HasDumpFlag(const AnfNodePtr &node) {
477   if (node == nullptr || !node->isa<CNode>()) {
478     return false;
479   }
480   auto prim = GetCNodePrimitive(node);
481   if (prim != nullptr) {
482     return prim->HasAttr(kAttrDump);
483   }
484   return false;
485 }
486 
IsCustomActorNode(const AnfNodePtr & node)487 bool AnfUtils::IsCustomActorNode(const AnfNodePtr &node) {
488   MS_EXCEPTION_IF_NULL(node);
489   return node->has_user_data<CustomActorInfo>();
490 }
491 
IsCutomActorNodeSame(const AnfNodePtr & node1,const AnfNodePtr & node2)492 bool AnfUtils::IsCutomActorNodeSame(const AnfNodePtr &node1, const AnfNodePtr &node2) {
493   MS_EXCEPTION_IF_NULL(node1);
494   MS_EXCEPTION_IF_NULL(node2);
495   if (!IsCustomActorNode(node1) || !IsCustomActorNode(node2)) {
496     MS_LOG(INTERNAL_EXCEPTION) << "Two node are not all Custom Actor Node!";
497   }
498 
499   auto actor_info1 = node1->user_data<CustomActorInfo>();
500   MS_EXCEPTION_IF_NULL(actor_info1);
501   std::string actor_type1 = actor_info1->type_name;
502 
503   auto actor_info2 = node2->user_data<CustomActorInfo>();
504   MS_EXCEPTION_IF_NULL(actor_info2);
505   std::string actor_type2 = actor_info2->type_name;
506 
507   return (actor_type1 == actor_type2);
508 }
509 
GetCustomActorType(const AnfNodePtr & node)510 std::string AnfUtils::GetCustomActorType(const AnfNodePtr &node) {
511   MS_EXCEPTION_IF_NULL(node);
512   if (!IsCustomActorNode(node)) {
513     MS_LOG(INTERNAL_EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
514   }
515 
516   auto actor_info = node->user_data<CustomActorInfo>();
517   MS_EXCEPTION_IF_NULL(actor_info);
518   return actor_info->type_name;
519 }
520 
GetCustomActorName(const AnfNodePtr & node)521 std::string AnfUtils::GetCustomActorName(const AnfNodePtr &node) {
522   MS_EXCEPTION_IF_NULL(node);
523   if (!IsCustomActorNode(node)) {
524     MS_LOG(INTERNAL_EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
525   }
526 
527   auto actor_info = node->user_data<CustomActorInfo>();
528   MS_EXCEPTION_IF_NULL(actor_info);
529   auto base_node = actor_info->base_cnode_ptr.lock();
530   MS_EXCEPTION_IF_NULL(base_node);
531   std::string actor_name = actor_info->type_name + "_of_" + base_node->fullname_with_scope();
532   return actor_name;
533 }
534 
GetCustomActorBaseNode(const AnfNodePtr & node)535 CNodePtr AnfUtils::GetCustomActorBaseNode(const AnfNodePtr &node) {
536   MS_EXCEPTION_IF_NULL(node);
537   if (!IsCustomActorNode(node)) {
538     MS_LOG(INTERNAL_EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
539   }
540 
541   auto actor_info = node->user_data<CustomActorInfo>();
542   MS_EXCEPTION_IF_NULL(actor_info);
543   return actor_info->base_cnode_ptr.lock();
544 }
545 
GetCustomFunc(const AnfNodePtr & node)546 AnfUtils::CustomActorCallback AnfUtils::GetCustomFunc(const AnfNodePtr &node) {
547   MS_EXCEPTION_IF_NULL(node);
548   if (!IsCustomActorNode(node)) {
549     MS_LOG(INTERNAL_EXCEPTION) << node->fullname_with_scope() << " is not a custom actor node!";
550   }
551 
552   auto actor_info = node->user_data<CustomActorInfo>();
553   MS_EXCEPTION_IF_NULL(actor_info);
554   return actor_info->actor_func;
555 }
556 
NewInitActorNode(AnfUtils::CustomActorCallback f,const CNodePtr & base_cnode)557 AnfNodePtr AnfUtils::NewInitActorNode(AnfUtils::CustomActorCallback f, const CNodePtr &base_cnode) {
558   MS_EXCEPTION_IF_NULL(base_cnode);
559   auto actor_info = std::make_shared<CustomActorInfo>(f, kInit, base_cnode);
560   return NewCustomActorNode(actor_info, base_cnode->func_graph());
561 }
562 
NewInferActorNode(AnfUtils::CustomActorCallback f,const CNodePtr & base_cnode)563 AnfNodePtr AnfUtils::NewInferActorNode(AnfUtils::CustomActorCallback f, const CNodePtr &base_cnode) {
564   MS_EXCEPTION_IF_NULL(base_cnode);
565   auto actor_info = std::make_shared<CustomActorInfo>(f, kInfer, base_cnode);
566   return NewCustomActorNode(actor_info, base_cnode->func_graph());
567 }
568 
SetCustomInfoToBaseNode(const AnfNodePtr & base_cnode,const AnfNodePtr & inferop,const AnfNodePtr & initop)569 void AnfUtils::SetCustomInfoToBaseNode(const AnfNodePtr &base_cnode, const AnfNodePtr &inferop,
570                                        const AnfNodePtr &initop) {
571   MS_EXCEPTION_IF_NULL(base_cnode);
572   MS_EXCEPTION_IF_NULL(inferop);
573   MS_EXCEPTION_IF_NULL(initop);
574 
575   auto actor_info = std::make_shared<CNodeCustomInfo>(inferop, initop);
576   base_cnode->set_user_data<CNodeCustomInfo>(actor_info);
577 }
578 
GetCustomInferopNode(const AnfNodePtr & base_cnode)579 AnfNodePtr AnfUtils::GetCustomInferopNode(const AnfNodePtr &base_cnode) {
580   MS_EXCEPTION_IF_NULL(base_cnode);
581   auto actor_info = base_cnode->user_data<CNodeCustomInfo>();
582   if (actor_info == nullptr) {
583     return nullptr;
584   }
585   return actor_info->infer_node.lock();
586 }
587 
GetRealInputNodes(const CNodePtr & cnode)588 mindspore::HashMap<size_t, std::pair<AnfNodeWeakPtr, size_t>> &AnfUtils::GetRealInputNodes(const CNodePtr &cnode) {
589   MS_EXCEPTION_IF_NULL(cnode);
590   auto real_input_info = cnode->user_data<RealInputInfo>();
591   if (real_input_info == nullptr) {
592     real_input_info = std::make_shared<RealInputInfo>(cnode);
593     cnode->set_user_data(real_input_info);
594   }
595   return real_input_info->real_input_nodes;
596 }
597 
NeedJumpMonadOutput(const AnfNodePtr & node)598 bool AnfUtils::NeedJumpMonadOutput(const AnfNodePtr &node) {
599   MS_EXCEPTION_IF_NULL(node);
600   auto cnode = node->cast<CNodePtr>();
601   if (cnode == nullptr) {
602     return false;
603   }
604 
605   std::vector<std::string> jump_monad_output_nodes = {kRpcRecvOpName, prim::kPrimConditionSwitch->name(),
606                                                       prim::kPrimConditionGather->name()};
607   if (std::find(jump_monad_output_nodes.begin(), jump_monad_output_nodes.end(), GetCNodeName(cnode)) !=
608       jump_monad_output_nodes.end()) {
609     return true;
610   }
611   return false;
612 }
613 
AddParameter(const ParameterPtr & param)614 void FlatParameterFinder::AddParameter(const ParameterPtr &param) {
615   auto tensor = dyn_cast<tensor::Tensor>(param->default_param());
616   if (tensor == nullptr) {
617     return;
618   }
619   auto [chunk, offset] = tensor->GetChunkOffset();
620   if (chunk != nullptr) {
621     (void)param_to_flat_param_.emplace(param, FlatParamInfo{nullptr, chunk, offset});
622     return;
623   }
624   if (tensor->shape_c().size() == 1) {
625     (void)candidate_flat_params_.emplace(tensor->data_c(), param);
626   }
627 }
628 
AddNodes(const std::vector<AnfNodePtr> & nodes)629 void FlatParameterFinder::AddNodes(const std::vector<AnfNodePtr> &nodes) {
630   for (auto &node : nodes) {
631     auto param = dyn_cast<Parameter>(node);
632     if (param != nullptr) {
633       AddParameter(param);
634     }
635   }
636 }
637 
UpdateFlatParameters()638 void FlatParameterFinder::UpdateFlatParameters() {
639   if (candidate_flat_params_.empty()) {
640     return;
641   }
642   for (auto &entry : param_to_flat_param_) {
643     auto &info = entry.second;
644     if (info.flat_param == nullptr) {
645       auto iter = candidate_flat_params_.find(info.chunk);
646       if (iter != candidate_flat_params_.end()) {
647         (void)flat_params_.emplace(iter->second);
648         info.flat_param = iter->second;
649       }
650     }
651   }
652   candidate_flat_params_.clear();
653 }
654 
FindFlatParameter(const ParameterPtr & param)655 std::pair<ParameterPtr, size_t> FlatParameterFinder::FindFlatParameter(const ParameterPtr &param) {
656   UpdateFlatParameters();
657   auto iter = param_to_flat_param_.find(param);
658   if (iter == param_to_flat_param_.end()) {
659     return {nullptr, 0};
660   }
661   auto &flat_param = iter->second.flat_param;
662   if (flat_param == nullptr) {
663     MS_LOG(WARNING) << "Find flat Parameter for " << param->ToString() << " failed";
664     return {nullptr, 0};
665   }
666   return {flat_param, iter->second.offset};
667 }
668 
GetFlatParameters()669 const std::set<ParameterPtr> &FlatParameterFinder::GetFlatParameters() {
670   UpdateFlatParameters();
671   return flat_params_;
672 }
673 }  // namespace mindspore
674