• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "ir/anf.h"
20 
21 #include <algorithm>
22 #include <sstream>
23 #include <vector>
24 #include <queue>
25 #include <unordered_map>
26 
27 #include "base/core_ops.h"
28 #include "ir/func_graph.h"
29 #include "ir/primitive.h"
30 #include "utils/ms_context.h"
31 
32 namespace mindspore {
33 // namespace to support intermediate representation definition
CNode(const std::vector<AnfNodePtr> & inputs,const FuncGraphPtr & func_graph)34 CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
35     : AnfNode(func_graph),
36       inputs_(inputs),
37       stop_gradient_(false),
38       output_value_(std::make_pair(nullptr, "")),
39       input_tensor_num_(-1) {
40   primal_attrs_ = PrimalAttrManager::GetInstance().GetCurrentPrimalAttr();
41   primal_debug_infos_ = PrimalDebugInfoManager::GetInstance().GetCurrentPrimalDebugInfo();
42 }
43 
44 // Check if CNode is an apply with the specific Primitive.
IsApply(const PrimitivePtr & value) const45 bool CNode::IsApply(const PrimitivePtr &value) const {
46   if (value == nullptr) {
47     return false;
48   }
49 
50   if (inputs_.size() != 0 && IsValueNode<Primitive>(inputs_[0])) {
51     PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(inputs_[0]);
52     if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) {
53       return true;
54     }
55   }
56 
57   return false;
58 }
59 
add_input(const AnfNodePtr & input)60 void CNode::add_input(const AnfNodePtr &input) {
61   inputs_.push_back(input);
62   input_tensor_num_ = -1;
63 }
64 
set_input(size_t i,const AnfNodePtr & new_input)65 void CNode::set_input(size_t i, const AnfNodePtr &new_input) {
66   if (i >= inputs_.size()) {
67     MS_LOG(EXCEPTION) << "i:" << i << " out of range:" << inputs_.size() << ", cnode:" << DebugString();
68   }
69   inputs_[i] = new_input;
70   input_tensor_num_ = -1;
71 }
72 
set_inputs(const std::vector<AnfNodePtr> & inputs)73 void CNode::set_inputs(const std::vector<AnfNodePtr> &inputs) {
74   inputs_ = inputs;
75   input_tensor_num_ = -1;
76 }
77 
input(size_t i) const78 const AnfNodePtr &CNode::input(size_t i) const {
79   if (i >= inputs_.size()) {
80     MS_LOG(EXCEPTION) << "i:" << i << "out of range:" << inputs_.size() << ", cnode:" << DebugString();
81   }
82   return inputs_.at(i);
83 }
84 
DebugString(int recursive_level) const85 std::string CNode::DebugString(int recursive_level) const {
86   std::ostringstream buffer;
87   if (recursive_level > 0) {
88     if (func_graph() != nullptr) {
89       buffer << func_graph()->ToString() << ":";
90     }
91     buffer << ToString() << "{";
92     bool is_first_node = true;
93     int idx = 0;
94     for (auto &node : inputs_) {
95       MS_EXCEPTION_IF_NULL(node);
96       if (is_first_node) {
97         is_first_node = false;
98       } else {
99         buffer << ", ";
100       }
101       buffer << "[" << idx << "]: " << node->DebugString(recursive_level - 1);
102       idx++;
103     }
104     buffer << "}";
105   } else {
106     buffer << ToString();
107   }
108   return buffer.str();
109 }
110 
DebugString(int recursive_level) const111 std::string Parameter::DebugString(int recursive_level) const {
112   std::ostringstream buffer;
113   if (recursive_level > 0) {
114     if (func_graph() != nullptr) {
115       buffer << func_graph()->ToString() << ":";
116     }
117   }
118   buffer << ToString();
119   return buffer.str();
120 }
121 
param_info() const122 ParamInfoPtr Parameter::param_info() const {
123   if (!has_default()) {
124     return nullptr;
125   }
126   auto tensor = default_param()->cast<tensor::MetaTensorPtr>();
127   if (tensor == nullptr || !tensor->is_parameter()) {
128     return nullptr;
129   }
130   return tensor->param_info();
131 }
132 
ToString() const133 std::string ValueNode::ToString() const {
134   MS_EXCEPTION_IF_NULL(value_);
135   if (value_->isa<FuncGraph>()) {
136     return value_->ToString();
137   }
138   std::ostringstream buffer;
139   buffer << AnfNode::ToString();
140   buffer << "(" << value_->ToString() << ")";
141   return buffer.str();
142 }
143 
DebugString(int) const144 std::string ValueNode::DebugString(int) const {
145   MS_EXCEPTION_IF_NULL(value_);
146   std::ostringstream buffer;
147   buffer << "ValueNode<" << value_->type_name() << "> " << value_->ToString();
148   return buffer.str();
149 }
150 
fullname_with_scope()151 std::string ValueNode::fullname_with_scope() {
152   if (!fullname_with_scope_.empty()) {
153     return fullname_with_scope_;
154   }
155 
156   MS_EXCEPTION_IF_NULL(scope());
157   fullname_with_scope_ = scope()->name() + "/" + "data-" + id_generator::get_id(shared_from_base<ValueNode>());
158   return fullname_with_scope_;
159 }
160 
IsPrimitiveCNode(const AnfNodePtr & node,const PrimitivePtr & value)161 bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
162   auto cnode = dyn_cast<CNode>(node);
163   if (cnode == nullptr) {
164     return false;
165   }
166   if (value != nullptr) {
167     return cnode->IsApply(value);
168   }
169   const auto &prim = GetValueNode<PrimitivePtr>(cnode->input(0));
170   return prim != nullptr;
171 }
172 
GetCNodePrimitive(const AnfNodePtr & node)173 PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
174   auto cnode = dyn_cast<CNode>(node);
175   if (cnode != nullptr) {
176     if (cnode->size() > 0) {
177       auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
178       return prim;
179     }
180   }
181   return nullptr;
182 }
183 
GetCNodeFuncName(const CNodePtr cnode)184 std::string GetCNodeFuncName(const CNodePtr cnode) {
185   if (cnode->inputs().empty()) {
186     return "";
187   }
188 
189   AnfNodePtr valuenode = cnode->input(0);
190   auto value = GetValueNode(valuenode);
191   if (value != nullptr) {
192     // check whether the valuenode is primitive
193     if (value->isa<Primitive>()) {
194       return value->cast<PrimitivePtr>()->name();
195     }
196     return value->ToString();
197   }
198   return "";
199 }
200 
GetCNodeFuncGraph(const AnfNodePtr & node)201 FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node) {
202   auto cnode = dyn_cast<CNode>(node);
203   if (cnode != nullptr && cnode->size() > 0) {
204     return GetValueNode<FuncGraphPtr>(cnode->input(0));
205   }
206   return nullptr;
207 }
208 
IsPrimitive(const AnfNodePtr & node,const PrimitivePtr & value)209 bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) {
210   if (IsValueNode<Primitive>(node)) {
211     PrimitivePtr fn_value = GetValueNode<PrimitivePtr>(node);
212     MS_EXCEPTION_IF_NULL(value);
213     if (fn_value->Hash() == value->Hash() && fn_value->name() == value->name()) {
214       return true;
215     }
216   }
217   return false;
218 }
219 
IsPrimitiveEquals(const PrimitivePtr & prim1,const PrimitivePtr & prim2)220 bool IsPrimitiveEquals(const PrimitivePtr &prim1, const PrimitivePtr &prim2) {
221   if (prim1 == nullptr || prim2 == nullptr) {
222     return false;
223   }
224   return (prim1 == prim2) || (prim1->Hash() == prim2->Hash() && prim1->name() == prim2->name());
225 }
226 
GetAbstractMonadNum(const AbstractBasePtrList & args)227 size_t GetAbstractMonadNum(const AbstractBasePtrList &args) {
228   size_t num = 0;
229   for (auto &arg : args) {
230     if (arg->isa<abstract::AbstractMonad>()) {
231       ++num;
232     }
233   }
234   return num;
235 }
236 
237 template <typename T>
HasAbstract(const AnfNodePtr & node)238 bool HasAbstract(const AnfNodePtr &node) {
239   if (node == nullptr) {
240     return false;
241   }
242   const auto &abs = node->abstract();
243   return (abs != nullptr && abs->isa<T>());
244 }
245 
HasAbstractMonad(const AnfNodePtr & node)246 bool HasAbstractMonad(const AnfNodePtr &node) { return HasAbstract<abstract::AbstractMonad>(node); }
247 
HasAbstractUMonad(const AnfNodePtr & node)248 bool HasAbstractUMonad(const AnfNodePtr &node) { return HasAbstract<abstract::AbstractUMonad>(node); }
249 
HasAbstractIOMonad(const AnfNodePtr & node)250 bool HasAbstractIOMonad(const AnfNodePtr &node) { return HasAbstract<abstract::AbstractIOMonad>(node); }
251 
GetPrimitiveFlag(const PrimitivePtr & prim,const std::string & attr)252 bool GetPrimitiveFlag(const PrimitivePtr &prim, const std::string &attr) {
253   if (prim != nullptr) {
254     auto flag = prim->GetAttr(attr);
255     if (flag && flag->isa<BoolImm>()) {
256       return GetValue<bool>(flag);
257     }
258   }
259   return false;
260 }
261 
GetPrimEffectInfo(const PrimitivePtr & prim)262 EffectInfo GetPrimEffectInfo(const PrimitivePtr &prim) {
263   bool mem = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_MEM);
264   bool io = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_IO);
265   return {EffectInfo::kDetected, mem, io, false};
266 }
267 
GetMonadState(const AnfNodePtr & node,const AnfNodePtr & skip_input)268 MonadState GetMonadState(const AnfNodePtr &node, const AnfNodePtr &skip_input) {
269   if (node == nullptr) {
270     return {};
271   }
272   MonadState state;
273   size_t seen = NewSeenGeneration();
274   std::queue<AnfNodePtr> que;
275   que.push(node);
276   while (!que.empty()) {
277     auto n = que.front();
278     que.pop();
279 
280     // check whether this node has been matched or should be skipped.
281     if (n == nullptr || n->seen_ == seen || n == skip_input) {
282       continue;
283     }
284     n->seen_ = seen;
285 
286     // check whether this node has monad abstract.
287     if (state.u == nullptr && HasAbstractUMonad(n)) {
288       state.u = n;
289     } else if (state.io == nullptr && HasAbstractIOMonad(n)) {
290       state.io = n;
291     } else {
292       auto cnode = dyn_cast<CNode>(n);
293       if (cnode != nullptr) {
294         for (auto it = cnode->inputs().rbegin(); it != cnode->inputs().rend(); ++it) {
295           que.push(*it);
296         }
297       }
298       continue;
299     }
300 
301     if (state.u != nullptr && state.io != nullptr) {
302       return state;
303     }
304   }
305   return state;
306 }
307 
IsStateEquivalent(const MonadState & state1,const MonadState & state2)308 bool IsStateEquivalent(const MonadState &state1, const MonadState &state2) {
309   return (state1.u == nullptr || state2.u == nullptr || state1.u == state2.u) &&
310          (state1.io == nullptr || state2.io == nullptr || state1.io == state2.io);
311 }
312 
IsStateStrictEquivalent(const AnfNodePtr & outer,const AnfNodePtr & inner)313 bool IsStateStrictEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner) {
314   MonadState state_matmul = GetMonadState(inner);
315   MonadState state_node = GetMonadState(outer, inner);
316   return IsStateEquivalent(state_matmul, state_node);
317 }
318 
GetLoadInputs(const AnfNodePtr & node)319 std::set<CNodePtr> GetLoadInputs(const AnfNodePtr &node) {
320   std::set<CNodePtr> loads;
321   auto cnode = dyn_cast<CNode>(node);
322   if (cnode == nullptr) {
323     return loads;
324   }
325   auto &inputs = cnode->inputs();
326   for (size_t i = 1; i < inputs.size(); ++i) {
327     auto &input = inputs.at(i);
328     if (IsPrimitiveCNode(input, prim::kPrimLoad)) {
329       loads.insert(input->cast<CNodePtr>());
330     } else if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
331       loads.merge(GetLoadInputs(input));
332     }
333   }
334   return loads;
335 }
336 
IsStateEquivalent(const AnfNodePtr & outer,const AnfNodePtr & inner)337 bool IsStateEquivalent(const AnfNodePtr &outer, const AnfNodePtr &inner) {
338   constexpr size_t kMonadInput = 2;
339   auto outer_loads = GetLoadInputs(outer);
340   if (outer_loads.empty()) {
341     return true;
342   }
343   auto inner_loads = GetLoadInputs(inner);
344   if (inner_loads.empty()) {
345     return true;
346   }
347   outer_loads.merge(inner_loads);
348   auto &monad = (*outer_loads.begin())->inputs().at(kMonadInput);
349   return std::all_of(++outer_loads.begin(), outer_loads.end(),
350                      [&monad, kMonadInput](const CNodePtr &load) { return load->inputs().at(kMonadInput) == monad; });
351 }
352 
NewSeenGeneration()353 size_t NewSeenGeneration() {
354   static size_t seen_generation = 0;
355   return ++seen_generation;
356 }
357 
358 namespace id_generator {
359 static std::unordered_map<std::string, int> node_ids;
get_id(const AnfNodePtr & node)360 std::string get_id(const AnfNodePtr &node) {
361   auto type_name = node->type_name();
362   if (node_ids.find(type_name) == node_ids.end()) {
363     node_ids[type_name] = 0;
364   } else {
365     node_ids[type_name]++;
366   }
367   return std::to_string(node_ids[type_name]);
368 }
369 
reset_id()370 void reset_id() { node_ids.clear(); }
371 }  // namespace id_generator
372 auto constexpr kTargetUnDefined = "kTargetUnDefined";
373 auto constexpr kPrimitiveTarget = "primitive_target";
374 namespace {
GetPrimitiveFromValueNode(const AnfNodePtr & node)375 PrimitivePtr GetPrimitiveFromValueNode(const AnfNodePtr &node) {
376   if (node == nullptr) {
377     return nullptr;
378   }
379   auto value_node = node->cast<ValueNodePtr>();
380   if (value_node == nullptr) {
381     return nullptr;
382   }
383   auto value = value_node->value();
384   if (value == nullptr || !value->isa<Primitive>()) {
385     return nullptr;
386   }
387   return value->cast<PrimitivePtr>();
388 }
389 
GetVirtualNodeTargetFromInputs(const AnfNodePtr & node)390 std::string GetVirtualNodeTargetFromInputs(const AnfNodePtr &node) {
391   MS_EXCEPTION_IF_NULL(node);
392   auto cnode = node->cast<CNodePtr>();
393   MS_EXCEPTION_IF_NULL(cnode);
394   auto &inputs = cnode->inputs();
395 #ifndef ENABLE_SECURITY
396   if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
397       IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary)) {
398     if (inputs.size() > 1) {
399       return GetOriginNodeTarget(inputs[1]);
400     }
401     return kTargetUnDefined;
402   }
403 #endif
404   if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad)) {
405     const size_t node_inputs_num = 3;
406     if (inputs.size() >= node_inputs_num) {
407       size_t use_index = 1;
408       if (!inputs[use_index]->isa<CNode>()) {
409         use_index = 2;
410       }
411       return GetOriginNodeTarget(inputs[use_index]);
412     }
413   } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) || IsPrimitiveCNode(node, prim::kPrimUpdateState)) {
414     std::vector<AnfNodePtr> real_inputs;
415     const size_t update_state_valid_input_index = 2;
416     const size_t make_tuple_valid_input_index = 1;
417     if (IsPrimitiveCNode(node, prim::kPrimUpdateState) && inputs.size() > update_state_valid_input_index) {
418       (void)std::copy(inputs.begin() + SizeToLong(update_state_valid_input_index), inputs.end(),
419                       std::back_inserter(real_inputs));
420     } else if (IsPrimitiveCNode(node, prim::kPrimMakeTuple) && inputs.size() > make_tuple_valid_input_index) {
421       (void)std::copy(inputs.begin() + SizeToLong(make_tuple_valid_input_index), inputs.end(),
422                       std::back_inserter(real_inputs));
423     }
424     std::string first_input_target = kTargetUnDefined;
425     bool has_diff_target =
426       std::any_of(std::rbegin(real_inputs), std::rend(real_inputs), [&first_input_target](const AnfNodePtr &n) {
427         auto target = GetOriginNodeTarget(n);
428         if (target == kTargetUnDefined) {
429           return false;
430         }
431         if (first_input_target == kTargetUnDefined) {
432           first_input_target = target;
433         }
434         return target != first_input_target;
435       });
436     if (!has_diff_target) {
437       return first_input_target;
438     }
439   } else if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
440     return GetOriginNodeTarget(cnode->input(1));
441   }
442   return kTargetUnDefined;
443 }
444 
GetVirtualNodeTargetFromUsers(const AnfNodePtr & node)445 std::string GetVirtualNodeTargetFromUsers(const AnfNodePtr &node) {
446   MS_EXCEPTION_IF_NULL(node);
447   auto cnode = node->cast<CNodePtr>();
448   MS_EXCEPTION_IF_NULL(cnode);
449   auto func_graph = cnode->func_graph();
450   if (func_graph == nullptr) {
451     return kTargetUnDefined;
452   }
453   auto manager = func_graph->manager();
454   if (manager == nullptr) {
455     return kTargetUnDefined;
456   }
457   auto users = manager->node_users()[cnode];
458   std::string first_user_target = kTargetUnDefined;
459   bool has_diff_target =
460     std::any_of(std::begin(users), std::end(users), [&first_user_target](const std::pair<AnfNodePtr, int> &u) {
461       auto target = GetOriginNodeTarget(u.first);
462       if (target == kTargetUnDefined) {
463         return false;
464       }
465       if (first_user_target == kTargetUnDefined) {
466         first_user_target = target;
467       }
468       return target != first_user_target;
469     });
470   if (!has_diff_target) {
471     return first_user_target;
472   }
473   return kTargetUnDefined;
474 }
475 
GetVirtualNodeTarget(const AnfNodePtr & node)476 std::string GetVirtualNodeTarget(const AnfNodePtr &node) {
477   MS_EXCEPTION_IF_NULL(node);
478   node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(kTargetUnDefined));
479   auto target = GetVirtualNodeTargetFromInputs(node);
480   node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(target));
481   if (target != kTargetUnDefined) {
482     return target;
483   }
484   target = GetVirtualNodeTargetFromUsers(node);
485   node->set_user_data(kPrimitiveTarget, std::make_shared<std::string>(target));
486   return target;
487 }
488 
GetTargetFromAttr(const AnfNodePtr & node)489 std::string GetTargetFromAttr(const AnfNodePtr &node) {
490   MS_EXCEPTION_IF_NULL(node);
491   auto cnode = node->cast<CNodePtr>();
492   MS_EXCEPTION_IF_NULL(cnode);
493   auto attr_input = cnode->input(0);
494   auto primitive = GetPrimitiveFromValueNode(attr_input);
495   if (primitive == nullptr) {
496     return kTargetUnDefined;
497   }
498   auto att_target = primitive->GetAttr(kPrimitiveTarget);
499   if (att_target != nullptr) {
500     if (!att_target->isa<StringImm>()) {
501       MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target";
502     }
503     auto target = GetValue<std::string>(att_target);
504     if (kTargetSet.find(target) == kTargetSet.end()) {
505       MS_LOG(EXCEPTION) << "Only support string CPU|GPU|Ascend for primitive_target, but get " << target;
506     }
507     return target;
508   }
509   return kTargetUnDefined;
510 }
511 }  // namespace
512 
GetOriginNodeTarget(const AnfNodePtr & node)513 std::string GetOriginNodeTarget(const AnfNodePtr &node) {
514   MS_EXCEPTION_IF_NULL(node);
515   if (!node->isa<CNode>()) {
516     return kTargetUnDefined;
517   }
518   auto cnode = node->cast<CNodePtr>();
519   MS_EXCEPTION_IF_NULL(cnode);
520   auto ud_target = cnode->user_data<std::string>(kPrimitiveTarget);
521   if (ud_target != nullptr) {
522     return *ud_target.get();
523   }
524   auto target = GetTargetFromAttr(node);
525   if (target != kTargetUnDefined) {
526     return target;
527   }
528 #ifndef ENABLE_SECURITY
529   if (IsPrimitiveCNode(node, prim::kPrimImageSummary) || IsPrimitiveCNode(node, prim::kPrimScalarSummary) ||
530       IsPrimitiveCNode(node, prim::kPrimTensorSummary) || IsPrimitiveCNode(node, prim::kPrimHistogramSummary) ||
531       IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
532       IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
533       IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
534     return GetVirtualNodeTarget(node);
535   }
536 #else
537   if (IsPrimitiveCNode(node, prim::kPrimDepend) || IsPrimitiveCNode(node, prim::kPrimLoad) ||
538       IsPrimitiveCNode(node, prim::kPrimUpdateState) || IsPrimitiveCNode(node, prim::kPrimMakeTuple) ||
539       IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
540     return GetVirtualNodeTarget(node);
541   }
542 #endif
543   auto context_ptr = MsContext::GetInstance();
544   MS_EXCEPTION_IF_NULL(context_ptr);
545   return context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
546 }
547 
GetCNodeTarget(const AnfNodePtr & node)548 std::string GetCNodeTarget(const AnfNodePtr &node) {
549   auto context_ptr = MsContext::GetInstance();
550   MS_EXCEPTION_IF_NULL(context_ptr);
551   std::string default_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
552   auto target = GetOriginNodeTarget(node);
553   if (target != kTargetUnDefined) {
554     return target;
555   }
556   return default_target;
557 }
558 
ContainMultiTarget(const std::vector<AnfNodePtr> & nodes)559 bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes) {
560   auto context_ptr = MsContext::GetInstance();
561   MS_EXCEPTION_IF_NULL(context_ptr);
562   std::string last_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
563   for (auto &node : nodes) {
564     if (node->isa<CNode>()) {
565       std::string cur_target = GetCNodeTarget(node);
566       if (last_target != cur_target) {
567         return true;
568       }
569       last_target = cur_target;
570     }
571   }
572   return false;
573 }
574 }  // namespace mindspore
575