1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "include/transform/graph_ir/utils.h"
17 #include "mindspore/core/ops/other_ops.h"
18 #include "mindspore/core/ops/framework_ops.h"
19 #include "mindspore/core/ops/sequence_op_name.h"
20 #include "transform/graph_ir/aoe_util.h"
21 #include "transform/graph_ir/convert.h"
22 #include "transform/graph_ir/op_adapter_map.h"
23 #include "transform/graph_ir/op_adapter_util.h"
24 #include "transform/graph_ir/df_graph_manager.h"
25 #include "transform/graph_ir/op_adapter_desc.h"
26 #include "transform/graph_ir/transform_util.h"
27 #include "transform/graph_ir/graph_builder.h"
28 #include "include/common/utils/anfalgo.h"
29
30 namespace mindspore {
31 namespace transform {
32 namespace {
33 constexpr size_t kSwitchInputSize = 4;
34 constexpr size_t kSwitchCondIndex = 1;
35 constexpr size_t kSwitchTrueBranchIndex = 2;
36 constexpr size_t kSwitchFalseBranchIndex = 3;
37 constexpr size_t kPartialCNodeValue = 1;
38 } // namespace
39
FindAdapter(const AnfNodePtr node,bool train)40 OpAdapterPtr FindAdapter(const AnfNodePtr node, bool train) {
41 MS_EXCEPTION_IF_NULL(node);
42 if (node->isa<CNode>()) {
43 auto cnode = node->cast<CNodePtr>();
44
45 std::string name = kNameCustomOp;
46 if (!IsCustomCNode(cnode)) {
47 name = GetCNodeTargetFuncName(cnode);
48 }
49
50 // Convert TupleGetItem to control edge when it has monad.
51 if (name == kNameTupleGetItem) {
52 if (HasAbstractMonad(node)) {
53 name = kNameUpdateState;
54 }
55 }
56 auto it_adpt = OpAdapterMap::get().find(name);
57 if (it_adpt != OpAdapterMap::get().end()) {
58 return it_adpt->second->Get(train);
59 }
60
61 std::set<std::string> cpu_only_ops{kRealMakeTupleOpName, kRealTupleGetItemOpName};
62 auto iter = cpu_only_ops.find(name);
63 if (iter != cpu_only_ops.end()) {
64 MS_LOG(INFO) << "Can't find OpAdapter for " << name;
65 } else {
66 MS_LOG(WARNING) << "Can't find OpAdapter for " << name;
67 }
68
69 return OpAdapterPtr(nullptr);
70 }
71
72 if (node->isa<ValueNode>()) {
73 return OpAdapterMap::get()[kNameConst]->Get(train);
74 }
75 if (node->isa<Parameter>()) {
76 return OpAdapterMap::get()[kNameParam]->Get(train);
77 }
78 return OpAdapterPtr(nullptr);
79 }
80
FindAdapter(const std::string & name,bool train)81 OpAdapterPtr FindAdapter(const std::string &name, bool train) {
82 auto it = OpAdapterMap::get().find(name);
83 if (it != OpAdapterMap::get().end()) {
84 return it->second->Get(train);
85 }
86
87 std::set<std::string> cpu_only_ops{kRealMakeTupleOpName, kRealTupleGetItemOpName, kShapeCalcOpName};
88 auto iter = cpu_only_ops.find(name);
89 // If ops in cpu only list or ops is scalar ops or is sequence ops
90 if (iter != cpu_only_ops.end() || name.find("Scalar") != std::string::npos ||
91 name.find("Sequence") != std::string::npos || name.find("Tuple") != std::string::npos ||
92 name.find("List") != std::string::npos) {
93 MS_LOG(INFO) << "Can't find OpAdapter for " << name;
94 return nullptr;
95 }
96 MS_LOG(WARNING) << "Can't find OpAdapter for " << name;
97 return nullptr;
98 }
99
ClearGeSessionAndRunner()100 void ClearGeSessionAndRunner() {
101 DfGraphManager::GetInstance().DeleteGraphRunner();
102 DfGraphManager::GetInstance().DeleteGeSession();
103 }
104
IsPartialSuccNode(const AnfNodePtr node)105 bool IsPartialSuccNode(const AnfNodePtr node) {
106 MS_EXCEPTION_IF_NULL(node);
107 if (!node->isa<CNode>()) {
108 return false;
109 }
110 auto cnode = node->cast<CNodePtr>();
111 if (!cnode->inputs().empty()) {
112 for (size_t i = 0; i < cnode->size(); i++) {
113 if (IsPartialCNode(cnode->input(i))) {
114 return true;
115 }
116 }
117 }
118 return false;
119 }
120
IsPartialCNode(const AnfNodePtr node)121 bool IsPartialCNode(const AnfNodePtr node) {
122 MS_EXCEPTION_IF_NULL(node);
123 if (!node->isa<CNode>()) {
124 return false;
125 }
126 auto cnode = node->cast<CNodePtr>();
127 if (GetCNodeFuncName(cnode) == prim::kPrimPartial->name()) {
128 return true;
129 }
130 return false;
131 }
132
IsWhileNode(const AnfNodePtr & node)133 bool IsWhileNode(const AnfNodePtr &node) {
134 if (!node->isa<CNode>()) {
135 return false;
136 }
137 auto graph = node->func_graph();
138 MS_EXCEPTION_IF_NULL(graph);
139 bool in_kg = graph->type_name() == kKernelGraphTypeName;
140 auto cnode = node->cast<CNodePtr>();
141 ValueNodePtr graph_node = nullptr;
142 if (in_kg && IsPrimitiveCNode(node, prim::kPrimCall) && cnode->input(1)->isa<ValueNode>()) {
143 graph_node = cnode->input(1)->cast<ValueNodePtr>();
144 }
145 if (!in_kg) {
146 if (IsPrimitiveCNode(cnode->input(0), prim::kPrimPartial)) {
147 auto partial_node = cnode->input(0)->cast<CNodePtr>();
148 MS_EXCEPTION_IF_NULL(partial_node);
149 auto graph_node_input = partial_node->input(1);
150 MS_EXCEPTION_IF_NULL(graph_node_input);
151 graph_node = graph_node_input->cast<ValueNodePtr>();
152 } else if (cnode->input(0)->cast<ValueNodePtr>()) {
153 graph_node = cnode->input(0)->cast<ValueNodePtr>();
154 }
155 }
156 if (graph_node == nullptr) {
157 return false;
158 }
159
160 auto graph_node_value = graph_node->value();
161 MS_EXCEPTION_IF_NULL(graph_node_value);
162 if (!graph_node_value->isa<FuncGraph>()) {
163 return false;
164 }
165 auto cond_graph = graph_node_value->cast<FuncGraphPtr>();
166 MS_EXCEPTION_IF_NULL(cond_graph);
167 if (!cond_graph->recursive()) {
168 return false;
169 }
170 const auto &cond_set = cond_graph->nodes();
171 for (auto beg = cond_set.begin(); beg != cond_set.end(); ++beg) {
172 if (!((*beg)->isa<CNode>())) {
173 continue;
174 }
175 auto c_beg = (*beg)->cast<CNodePtr>();
176 if (IsPrimitiveCNode(c_beg, prim::kPrimSwitch)) {
177 auto func_graph = node->func_graph();
178 MS_LOG(DEBUG) << "There is while node: " << node->ToString() << " in graph: " << func_graph->ToString();
179 return true;
180 }
181 }
182 return false;
183 }
184
IsCallNode(const AnfNodePtr & node)185 bool IsCallNode(const AnfNodePtr &node) {
186 MS_EXCEPTION_IF_NULL(node);
187 if (!node->isa<CNode>()) {
188 return false;
189 }
190 auto graph = node->func_graph();
191 MS_EXCEPTION_IF_NULL(graph);
192 bool in_kg = graph->type_name() == kKernelGraphTypeName;
193 auto cnode = node->cast<CNodePtr>();
194 MS_EXCEPTION_IF_NULL(cnode);
195 if (in_kg && IsPrimitiveCNode(node, prim::kPrimCall) && cnode->input(1) != nullptr &&
196 cnode->input(1)->isa<ValueNode>()) {
197 return true;
198 }
199 return false;
200 }
201
CheckSwitchBranch(const AnfNodePtr & node)202 bool CheckSwitchBranch(const AnfNodePtr &node) {
203 AnfNodePtr value_node = nullptr;
204 if (IsPartialCNode(node)) {
205 auto cnode = node->cast<CNodePtr>();
206 MS_EXCEPTION_IF_NULL(cnode);
207 value_node = cnode->input(kPartialCNodeValue);
208 } else if (IsValueNode<FuncGraph>(node)) {
209 value_node = node;
210 } else {
211 return false;
212 }
213 auto graph = GetValueNode<FuncGraphPtr>(value_node);
214 MS_EXCEPTION_IF_NULL(graph);
215 if (graph->recursive()) {
216 return false;
217 }
218 return true;
219 }
220
IsIfNode(const AnfNodePtr & node)221 bool IsIfNode(const AnfNodePtr &node) {
222 if (!node->isa<CNode>()) {
223 return false;
224 }
225 auto graph = node->func_graph();
226 MS_EXCEPTION_IF_NULL(graph);
227 bool in_kg = graph->type_name() == kKernelGraphTypeName;
228 auto cnode = node->cast<CNodePtr>();
229 MS_EXCEPTION_IF_NULL(cnode);
230 CNodePtr switch_node = nullptr;
231 if (in_kg && IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
232 switch_node = cnode;
233 } else if (!in_kg && IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
234 switch_node = cnode->input(0)->cast<CNodePtr>();
235 } else {
236 return false;
237 }
238 auto true_branch = switch_node->input(kSwitchTrueBranchIndex);
239 MS_EXCEPTION_IF_NULL(true_branch);
240 auto false_branch = switch_node->input(kSwitchFalseBranchIndex);
241 MS_EXCEPTION_IF_NULL(false_branch);
242
243 if (!CheckSwitchBranch(switch_node->input(kSwitchTrueBranchIndex))) {
244 return false;
245 }
246 auto func_graph = node->func_graph();
247 MS_LOG(DEBUG) << "There is if node: " << node->ToString() << " in graph: " << func_graph->ToString();
248 return true;
249 }
250
IsInitDataSetQueueNode(const AnfNodePtr & node)251 bool IsInitDataSetQueueNode(const AnfNodePtr &node) {
252 if (!node->isa<CNode>()) {
253 return false;
254 }
255 auto cnode = node->cast<CNodePtr>();
256 MS_EXCEPTION_IF_NULL(cnode);
257 if (IsPrimitiveCNode(cnode, prim::kPrimInitDataSetQueue)) {
258 return true;
259 }
260 return false;
261 }
262
GetCNodeTargetFuncName(const CNodePtr cnode)263 std::string GetCNodeTargetFuncName(const CNodePtr cnode) {
264 if (IsCaseNode(cnode)) {
265 return string(kNameCase);
266 }
267 if (IsWhileNode(cnode)) {
268 return string(kNameWhile);
269 }
270 if (IsIfNode(cnode)) {
271 return string(kNameIf);
272 }
273 if (IsCallNode(cnode)) {
274 return string(kNamePartitionedCall);
275 }
276 return GetCNodeFuncName(cnode);
277 }
278
IsCaseNode(const AnfNodePtr & node)279 bool IsCaseNode(const AnfNodePtr &node) {
280 MS_EXCEPTION_IF_NULL(node);
281 if (!node->isa<CNode>()) {
282 return false;
283 }
284 auto cnode = node->cast<CNodePtr>();
285 MS_EXCEPTION_IF_NULL(cnode);
286 auto graph = node->func_graph();
287 MS_EXCEPTION_IF_NULL(graph);
288 bool in_kg = graph->type_name() == kKernelGraphTypeName;
289 if (in_kg && IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
290 return true;
291 }
292 if (!in_kg && IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitchLayer)) {
293 return true;
294 }
295 return false;
296 }
297
ConvertInputTensors(const std::vector<MeTensorPtr> & me_tensors,const std::string & format)298 std::vector<GeTensorPtr> ConvertInputTensors(const std::vector<MeTensorPtr> &me_tensors, const std::string &format) {
299 return TransformUtil::ConvertInputTensors(me_tensors, format);
300 }
301
ConvertGeTensors(const std::vector<GeTensorPtr> & ge_tensors)302 std::vector<MeTensorPtr> ConvertGeTensors(const std::vector<GeTensorPtr> &ge_tensors) {
303 return TransformUtil::ConvertGeTensors(ge_tensors);
304 }
305
ConvertDataType(const MeDataType & type)306 GeDataType ConvertDataType(const MeDataType &type) { return TransformUtil::ConvertDataType(type); }
307
ConvertGeTensor(const GeTensorPtr & ge_tensor,const ShapeVector & request_dims,bool ref_mem)308 MeTensorPtr ConvertGeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &request_dims, bool ref_mem) {
309 return TransformUtil::ConvertGeTensor(ge_tensor, request_dims, ref_mem);
310 }
311
ConvertGeTensor(const GeTensorPtr & tensor)312 MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor) { return TransformUtil::ConvertGeTensor(tensor); }
313
ConvertGeTensor(const GeTensorPtr & tensor,const TypeId & me_type)314 MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor, const TypeId &me_type) {
315 return TransformUtil::ConvertGeTensor(tensor, me_type);
316 }
317
GetGraphRunner()318 std::shared_ptr<transform::GraphRunner> GetGraphRunner() { return DfGraphManager::GetInstance().GetGraphRunner(); }
319
CheckAndGetGraphRunner(const transform::RunOptions & run_options)320 std::shared_ptr<transform::GraphRunner> CheckAndGetGraphRunner(const transform::RunOptions &run_options) {
321 if (transform::GetGraphByName(run_options.name) == nullptr) {
322 MS_LOG(WARNING) << "Can not find " << run_options.name
323 << " sub graph, don't need data init subgraph in INFER mode.";
324 return nullptr;
325 }
326
327 auto graph_runner = transform::GetGraphRunner();
328 if (graph_runner == nullptr) {
329 MS_LOG(EXCEPTION) << "Can not found GraphRunner.";
330 }
331 return graph_runner;
332 }
333
GetGeSession()334 std::shared_ptr<::ge::Session> GetGeSession() { return DfGraphManager::GetInstance().GetGeSession(); }
335
SetGeSession(const std::shared_ptr<::ge::Session> & sess_ptr)336 void SetGeSession(const std::shared_ptr<::ge::Session> &sess_ptr) {
337 DfGraphManager::GetInstance().SetGeSession(sess_ptr);
338 }
339
NewGraphRunner(const GraphRunnerOptions & options)340 GraphRunnerPtr NewGraphRunner(const GraphRunnerOptions &options) {
341 auto graph_runner = std::make_shared<transform::GraphRunner>(options);
342 return graph_runner;
343 }
344
SetGraphRunner(const GraphRunnerPtr & runner)345 void SetGraphRunner(const GraphRunnerPtr &runner) { DfGraphManager::GetInstance().SetGraphRunner(runner); }
ClearGraph()346 void ClearGraph() { DfGraphManager::GetInstance().ClearGraph(); }
347
AddGraph(const std::string & name,const DfGraphPtr & graph,const OptionMap & options,const bool & is_cloud,const bool & need_aoe)348 Status AddGraph(const std::string &name, const DfGraphPtr &graph, const OptionMap &options, const bool &is_cloud,
349 const bool &need_aoe) {
350 auto ret = DfGraphManager::GetInstance().AddGraph(name, graph, options, is_cloud);
351 if (ret != Status::SUCCESS) {
352 return ret;
353 }
354 if (need_aoe) {
355 transform::AddOptimizeGraph(name);
356 transform::DfGraphManager::GetInstance().AoeGeGraph();
357 }
358 auto graph_runner = transform::GetGraphRunner();
359 if (graph_runner == nullptr) {
360 // lite may not use graph_runner
361 MS_LOG(INFO) << "There is no GraphRunner.";
362 return ret;
363 }
364 return graph_runner->AddGraph(name);
365 }
366
SetAnfGraph(const std::string & name,const AnfGraphPtr & anf_graph_ptr)367 void SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) {
368 DfGraphManager::GetInstance().SetAnfGraph(name, anf_graph_ptr);
369 }
370
GetAnfGraph(uint32_t graph_id)371 FuncGraphPtr GetAnfGraph(uint32_t graph_id) { return DfGraphManager::GetInstance().GetAnfGraph(graph_id); }
372
GetGraphByName(const std::string & name)373 DfGraphWrapperPtr GetGraphByName(const std::string &name) { return DfGraphManager::GetInstance().GetGraphByName(name); }
374
AddOptimizeGraph(const std::string & name)375 void AddOptimizeGraph(const std::string &name) { AoeUtil::GetInstance().AddOptimizeGraph(name); }
376
InitializeAoeUtil()377 void InitializeAoeUtil() { AoeUtil::GetInstance().Initialize(); }
378
DestroyAoeUtil()379 void DestroyAoeUtil() { AoeUtil::GetInstance().Destroy(); }
380
EnableAoeOffline()381 void EnableAoeOffline() { AoeUtil::GetInstance().SetOfflineEnvDumpGeGraph(); }
382
383 // convert
384
NewConverter(const FuncGraphPtr & graph,const std::string & phase_prefix,RefModeFlag ref_mode_type,bool offline_convert)385 DfGraphConvertorPtr NewConverter(const FuncGraphPtr &graph, const std::string &phase_prefix, RefModeFlag ref_mode_type,
386 bool offline_convert) {
387 std::vector<std::string> extra_variables_names = {};
388 auto converter = std::make_shared<transform::DfGraphConvertor>(graph, phase_prefix, ref_mode_type,
389 extra_variables_names, nullptr, offline_convert);
390 return converter;
391 }
392
SetTraining(const DfGraphConvertorPtr & converter,bool training)393 void SetTraining(const DfGraphConvertorPtr &converter, bool training) {
394 MS_EXCEPTION_IF_NULL(converter);
395 converter->set_training(training);
396 }
397
SetExportAir(const DfGraphConvertorPtr & converter,bool export_air)398 void SetExportAir(const DfGraphConvertorPtr &converter, bool export_air) {
399 MS_EXCEPTION_IF_NULL(converter);
400 converter->set_export_air(export_air);
401 }
402
BuildGraph(const std::string & name,const DfGraphConvertorPtr & converter,const std::map<std::string,std::shared_ptr<tensor::Tensor>> & maps)403 void BuildGraph(const std::string &name, const DfGraphConvertorPtr &converter,
404 const std::map<std::string, std::shared_ptr<tensor::Tensor>> &maps) {
405 MS_EXCEPTION_IF_NULL(converter);
406 (void)converter->ConvertAllNode().InitParam(maps).BuildGraph(name);
407 }
408
GenerateBroadcastGraph(const DfGraphConvertorPtr & converter,const TensorOrderMap & tensors)409 void GenerateBroadcastGraph(const DfGraphConvertorPtr &converter, const TensorOrderMap &tensors) {
410 MS_EXCEPTION_IF_NULL(converter);
411 (void)converter->GenerateBroadcastGraph(tensors);
412 }
GenerateCheckpointGraph(const DfGraphConvertorPtr & converter)413 void GenerateCheckpointGraph(const DfGraphConvertorPtr &converter) {
414 MS_EXCEPTION_IF_NULL(converter);
415 (void)converter->GenerateCheckpointGraph();
416 }
ErrCode(const DfGraphConvertorPtr & converter)417 int ErrCode(const DfGraphConvertorPtr &converter) {
418 MS_EXCEPTION_IF_NULL(converter);
419 return converter->ErrCode();
420 }
421
GenFakeGraph(const std::string & name,const DfGraphConvertorPtr & converter)422 void GenFakeGraph(const std::string &name, const DfGraphConvertorPtr &converter) {
423 MS_EXCEPTION_IF_NULL(converter);
424 converter->GenFakeGraph(name);
425 }
426
GetComputeGraph(const DfGraphConvertorPtr & converter)427 DfGraphPtr GetComputeGraph(const DfGraphConvertorPtr &converter) {
428 MS_EXCEPTION_IF_NULL(converter);
429 return converter->GetComputeGraph();
430 }
GetInitGraph(const DfGraphConvertorPtr & converter)431 DfGraphPtr GetInitGraph(const DfGraphConvertorPtr &converter) {
432 MS_EXCEPTION_IF_NULL(converter);
433 return converter->GetInitGraph();
434 }
GetSaveCheckpointGraph(const DfGraphConvertorPtr & converter)435 DfGraphPtr GetSaveCheckpointGraph(const DfGraphConvertorPtr &converter) {
436 MS_EXCEPTION_IF_NULL(converter);
437 return converter->GetSaveCheckpointGraph();
438 }
GetBroadcastGraph(const DfGraphConvertorPtr & converter)439 DfGraphPtr GetBroadcastGraph(const DfGraphConvertorPtr &converter) {
440 MS_EXCEPTION_IF_NULL(converter);
441 return converter->GetBroadcastGraph();
442 }
443
NewSession(const SessionOptions & sess_options)444 std::shared_ptr<::ge::Session> NewSession(const SessionOptions &sess_options) {
445 return transform::GraphRunner::NewSession(sess_options);
446 }
447
RunGraph(const std::shared_ptr<transform::GraphRunner> & runner,const RunOptions & options,const std::vector<GeTensorPtr> & inputs,std::vector<GeTensorPtr> * outputs)448 Status RunGraph(const std::shared_ptr<transform::GraphRunner> &runner, const RunOptions &options,
449 const std::vector<GeTensorPtr> &inputs, std::vector<GeTensorPtr> *outputs) {
450 MS_EXCEPTION_IF_NULL(runner);
451 return runner->RunGraph(options, inputs, outputs);
452 }
453
RunGraphAsync(const std::shared_ptr<GraphRunner> & runner,const RunOptions & options,const std::vector<GeTensorPtr> & inputs,std::vector<GeTensorPtr> * outputs)454 Status RunGraphAsync(const std::shared_ptr<GraphRunner> &runner, const RunOptions &options,
455 const std::vector<GeTensorPtr> &inputs, std::vector<GeTensorPtr> *outputs) {
456 MS_EXCEPTION_IF_NULL(runner);
457 return runner->RunGraphAsync(options, inputs, outputs);
458 }
459
RunGraphWithStreamAsync(const std::shared_ptr<GraphRunner> & runner,const RunOptions & options,void * stream,const std::vector<GeTensor> & inputs,std::vector<GeTensor> * outputs)460 Status RunGraphWithStreamAsync(const std::shared_ptr<GraphRunner> &runner, const RunOptions &options, void *stream,
461 const std::vector<GeTensor> &inputs, std::vector<GeTensor> *outputs) {
462 MS_EXCEPTION_IF_NULL(runner);
463 return runner->RunGraphWithStreamAsync(options, stream, inputs, outputs);
464 }
465
RegisterExternalAllocator(const std::shared_ptr<GraphRunner> & runner,const void * const stream,GeAllocatorPtr allocator)466 Status RegisterExternalAllocator(const std::shared_ptr<GraphRunner> &runner, const void *const stream,
467 GeAllocatorPtr allocator) {
468 MS_EXCEPTION_IF_NULL(runner);
469 return runner->RegisterExternalAllocator(stream, allocator);
470 }
471
UnregisterExternalAllocator(const std::shared_ptr<GraphRunner> & runner,const void * const stream)472 Status UnregisterExternalAllocator(const std::shared_ptr<GraphRunner> &runner, const void *const stream) {
473 MS_EXCEPTION_IF_NULL(runner);
474 return runner->UnregisterExternalAllocator(stream);
475 }
476
CompileDatasetGraph(const DatasetGraphParam & param,const std::string & phase)477 transform::Status CompileDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase) {
478 return BuildDatasetGraph(param, phase);
479 }
480
ConvertCheck(const AnfNodePtr & node)481 bool ConvertCheck(const AnfNodePtr &node) {
482 if (!node->cast<CNodePtr>() || !AnfUtils::IsRealKernel(node)) {
483 return true;
484 }
485 PrimitivePtr prim = common::AnfAlgo::GetCNodePrimitive(node);
486 auto &adapter_map = OpAdapterMap::get();
487 return adapter_map.find(prim->name()) != adapter_map.end();
488 }
489
DynamicShapeSupportCheck(const AnfNodePtr & node,bool train)490 bool DynamicShapeSupportCheck(const AnfNodePtr &node, bool train) {
491 auto adpt = FindAdapter(node, train);
492 MS_EXCEPTION_IF_NULL(adpt);
493 return adpt->GetDynamicShapeSupport();
494 }
495
SinkGraphCheck(const AnfNodePtr & node,bool train)496 bool SinkGraphCheck(const AnfNodePtr &node, bool train) {
497 PrimitivePtr prim = common::AnfAlgo::GetCNodePrimitive(node);
498 auto adpt = FindAdapter(prim->name(), train);
499 MS_EXCEPTION_IF_NULL(adpt);
500 auto input_attr_map = adpt->getInputAttrMap();
501 auto cnode = node->cast<CNodePtr>();
502 MS_EXCEPTION_IF_NULL(cnode);
503 auto input_size = cnode->size();
504 for (auto &it : input_attr_map) {
505 if (it.first >= input_size) {
506 continue;
507 }
508 if (!cnode->input(it.first)->isa<ValueNode>()) {
509 MS_LOG(DEBUG) << node->fullname_with_scope() << " inputs[" << it.first << "]"
510 << " is not a ValueNode";
511 return false;
512 }
513 }
514 auto input_map = adpt->getInputMap();
515 for (auto &it : input_map) {
516 if (static_cast<size_t>(it.first) >= input_size) {
517 continue;
518 }
519 auto abs = cnode->input(it.first)->abstract();
520 MS_EXCEPTION_IF_NULL(abs);
521 if (abs->isa<abstract::AbstractAny>()) {
522 MS_LOG(DEBUG) << node->fullname_with_scope() << " inputs[" << it.first << "]"
523 << " is a AbstractAny";
524 return false;
525 }
526 }
527 return true;
528 }
529 } // namespace transform
530 } // namespace mindspore
531