1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "load_mindir/infer_mindir.h"
17 #include <deque>
18 #include <set>
19 #include <map>
20 #include <memory>
21 #include <algorithm>
22 #include <string>
23 #include "ir/func_graph.h"
24 #include "abstract/abstract_function.h"
25 #include "ops/primitive_c.h"
26 #include "abstract/abstract_value.h"
27
28 namespace mindspore {
29 namespace {
30 class MindIREngine {
31 public:
MindIREngine(const FuncGraphPtr & root)32 explicit MindIREngine(const FuncGraphPtr &root) : func_graph_(root), nodeuser_map_(root->manager()->node_users()) {}
33 ~MindIREngine() = default;
34 MindIREngine(const MindIREngine &) = delete;
35 MindIREngine &operator=(const MindIREngine &) = delete;
36
37 bool InferShape(const AbstractBasePtrList &args);
38
39 private:
40 using AbstractBasePtrListPtr = std::shared_ptr<AbstractBasePtrList>;
41
42 void Init(const AbstractBasePtrList &args);
43 static AbstractBasePtr InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
44 void EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args);
45 void EvalPartialPrimitive(const CNodePtr &node, const AbstractBasePtrListPtr &args);
46 void EvalReturnPrimitive(const CNodePtr &node);
47 void InferParameter(const AnfNodePtr &node);
48 void InferValueNode(const AnfNodePtr &node);
49 void InferCNode(const AnfNodePtr &node);
50 void EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &abstractFunc, const CNodePtr &node,
51 const AbstractBasePtrListPtr &args);
52 void EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node,
53 const AbstractBasePtrListPtr &args);
54 void EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node,
55 const AbstractBasePtrListPtr &args);
56 void EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node,
57 const AbstractBasePtrListPtr &args);
58 bool CheckCNodeNotReady(const CNodePtr &node);
59 void UpdateReady(const AnfNodePtr &node);
60 void SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result);
61 AbstractBasePtr GetCNodeOperatorAbstract(const AnfNodePtr &node);
62
63 FuncGraphPtr func_graph_;
64 std::map<AnfNodePtr, int> node_input_depends_;
65 std::map<AnfNodePtr, AbstractBasePtr> infer_resut_;
66 std::map<std::string, AbstractBasePtr> func_graph_result_;
67 std::map<std::string, std::set<AnfNodePtr>> func_graph_visited_;
68 std::deque<AnfNodePtr> ready_;
69 std::set<AnfNodePtr> todo_;
70 NodeUsersMap nodeuser_map_;
71 };
72
73 // Infer the root function graph.
InferShape(const AbstractBasePtrList & args)74 bool MindIREngine::InferShape(const AbstractBasePtrList &args) {
75 Init(args);
76 while (!ready_.empty()) {
77 auto current = ready_.front();
78 MS_EXCEPTION_IF_NULL(current);
79 ready_.pop_front();
80 if (current->isa<CNode>()) {
81 InferCNode(current);
82 } else if (current->isa<ValueNode>()) {
83 InferValueNode(current);
84 } else if (current->isa<Parameter>()) {
85 InferParameter(current);
86 } else {
87 MS_LOG(WARNING) << " There is something changed. Please check the code.";
88 }
89 }
90
91 // Set abstract of node.
92 for (const auto &item : infer_resut_) {
93 item.first->set_abstract(item.second);
94 }
95
96 if (todo_.empty()) {
97 MS_LOG(DEBUG) << "Finish to Infere.";
98 return true;
99 }
100 MS_LOG(WARNING) << "Not finished to infer: " << todo_.size();
101 for (const auto &node : todo_) {
102 MS_LOG(DEBUG) << "Node uninfered: " << node->DebugString();
103 }
104 return false;
105 }
106
Init(const AbstractBasePtrList & args)107 void MindIREngine::Init(const AbstractBasePtrList &args) {
108 MS_EXCEPTION_IF_NULL(func_graph_);
109 auto manager = func_graph_->manager();
110 MS_EXCEPTION_IF_NULL(manager);
111 for (const auto &node : manager->all_nodes()) {
112 MS_EXCEPTION_IF_NULL(node);
113 if (node->isa<CNode>()) {
114 auto cnode = node->cast<CNodePtr>();
115 MS_EXCEPTION_IF_NULL(cnode);
116 (void)todo_.insert(node);
117 node_input_depends_[node] = SizeToInt(cnode->inputs().size());
118 } else if (node->isa<Parameter>()) {
119 auto param = node->cast<ParameterPtr>();
120 MS_EXCEPTION_IF_NULL(param);
121 if (param->has_default()) {
122 node_input_depends_[node] = 0;
123 infer_resut_[node] = param->default_param()->ToAbstract();
124 ready_.push_back(node);
125 } else {
126 node_input_depends_[node] = 1;
127 (void)todo_.insert(node);
128 }
129 } else {
130 // Value Node
131 node_input_depends_[node] = 0;
132 ready_.push_back(node);
133 }
134 }
135
136 auto inputs = func_graph_->get_inputs();
137 if (inputs.size() != args.size()) {
138 MS_LOG(EXCEPTION) << "The input parameters is not Compatible. mindir:" << inputs.size()
139 << " inputs: " << args.size() << " FuncGraph:" << func_graph_->ToString();
140 }
141 // Root Func Parameters
142 for (size_t i = 0; i < args.size(); ++i) {
143 this->SaveNodeInferResult(inputs[i], args[i]);
144 }
145 MS_LOG(DEBUG) << "Finish init. Size of nodes:" << manager->all_nodes().size();
146 }
147
148 // Infer primitive using C++ implement.
InferPrimitiveShape(const PrimitivePtr & prim,const AbstractBasePtrList & args_spec_list)149 AbstractBasePtr MindIREngine::InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
150 MS_EXCEPTION_IF_NULL(prim);
151 try {
152 MS_LOG_TRY_CATCH_SCOPE;
153 static auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap();
154 auto ret = prim_eval_implement_map.find(prim);
155 if (ret != prim_eval_implement_map.end()) {
156 // fing infer function in the front infer map and restore input abastract form dynamic inputs and reg attr
157 MS_EXCEPTION_IF_NULL(ret->second.infer_shape_impl_);
158 return ret->second.infer_shape_impl_(nullptr, prim, args_spec_list);
159 } else {
160 // if the infer function has been not founded in the front infer map find it in the backend infer map instead
161 static auto &prim_backend_eval_impl_map = abstract::GetPrimitiveToBackendEvalImplMap();
162 auto ret_backend = prim_backend_eval_impl_map.find(prim);
163 if (ret_backend != prim_backend_eval_impl_map.end()) {
164 MS_EXCEPTION_IF_NULL(ret_backend->second.infer_shape_impl_);
165 return ret_backend->second.infer_shape_impl_(nullptr, prim, args_spec_list);
166 }
167 }
168 MS_LOG(WARNING) << "Get infer shape function failed, primitive name:" << prim->name()
169 << " primitive type:" << prim->type_name() << " It will keep the prevalue witch danger.";
170 } catch (const std::exception &ex) {
171 MS_LOG(WARNING) << "Catch primitive:" << prim->ToString() << " InferPrimitiveShape exception:" << ex.what()
172 << " It will keep the prevalue witch danger.";
173 }
174 return nullptr;
175 }
176
EvalCommonPrimitive(const PrimitivePtr & prim,const CNodePtr & node,const AbstractBasePtrListPtr & args)177 void MindIREngine::EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node,
178 const AbstractBasePtrListPtr &args) {
179 AbstractBasePtrList args_spec_list;
180 // Args has been resolved by partial
181 if (args != nullptr) {
182 (void)args_spec_list.insert(args_spec_list.end(), args->begin(), args->end());
183 } else {
184 (void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(args_spec_list),
185 [this](const AnfNodePtr &arg) { return infer_resut_[arg]; });
186 }
187
188 // Call C++ infer
189 auto result = InferPrimitiveShape(prim, args_spec_list);
190 if (result == nullptr) {
191 MS_LOG(INFO) << node->ToString()
192 << " can't be inferred shape. It will keep the prevalue witch danger. Prim: " << prim->ToString();
193 result = node->abstract();
194 }
195 SaveNodeInferResult(node, result);
196 }
197
EvalReturnPrimitive(const CNodePtr & node)198 void MindIREngine::EvalReturnPrimitive(const CNodePtr &node) {
199 if (node->inputs().size() < 2) {
200 MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2";
201 }
202 auto result = infer_resut_[node->inputs()[1]];
203 auto funcName = node->func_graph()->ToString();
204 auto it = func_graph_result_.find(funcName);
205 if (it != func_graph_result_.end()) {
206 try {
207 MS_LOG_TRY_CATCH_SCOPE;
208 result = result->Join(it->second);
209 } catch (const std::exception &e) {
210 MS_LOG(WARNING) << "Join abstract for return node " << node->DebugString() << " failed, exception: " << e.what();
211 }
212 }
213 this->func_graph_result_[funcName] = result;
214 SaveNodeInferResult(node, result);
215 MS_LOG(DEBUG) << funcName << " result: " << result->ToString();
216
217 // Set the result of the node whose Operator is this funcGraph
218 for (const auto &item : func_graph_visited_[funcName]) {
219 SaveNodeInferResult(item, result);
220 }
221 }
222
EvalPartialPrimitive(const CNodePtr & node,const AbstractBasePtrListPtr & args)223 void MindIREngine::EvalPartialPrimitive(const CNodePtr &node, const AbstractBasePtrListPtr &args) {
224 // Args has been resolved
225 if (args != nullptr) {
226 if (args->size() < 2) {
227 MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2";
228 }
229 auto real_func = (*args)[0]->cast<abstract::AbstractFuncAtomPtr>();
230 if (real_func == nullptr) {
231 MS_LOG(EXCEPTION) << (*args)[0]->ToString() << " is not a function abstract.";
232 }
233 AbstractBasePtrList partial_args_list;
234 (void)partial_args_list.insert(partial_args_list.end(), args->begin() + 1, args->end());
235 auto partial_func = std::make_shared<abstract::PartialAbstractClosure>(real_func, partial_args_list, node);
236 SaveNodeInferResult(node, partial_func);
237 return;
238 }
239 // Not Resolved.
240 if (node->inputs().size() < 2) {
241 MS_LOG(EXCEPTION) << node->DebugString() << " input size < 2";
242 }
243 auto &func = infer_resut_[node->inputs()[1]];
244 auto real_func = func->cast<abstract::AbstractFuncAtomPtr>();
245 if (real_func == nullptr) {
246 MS_LOG(EXCEPTION) << func->ToString() << " is not a function abstract.";
247 }
248 AbstractBasePtrList partial_args_list;
249 (void)std::transform(node->inputs().begin() + 2, node->inputs().end(), std::back_inserter(partial_args_list),
250 [this](const AnfNodePtr &arg) { return infer_resut_[arg]; });
251 auto partial_func = std::make_shared<abstract::PartialAbstractClosure>(real_func, partial_args_list, node);
252 SaveNodeInferResult(node, partial_func);
253 }
254
EvalPartialAbastract(const abstract::PartialAbstractClosurePtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)255 void MindIREngine::EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node,
256 const AbstractBasePtrListPtr &args) {
257 AbstractBasePtrListPtr partial_args_list = std::make_shared<AbstractBasePtrList>();
258 // Join arguments in partial and the rest arguments from args_conf_list.
259 auto func_args = func->args();
260 (void)partial_args_list->insert(partial_args_list->end(), func_args.begin(), func_args.end());
261 if (args == nullptr) {
262 // Not Recursive
263 (void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(*partial_args_list),
264 [this](const AnfNodePtr &arg) { return infer_resut_[arg]; });
265 } else {
266 // Recursive
267 (void)partial_args_list->insert(partial_args_list->end(), args->begin(), args->end());
268 }
269
270 // Get real function
271 abstract::AbstractFuncAtomPtrList abstractFuncList;
272 auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) {
273 abstractFuncList.push_back(poss);
274 };
275 func->fn()->Visit(build_fuction);
276 for (const auto &abstractFunc : abstractFuncList) {
277 EvalAbstractFunction(abstractFunc, node, partial_args_list);
278 }
279 }
280
SaveNodeInferResult(const AnfNodePtr & node,const AbstractBasePtr & result)281 void MindIREngine::SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result) {
282 auto answer = result;
283 try {
284 MS_LOG_TRY_CATCH_SCOPE;
285 auto it = infer_resut_.find(node);
286 if (it != infer_resut_.end()) {
287 MS_LOG(DEBUG) << node->ToString() << " result: " << it->second->ToString();
288 answer = result->Join(it->second);
289 if (*answer == *(it->second)) {
290 MS_LOG(DEBUG) << node->ToString() << " The value is not changed.";
291 return;
292 }
293 }
294 } catch (const std::exception &e) {
295 MS_LOG(WARNING) << "Join abstract for node " << node->DebugString() << " failed, exception: " << e.what();
296 return;
297 }
298
299 MS_LOG(DEBUG) << node->ToString() << " result: " << answer->ToString();
300 infer_resut_[node] = answer;
301 UpdateReady(node);
302 }
303
EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)304 void MindIREngine::EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node,
305 const AbstractBasePtrListPtr &args) {
306 auto prim = func->prim();
307 // Return Primitive
308 if (prim->name() == prim::kPrimReturn->name()) {
309 EvalReturnPrimitive(node);
310 return;
311 }
312 // Partial Primitive
313 if (prim->name() == prim::kPrimPartial->name()) {
314 EvalPartialPrimitive(node, args);
315 return;
316 }
317 // common Primitive
318 EvalCommonPrimitive(prim, node, args);
319 }
320
CheckCNodeNotReady(const CNodePtr & node)321 bool MindIREngine::CheckCNodeNotReady(const CNodePtr &node) {
322 int depend = 0;
323 for (const auto &input : node->inputs()) {
324 depend += infer_resut_.find(input) != infer_resut_.end() ? 0 : 1;
325 }
326 this->node_input_depends_[node] = depend;
327 return (depend != 0);
328 }
329
EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)330 void MindIREngine::EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node,
331 const AbstractBasePtrListPtr &args) {
332 MS_EXCEPTION_IF_NULL(node);
333 MS_EXCEPTION_IF_NULL(func);
334 MS_EXCEPTION_IF_NULL(func->func_graph());
335 // Has Processd
336 MS_LOG(DEBUG) << node->ToString() << " FuncGraph: " << func->ToString();
337 auto funcName = func->func_graph()->ToString();
338 auto it = func_graph_result_.find(funcName);
339 if (it != func_graph_result_.end()) {
340 MS_LOG(DEBUG) << "The abstract of " << node->ToString() << " = abstract of " << func->ToString();
341 SaveNodeInferResult(node, it->second);
342
343 // Process only one return valueNode function graph
344 auto func_inputs = func->func_graph()->get_inputs();
345 // args has been resolved in partial.
346 if (args != nullptr) {
347 if (func_inputs.size() != args->size()) {
348 MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
349 << " CNode:" << node->DebugString() << " input size:" << args->size();
350 }
351 for (size_t i = 0; i < func_inputs.size(); ++i) {
352 infer_resut_[func_inputs[i]] =
353 (*args)[i]; // Not use SaveNodeInferResult because this function has been evaluated.
354 (void)todo_.erase(func_inputs[i]);
355 }
356 return;
357 }
358 // args is not resolved.
359 auto &cnode_inputs = node->inputs();
360 if (func_inputs.size() != cnode_inputs.size() - 1) {
361 MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
362 << " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size();
363 }
364 for (size_t i = 0; i < func_inputs.size(); ++i) {
365 infer_resut_[func_inputs[i]] = infer_resut_[cnode_inputs[i + 1]];
366 (void)todo_.erase(func_inputs[i]);
367 }
368 return;
369 }
370
371 // Be handling
372 auto visitIt = func_graph_visited_.find(funcName);
373 if (visitIt != func_graph_visited_.end()) {
374 (void)visitIt->second.insert(node);
375 return;
376 }
377 func_graph_visited_[funcName] = std::set<AnfNodePtr>({node});
378
379 // Call the funcGraph
380 auto func_inputs = func->func_graph()->get_inputs();
381
382 // args has been resolved in partial.
383 if (args != nullptr) {
384 if (func_inputs.size() != args->size()) {
385 MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
386 << " CNode:" << node->DebugString() << " input size:" << args->size();
387 }
388 for (size_t i = 0; i < func_inputs.size(); ++i) {
389 SaveNodeInferResult(func_inputs[i], (*args)[i]);
390 }
391 return;
392 }
393 // args is not resolved.
394 auto &cnode_inputs = node->inputs();
395 if (func_inputs.size() != cnode_inputs.size() - 1) {
396 MS_LOG(EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
397 << " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size();
398 }
399
400 for (size_t i = 0; i < func_inputs.size(); ++i) {
401 SaveNodeInferResult(func_inputs[i], infer_resut_[cnode_inputs[i + 1]]);
402 }
403 }
404
InferParameter(const AnfNodePtr & node)405 void MindIREngine::InferParameter(const AnfNodePtr &node) { UpdateReady(node); }
406
InferValueNode(const AnfNodePtr & node)407 void MindIREngine::InferValueNode(const AnfNodePtr &node) {
408 MS_EXCEPTION_IF_NULL(node);
409 auto value_node = node->cast<ValueNodePtr>();
410 MS_EXCEPTION_IF_NULL(value_node);
411 auto value = GetValueNode(node);
412 MS_EXCEPTION_IF_NULL(value);
413 AbstractBasePtr result;
414 if (value->isa<FuncGraph>()) {
415 auto func_graph = value->cast<FuncGraphPtr>();
416 auto temp_context = abstract::AnalysisContext::DummyContext();
417 result = std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, node);
418 } else if (value->isa<Primitive>()) {
419 auto prim = value->cast<PrimitivePtr>();
420 result = std::make_shared<abstract::PrimitiveAbstractClosure>(prim, node);
421 } else {
422 result = value->ToAbstract();
423 }
424
425 if (result->isa<abstract::AbstractTensor>()) {
426 result = result->Broaden();
427 }
428 SaveNodeInferResult(node, result);
429 }
430
GetCNodeOperatorAbstract(const AnfNodePtr & node)431 AbstractBasePtr MindIREngine::GetCNodeOperatorAbstract(const AnfNodePtr &node) {
432 MS_EXCEPTION_IF_NULL(node);
433 auto cnode = node->cast<CNodePtr>();
434 MS_EXCEPTION_IF_NULL(cnode);
435 auto op = cnode->inputs()[0];
436 auto it = infer_resut_.find(op);
437 if (it != infer_resut_.end()) {
438 return it->second;
439 }
440 MS_LOG(EXCEPTION) << "Can't get the abstract of Node:" << op->DebugString();
441 }
442
443 // If args is nullPtr, it is called by InferCNode, else it is called recursively by EvalPartialAbastract.
EvalAbstractFunction(const abstract::AbstractFuncAtomPtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)444 void MindIREngine::EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &func, const CNodePtr &node,
445 const AbstractBasePtrListPtr &args) {
446 MS_EXCEPTION_IF_NULL(func);
447 if (func->isa<abstract::PrimitiveAbstractClosure>()) {
448 // C++ Primitive
449 auto prim = func->cast<abstract::PrimitiveAbstractClosurePtr>();
450 EvalPrimitiveAbastract(prim, node, args);
451 } else if (func->isa<abstract::FuncGraphAbstractClosure>()) {
452 // FuncGraph
453 auto funcGraph = func->cast<abstract::FuncGraphAbstractClosurePtr>();
454 EvalFuncGraphAbastract(funcGraph, node, args);
455 } else if (func->isa<abstract::PartialAbstractClosure>()) {
456 // Partial
457 auto partialPrim = func->cast<abstract::PartialAbstractClosurePtr>();
458 EvalPartialAbastract(partialPrim, node, args);
459 } else {
460 MS_LOG(EXCEPTION) << "MindIR can't process the abstractFunc: " << func->DumpText();
461 }
462 }
463
UpdateReady(const AnfNodePtr & node)464 void MindIREngine::UpdateReady(const AnfNodePtr &node) {
465 (void)todo_.erase(node);
466 auto it = nodeuser_map_.find(node);
467 if (it == nodeuser_map_.end()) {
468 return;
469 }
470 const auto &users = it->second;
471 MS_LOG(DEBUG) << node->ToString() << " has users: " << users.size();
472 for (const auto &user : users) {
473 int count = node_input_depends_[user.first];
474 node_input_depends_[user.first] = count - 1;
475 if (count <= 1) {
476 ready_.push_back(user.first);
477 MS_LOG(DEBUG) << "Node:" << user.first->ToString() << " is ready.";
478 if (count < 1) {
479 MS_LOG(INFO) << " There is something to do. Node:" << node->ToString() << " user:" << user.first->DebugString();
480 }
481 }
482 }
483 }
484
InferCNode(const AnfNodePtr & node)485 void MindIREngine::InferCNode(const AnfNodePtr &node) {
486 auto cnode = node->cast<CNodePtr>();
487 MS_EXCEPTION_IF_NULL(cnode);
488 if (CheckCNodeNotReady(cnode)) {
489 MS_LOG(INFO) << "The node is not ready: " << cnode->DebugString();
490 return;
491 }
492 AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode);
493 if (possible_func->BuildType()->type_id() == kObjectTypeUndeterminedType) {
494 MS_LOG(EXCEPTION) << "EvalCNode eval Undetermined";
495 }
496 abstract::AbstractFunctionPtr func = dyn_cast<abstract::AbstractFunction>(possible_func);
497 if (func == nullptr) {
498 MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << possible_func->ToString() << ".";
499 MS_EXCEPTION(ValueError) << "This may be not defined, and it can't be a operator. Please check code.";
500 }
501 abstract::AbstractFuncAtomPtrList abstractFuncList;
502 auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) {
503 abstractFuncList.push_back(poss);
504 };
505 func->Visit(build_fuction);
506 for (const auto &abstractFunc : abstractFuncList) {
507 EvalAbstractFunction(abstractFunc, cnode, nullptr);
508 }
509 }
510 } // namespace
InferMindir(const FuncGraphPtr & root,const AbstractBasePtrList & args)511 bool InferMindir(const FuncGraphPtr &root, const AbstractBasePtrList &args) {
512 auto engine = std::make_shared<MindIREngine>(root);
513 return engine->InferShape(args);
514 }
515 } // namespace mindspore
516