• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "include/transform/graph_ir/utils.h"
17 #include "mindspore/core/ops/other_ops.h"
18 #include "mindspore/core/ops/framework_ops.h"
19 #include "mindspore/core/ops/sequence_op_name.h"
20 #include "transform/graph_ir/aoe_util.h"
21 #include "transform/graph_ir/convert.h"
22 #include "transform/graph_ir/op_adapter_map.h"
23 #include "transform/graph_ir/op_adapter_util.h"
24 #include "transform/graph_ir/df_graph_manager.h"
25 #include "transform/graph_ir/op_adapter_desc.h"
26 #include "transform/graph_ir/transform_util.h"
27 #include "transform/graph_ir/graph_builder.h"
28 #include "include/common/utils/anfalgo.h"
29 
30 namespace mindspore {
31 namespace transform {
32 namespace {
33 constexpr size_t kSwitchInputSize = 4;
34 constexpr size_t kSwitchCondIndex = 1;
35 constexpr size_t kSwitchTrueBranchIndex = 2;
36 constexpr size_t kSwitchFalseBranchIndex = 3;
37 constexpr size_t kPartialCNodeValue = 1;
38 }  // namespace
39 
FindAdapter(const AnfNodePtr node,bool train)40 OpAdapterPtr FindAdapter(const AnfNodePtr node, bool train) {
41   MS_EXCEPTION_IF_NULL(node);
42   if (node->isa<CNode>()) {
43     auto cnode = node->cast<CNodePtr>();
44 
45     std::string name = kNameCustomOp;
46     if (!IsCustomCNode(cnode)) {
47       name = GetCNodeTargetFuncName(cnode);
48     }
49 
50     // Convert TupleGetItem to control edge when it has monad.
51     if (name == kNameTupleGetItem) {
52       if (HasAbstractMonad(node)) {
53         name = kNameUpdateState;
54       }
55     }
56     auto it_adpt = OpAdapterMap::get().find(name);
57     if (it_adpt != OpAdapterMap::get().end()) {
58       return it_adpt->second->Get(train);
59     }
60 
61     std::set<std::string> cpu_only_ops{kRealMakeTupleOpName, kRealTupleGetItemOpName};
62     auto iter = cpu_only_ops.find(name);
63     if (iter != cpu_only_ops.end()) {
64       MS_LOG(INFO) << "Can't find OpAdapter for " << name;
65     } else {
66       MS_LOG(WARNING) << "Can't find OpAdapter for " << name;
67     }
68 
69     return OpAdapterPtr(nullptr);
70   }
71 
72   if (node->isa<ValueNode>()) {
73     return OpAdapterMap::get()[kNameConst]->Get(train);
74   }
75   if (node->isa<Parameter>()) {
76     return OpAdapterMap::get()[kNameParam]->Get(train);
77   }
78   return OpAdapterPtr(nullptr);
79 }
80 
FindAdapter(const std::string & name,bool train)81 OpAdapterPtr FindAdapter(const std::string &name, bool train) {
82   auto it = OpAdapterMap::get().find(name);
83   if (it != OpAdapterMap::get().end()) {
84     return it->second->Get(train);
85   }
86 
87   std::set<std::string> cpu_only_ops{kRealMakeTupleOpName, kRealTupleGetItemOpName, kShapeCalcOpName};
88   auto iter = cpu_only_ops.find(name);
89   // If ops in cpu only list or ops is scalar ops or is sequence ops
90   if (iter != cpu_only_ops.end() || name.find("Scalar") != std::string::npos ||
91       name.find("Sequence") != std::string::npos || name.find("Tuple") != std::string::npos ||
92       name.find("List") != std::string::npos) {
93     MS_LOG(INFO) << "Can't find OpAdapter for " << name;
94     return nullptr;
95   }
96   MS_LOG(WARNING) << "Can't find OpAdapter for " << name;
97   return nullptr;
98 }
99 
ClearGeSessionAndRunner()100 void ClearGeSessionAndRunner() {
101   DfGraphManager::GetInstance().DeleteGraphRunner();
102   DfGraphManager::GetInstance().DeleteGeSession();
103 }
104 
IsPartialSuccNode(const AnfNodePtr node)105 bool IsPartialSuccNode(const AnfNodePtr node) {
106   MS_EXCEPTION_IF_NULL(node);
107   if (!node->isa<CNode>()) {
108     return false;
109   }
110   auto cnode = node->cast<CNodePtr>();
111   if (!cnode->inputs().empty()) {
112     for (size_t i = 0; i < cnode->size(); i++) {
113       if (IsPartialCNode(cnode->input(i))) {
114         return true;
115       }
116     }
117   }
118   return false;
119 }
120 
IsPartialCNode(const AnfNodePtr node)121 bool IsPartialCNode(const AnfNodePtr node) {
122   MS_EXCEPTION_IF_NULL(node);
123   if (!node->isa<CNode>()) {
124     return false;
125   }
126   auto cnode = node->cast<CNodePtr>();
127   if (GetCNodeFuncName(cnode) == prim::kPrimPartial->name()) {
128     return true;
129   }
130   return false;
131 }
132 
IsWhileNode(const AnfNodePtr & node)133 bool IsWhileNode(const AnfNodePtr &node) {
134   if (!node->isa<CNode>()) {
135     return false;
136   }
137   auto graph = node->func_graph();
138   MS_EXCEPTION_IF_NULL(graph);
139   bool in_kg = graph->type_name() == kKernelGraphTypeName;
140   auto cnode = node->cast<CNodePtr>();
141   ValueNodePtr graph_node = nullptr;
142   if (in_kg && IsPrimitiveCNode(node, prim::kPrimCall) && cnode->input(1)->isa<ValueNode>()) {
143     graph_node = cnode->input(1)->cast<ValueNodePtr>();
144   }
145   if (!in_kg) {
146     if (IsPrimitiveCNode(cnode->input(0), prim::kPrimPartial)) {
147       auto partial_node = cnode->input(0)->cast<CNodePtr>();
148       MS_EXCEPTION_IF_NULL(partial_node);
149       auto graph_node_input = partial_node->input(1);
150       MS_EXCEPTION_IF_NULL(graph_node_input);
151       graph_node = graph_node_input->cast<ValueNodePtr>();
152     } else if (cnode->input(0)->cast<ValueNodePtr>()) {
153       graph_node = cnode->input(0)->cast<ValueNodePtr>();
154     }
155   }
156   if (graph_node == nullptr) {
157     return false;
158   }
159 
160   auto graph_node_value = graph_node->value();
161   MS_EXCEPTION_IF_NULL(graph_node_value);
162   if (!graph_node_value->isa<FuncGraph>()) {
163     return false;
164   }
165   auto cond_graph = graph_node_value->cast<FuncGraphPtr>();
166   MS_EXCEPTION_IF_NULL(cond_graph);
167   if (!cond_graph->recursive()) {
168     return false;
169   }
170   const auto &cond_set = cond_graph->nodes();
171   for (auto beg = cond_set.begin(); beg != cond_set.end(); ++beg) {
172     if (!((*beg)->isa<CNode>())) {
173       continue;
174     }
175     auto c_beg = (*beg)->cast<CNodePtr>();
176     if (IsPrimitiveCNode(c_beg, prim::kPrimSwitch)) {
177       auto func_graph = node->func_graph();
178       MS_LOG(DEBUG) << "There is while node: " << node->ToString() << " in graph: " << func_graph->ToString();
179       return true;
180     }
181   }
182   return false;
183 }
184 
IsCallNode(const AnfNodePtr & node)185 bool IsCallNode(const AnfNodePtr &node) {
186   MS_EXCEPTION_IF_NULL(node);
187   if (!node->isa<CNode>()) {
188     return false;
189   }
190   auto graph = node->func_graph();
191   MS_EXCEPTION_IF_NULL(graph);
192   bool in_kg = graph->type_name() == kKernelGraphTypeName;
193   auto cnode = node->cast<CNodePtr>();
194   MS_EXCEPTION_IF_NULL(cnode);
195   if (in_kg && IsPrimitiveCNode(node, prim::kPrimCall) && cnode->input(1) != nullptr &&
196       cnode->input(1)->isa<ValueNode>()) {
197     return true;
198   }
199   return false;
200 }
201 
CheckSwitchBranch(const AnfNodePtr & node)202 bool CheckSwitchBranch(const AnfNodePtr &node) {
203   AnfNodePtr value_node = nullptr;
204   if (IsPartialCNode(node)) {
205     auto cnode = node->cast<CNodePtr>();
206     MS_EXCEPTION_IF_NULL(cnode);
207     value_node = cnode->input(kPartialCNodeValue);
208   } else if (IsValueNode<FuncGraph>(node)) {
209     value_node = node;
210   } else {
211     return false;
212   }
213   auto graph = GetValueNode<FuncGraphPtr>(value_node);
214   MS_EXCEPTION_IF_NULL(graph);
215   if (graph->recursive()) {
216     return false;
217   }
218   return true;
219 }
220 
IsIfNode(const AnfNodePtr & node)221 bool IsIfNode(const AnfNodePtr &node) {
222   if (!node->isa<CNode>()) {
223     return false;
224   }
225   auto graph = node->func_graph();
226   MS_EXCEPTION_IF_NULL(graph);
227   bool in_kg = graph->type_name() == kKernelGraphTypeName;
228   auto cnode = node->cast<CNodePtr>();
229   MS_EXCEPTION_IF_NULL(cnode);
230   CNodePtr switch_node = nullptr;
231   if (in_kg && IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
232     switch_node = cnode;
233   } else if (!in_kg && IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
234     switch_node = cnode->input(0)->cast<CNodePtr>();
235   } else {
236     return false;
237   }
238   auto true_branch = switch_node->input(kSwitchTrueBranchIndex);
239   MS_EXCEPTION_IF_NULL(true_branch);
240   auto false_branch = switch_node->input(kSwitchFalseBranchIndex);
241   MS_EXCEPTION_IF_NULL(false_branch);
242 
243   if (!CheckSwitchBranch(switch_node->input(kSwitchTrueBranchIndex))) {
244     return false;
245   }
246   auto func_graph = node->func_graph();
247   MS_LOG(DEBUG) << "There is if node: " << node->ToString() << " in graph: " << func_graph->ToString();
248   return true;
249 }
250 
IsInitDataSetQueueNode(const AnfNodePtr & node)251 bool IsInitDataSetQueueNode(const AnfNodePtr &node) {
252   if (!node->isa<CNode>()) {
253     return false;
254   }
255   auto cnode = node->cast<CNodePtr>();
256   MS_EXCEPTION_IF_NULL(cnode);
257   if (IsPrimitiveCNode(cnode, prim::kPrimInitDataSetQueue)) {
258     return true;
259   }
260   return false;
261 }
262 
GetCNodeTargetFuncName(const CNodePtr cnode)263 std::string GetCNodeTargetFuncName(const CNodePtr cnode) {
264   if (IsCaseNode(cnode)) {
265     return string(kNameCase);
266   }
267   if (IsWhileNode(cnode)) {
268     return string(kNameWhile);
269   }
270   if (IsIfNode(cnode)) {
271     return string(kNameIf);
272   }
273   if (IsCallNode(cnode)) {
274     return string(kNamePartitionedCall);
275   }
276   return GetCNodeFuncName(cnode);
277 }
278 
IsCaseNode(const AnfNodePtr & node)279 bool IsCaseNode(const AnfNodePtr &node) {
280   MS_EXCEPTION_IF_NULL(node);
281   if (!node->isa<CNode>()) {
282     return false;
283   }
284   auto cnode = node->cast<CNodePtr>();
285   MS_EXCEPTION_IF_NULL(cnode);
286   auto graph = node->func_graph();
287   MS_EXCEPTION_IF_NULL(graph);
288   bool in_kg = graph->type_name() == kKernelGraphTypeName;
289   if (in_kg && IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
290     return true;
291   }
292   if (!in_kg && IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitchLayer)) {
293     return true;
294   }
295   return false;
296 }
297 
ConvertInputTensors(const std::vector<MeTensorPtr> & me_tensors,const std::string & format)298 std::vector<GeTensorPtr> ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors, const std::string &format) {
299   return TransformUtil::ConvertInputTensors(me_tensors, format);
300 }
301 
ConvertGeTensors(const std::vector<GeTensorPtr> & ge_tensors)302 std::vector<MeTensorPtr> ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors) {
303   return TransformUtil::ConvertGeTensors(ge_tensors);
304 }
305 
ConvertDataType(const MeDataType & type)306 GeDataType ConvertDataType(const MeDataType &type) { return TransformUtil::ConvertDataType(type); }
307 
ConvertGeTensor(const GeTensorPtr & ge_tensor,const ShapeVector & request_dims,bool ref_mem)308 MeTensorPtr ConvertGeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &request_dims, bool ref_mem) {
309   return TransformUtil::ConvertGeTensor(ge_tensor, request_dims, ref_mem);
310 }
311 
ConvertGeTensor(const GeTensorPtr & tensor)312 MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor) { return TransformUtil::ConvertGeTensor(tensor); }
313 
ConvertGeTensor(const GeTensorPtr & tensor,const TypeId & me_type)314 MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor, const TypeId &me_type) {
315   return TransformUtil::ConvertGeTensor(tensor, me_type);
316 }
317 
GetGraphRunner()318 std::shared_ptr<transform::GraphRunner> GetGraphRunner() { return DfGraphManager::GetInstance().GetGraphRunner(); }
319 
CheckAndGetGraphRunner(const transform::RunOptions & run_options)320 std::shared_ptr<transform::GraphRunner> CheckAndGetGraphRunner(const transform::RunOptions &run_options) {
321   if (transform::GetGraphByName(run_options.name) == nullptr) {
322     MS_LOG(WARNING) << "Can not find " << run_options.name
323                     << " sub graph, don't need data init subgraph in INFER mode.";
324     return nullptr;
325   }
326 
327   auto graph_runner = transform::GetGraphRunner();
328   if (graph_runner == nullptr) {
329     MS_LOG(EXCEPTION) << "Can not found GraphRunner.";
330   }
331   return graph_runner;
332 }
333 
GetGeSession()334 std::shared_ptr<::ge::Session> GetGeSession() { return DfGraphManager::GetInstance().GetGeSession(); }
335 
SetGeSession(const std::shared_ptr<::ge::Session> & sess_ptr)336 void SetGeSession(const std::shared_ptr<::ge::Session> &sess_ptr) {
337   DfGraphManager::GetInstance().SetGeSession(sess_ptr);
338 }
339 
NewGraphRunner(const GraphRunnerOptions & options)340 GraphRunnerPtr NewGraphRunner(const GraphRunnerOptions &options) {
341   auto graph_runner = std::make_shared<transform::GraphRunner>(options);
342   return graph_runner;
343 }
344 
SetGraphRunner(const GraphRunnerPtr & runner)345 void SetGraphRunner(const GraphRunnerPtr &runner) { DfGraphManager::GetInstance().SetGraphRunner(runner); }
ClearGraph()346 void ClearGraph() { DfGraphManager::GetInstance().ClearGraph(); }
347 
AddGraph(const std::string & name,const DfGraphPtr & graph,const OptionMap & options,const bool & is_cloud,const bool & need_aoe)348 Status AddGraph(const std::string &name, const DfGraphPtr &graph, const OptionMap &options, const bool &is_cloud,
349                 const bool &need_aoe) {
350   auto ret = DfGraphManager::GetInstance().AddGraph(name, graph, options, is_cloud);
351   if (ret != Status::SUCCESS) {
352     return ret;
353   }
354   if (need_aoe) {
355     transform::AddOptimizeGraph(name);
356     transform::DfGraphManager::GetInstance().AoeGeGraph();
357   }
358   auto graph_runner = transform::GetGraphRunner();
359   if (graph_runner == nullptr) {
360     // lite may not use graph_runner
361     MS_LOG(INFO) << "There is no GraphRunner.";
362     return ret;
363   }
364   return graph_runner->AddGraph(name);
365 }
366 
SetAnfGraph(const std::string & name,const AnfGraphPtr & anf_graph_ptr)367 void SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) {
368   DfGraphManager::GetInstance().SetAnfGraph(name, anf_graph_ptr);
369 }
370 
GetAnfGraph(uint32_t graph_id)371 FuncGraphPtr GetAnfGraph(uint32_t graph_id) { return DfGraphManager::GetInstance().GetAnfGraph(graph_id); }
372 
GetGraphByName(const std::string & name)373 DfGraphWrapperPtr GetGraphByName(const std::string &name) { return DfGraphManager::GetInstance().GetGraphByName(name); }
374 
AddOptimizeGraph(const std::string & name)375 void AddOptimizeGraph(const std::string &name) { AoeUtil::GetInstance().AddOptimizeGraph(name); }
376 
InitializeAoeUtil()377 void InitializeAoeUtil() { AoeUtil::GetInstance().Initialize(); }
378 
DestroyAoeUtil()379 void DestroyAoeUtil() { AoeUtil::GetInstance().Destroy(); }
380 
EnableAoeOffline()381 void EnableAoeOffline() { AoeUtil::GetInstance().SetOfflineEnvDumpGeGraph(); }
382 
383 // convert
384 
NewConverter(const FuncGraphPtr & graph,const std::string & phase_prefix,RefModeFlag ref_mode_type,bool offline_convert)385 DfGraphConvertorPtr NewConverter(const FuncGraphPtr &graph, const std::string &phase_prefix, RefModeFlag ref_mode_type,
386                                  bool offline_convert) {
387   std::vector<std::string> extra_variables_names = {};
388   auto converter = std::make_shared<transform::DfGraphConvertor>(graph, phase_prefix, ref_mode_type,
389                                                                  extra_variables_names, nullptr, offline_convert);
390   return converter;
391 }
392 
SetTraining(const DfGraphConvertorPtr & converter,bool training)393 void SetTraining(const DfGraphConvertorPtr &converter, bool training) {
394   MS_EXCEPTION_IF_NULL(converter);
395   converter->set_training(training);
396 }
397 
SetExportAir(const DfGraphConvertorPtr & converter,bool export_air)398 void SetExportAir(const DfGraphConvertorPtr &converter, bool export_air) {
399   MS_EXCEPTION_IF_NULL(converter);
400   converter->set_export_air(export_air);
401 }
402 
BuildGraph(const std::string & name,const DfGraphConvertorPtr & converter,const std::map<std::string,std::shared_ptr<tensor::Tensor>> & maps)403 void BuildGraph(const std::string &name, const DfGraphConvertorPtr &converter,
404                 const std::map<std::string, std::shared_ptr<tensor::Tensor>> &maps) {
405   MS_EXCEPTION_IF_NULL(converter);
406   (void)converter->ConvertAllNode().InitParam(maps).BuildGraph(name);
407 }
408 
GenerateBroadcastGraph(const DfGraphConvertorPtr & converter,const TensorOrderMap & tensors)409 void GenerateBroadcastGraph(const DfGraphConvertorPtr &converter, const TensorOrderMap &tensors) {
410   MS_EXCEPTION_IF_NULL(converter);
411   (void)converter->GenerateBroadcastGraph(tensors);
412 }
GenerateCheckpointGraph(const DfGraphConvertorPtr & converter)413 void GenerateCheckpointGraph(const DfGraphConvertorPtr &converter) {
414   MS_EXCEPTION_IF_NULL(converter);
415   (void)converter->GenerateCheckpointGraph();
416 }
ErrCode(const DfGraphConvertorPtr & converter)417 int ErrCode(const DfGraphConvertorPtr &converter) {
418   MS_EXCEPTION_IF_NULL(converter);
419   return converter->ErrCode();
420 }
421 
GenFakeGraph(const std::string & name,const DfGraphConvertorPtr & converter)422 void GenFakeGraph(const std::string &name, const DfGraphConvertorPtr &converter) {
423   MS_EXCEPTION_IF_NULL(converter);
424   converter->GenFakeGraph(name);
425 }
426 
GetComputeGraph(const DfGraphConvertorPtr & converter)427 DfGraphPtr GetComputeGraph(const DfGraphConvertorPtr &converter) {
428   MS_EXCEPTION_IF_NULL(converter);
429   return converter->GetComputeGraph();
430 }
GetInitGraph(const DfGraphConvertorPtr & converter)431 DfGraphPtr GetInitGraph(const DfGraphConvertorPtr &converter) {
432   MS_EXCEPTION_IF_NULL(converter);
433   return converter->GetInitGraph();
434 }
GetSaveCheckpointGraph(const DfGraphConvertorPtr & converter)435 DfGraphPtr GetSaveCheckpointGraph(const DfGraphConvertorPtr &converter) {
436   MS_EXCEPTION_IF_NULL(converter);
437   return converter->GetSaveCheckpointGraph();
438 }
GetBroadcastGraph(const DfGraphConvertorPtr & converter)439 DfGraphPtr GetBroadcastGraph(const DfGraphConvertorPtr &converter) {
440   MS_EXCEPTION_IF_NULL(converter);
441   return converter->GetBroadcastGraph();
442 }
443 
NewSession(const SessionOptions & sess_options)444 std::shared_ptr<::ge::Session> NewSession(const SessionOptions &sess_options) {
445   return transform::GraphRunner::NewSession(sess_options);
446 }
447 
RunGraph(const std::shared_ptr<transform::GraphRunner> & runner,const RunOptions & options,const std::vector<GeTensorPtr> & inputs,std::vector<GeTensorPtr> * outputs)448 Status RunGraph(const std::shared_ptr<transform::GraphRunner> &runner, const RunOptions &options,
449                 const std::vector<GeTensorPtr> &inputs, std::vector<GeTensorPtr> *outputs) {
450   MS_EXCEPTION_IF_NULL(runner);
451   return runner->RunGraph(options, inputs, outputs);
452 }
453 
RunGraphAsync(const std::shared_ptr<GraphRunner> & runner,const RunOptions & options,const std::vector<GeTensorPtr> & inputs,std::vector<GeTensorPtr> * outputs)454 Status RunGraphAsync(const std::shared_ptr<GraphRunner> &runner, const RunOptions &options,
455                      const std::vector<GeTensorPtr> &inputs, std::vector<GeTensorPtr> *outputs) {
456   MS_EXCEPTION_IF_NULL(runner);
457   return runner->RunGraphAsync(options, inputs, outputs);
458 }
459 
RunGraphWithStreamAsync(const std::shared_ptr<GraphRunner> & runner,const RunOptions & options,void * stream,const std::vector<GeTensor> & inputs,std::vector<GeTensor> * outputs)460 Status RunGraphWithStreamAsync(const std::shared_ptr<GraphRunner> &runner, const RunOptions &options, void *stream,
461                                const std::vector<GeTensor> &inputs, std::vector<GeTensor> *outputs) {
462   MS_EXCEPTION_IF_NULL(runner);
463   return runner->RunGraphWithStreamAsync(options, stream, inputs, outputs);
464 }
465 
RegisterExternalAllocator(const std::shared_ptr<GraphRunner> & runner,const void * const stream,GeAllocatorPtr allocator)466 Status RegisterExternalAllocator(const std::shared_ptr<GraphRunner> &runner, const void *const stream,
467                                  GeAllocatorPtr allocator) {
468   MS_EXCEPTION_IF_NULL(runner);
469   return runner->RegisterExternalAllocator(stream, allocator);
470 }
471 
UnregisterExternalAllocator(const std::shared_ptr<GraphRunner> & runner,const void * const stream)472 Status UnregisterExternalAllocator(const std::shared_ptr<GraphRunner> &runner, const void *const stream) {
473   MS_EXCEPTION_IF_NULL(runner);
474   return runner->UnregisterExternalAllocator(stream);
475 }
476 
CompileDatasetGraph(const DatasetGraphParam & param,const std::string & phase)477 transform::Status CompileDatasetGraph(const DatasetGraphParam &param, const std::string &phase) {
478   return BuildDatasetGraph(param, phase);
479 }
480 
ConvertCheck(const AnfNodePtr & node)481 bool ConvertCheck(const AnfNodePtr &node) {
482   if (!node->cast<CNodePtr>() || !AnfUtils::IsRealKernel(node)) {
483     return true;
484   }
485   PrimitivePtr prim = common::AnfAlgo::GetCNodePrimitive(node);
486   auto &adapter_map = OpAdapterMap::get();
487   return adapter_map.find(prim->name()) != adapter_map.end();
488 }
489 
DynamicShapeSupportCheck(const AnfNodePtr & node,bool train)490 bool DynamicShapeSupportCheck(const AnfNodePtr &node, bool train) {
491   auto adpt = FindAdapter(node, train);
492   MS_EXCEPTION_IF_NULL(adpt);
493   return adpt->GetDynamicShapeSupport();
494 }
495 
SinkGraphCheck(const AnfNodePtr & node,bool train)496 bool SinkGraphCheck(const AnfNodePtr &node, bool train) {
497   PrimitivePtr prim = common::AnfAlgo::GetCNodePrimitive(node);
498   auto adpt = FindAdapter(prim->name(), train);
499   MS_EXCEPTION_IF_NULL(adpt);
500   auto input_attr_map = adpt->getInputAttrMap();
501   auto cnode = node->cast<CNodePtr>();
502   MS_EXCEPTION_IF_NULL(cnode);
503   auto input_size = cnode->size();
504   for (auto &it : input_attr_map) {
505     if (it.first >= input_size) {
506       continue;
507     }
508     if (!cnode->input(it.first)->isa<ValueNode>()) {
509       MS_LOG(DEBUG) << node->fullname_with_scope() << " inputs[" << it.first << "]"
510                     << " is not a ValueNode";
511       return false;
512     }
513   }
514   auto input_map = adpt->getInputMap();
515   for (auto &it : input_map) {
516     if (static_cast<size_t>(it.first) >= input_size) {
517       continue;
518     }
519     auto abs = cnode->input(it.first)->abstract();
520     MS_EXCEPTION_IF_NULL(abs);
521     if (abs->isa<abstract::AbstractAny>()) {
522       MS_LOG(DEBUG) << node->fullname_with_scope() << " inputs[" << it.first << "]"
523                     << " is a AbstractAny";
524       return false;
525     }
526   }
527   return true;
528 }
529 }  // namespace transform
530 }  // namespace mindspore
531