1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "load_mindir/infer_mindir.h"
17 #include <deque>
18 #include <set>
19 #include <map>
20 #include <memory>
21 #include <algorithm>
22 #include <string>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ir/func_graph.h"
26 #include "abstract/abstract_function.h"
27 #include "abstract/abstract_value.h"
28 #include "utils/ms_context.h"
29 #include "abstract/ops/primitive_infer_map.h"
30
31 namespace mindspore {
32 namespace {
33 class MindIREngine {
34 public:
MindIREngine(const FuncGraphPtr & root)35 explicit MindIREngine(const FuncGraphPtr &root) : func_graph_(root), nodeuser_map_(root->manager()->node_users()) {}
36 ~MindIREngine() = default;
37 MindIREngine(const MindIREngine &) = delete;
38 MindIREngine &operator=(const MindIREngine &) = delete;
39
40 bool InferShape(const AbstractBasePtrList &args);
41
SetException(bool flag)42 void SetException(bool flag) { raise_exception_ = flag; }
43
44 private:
45 using AbstractBasePtrListPtr = std::shared_ptr<AbstractBasePtrList>;
46
47 void Init(const AbstractBasePtrList &args);
48 AbstractBasePtr InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_abs_list) const;
49 void EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args);
50 void EvalPartialPrimitive(const CNodePtr &node, const AbstractBasePtrListPtr &args);
51 void EvalReturnPrimitive(const CNodePtr &node);
52 void InferParameter(const AnfNodePtr &node);
53 void InferValueNode(const AnfNodePtr &node);
54 void InferCNode(const AnfNodePtr &node);
55 void EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &func, const CNodePtr &node,
56 const AbstractBasePtrListPtr &args);
57 void EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node,
58 const AbstractBasePtrListPtr &args);
59 void EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node,
60 const AbstractBasePtrListPtr &args);
61 void EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node,
62 const AbstractBasePtrListPtr &args);
63 bool CheckCNodeNotReady(const CNodePtr &node);
64 void UpdateReady(const AnfNodePtr &node);
65 void SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result);
66 AbstractBasePtr GetCNodeOperatorAbstract(const AnfNodePtr &node);
67
68 FuncGraphPtr func_graph_;
69 std::map<AnfNodePtr, int> node_input_depends_;
70 std::map<AnfNodePtr, AbstractBasePtr> infer_result_;
71 std::map<std::string, AbstractBasePtr> func_graph_result_;
72 std::map<std::string, std::set<AnfNodePtr>> func_graph_visited_;
73 std::deque<AnfNodePtr> ready_;
74 std::set<AnfNodePtr> todo_;
75 NodeUsersMap nodeuser_map_;
76 bool raise_exception_ = false;
77 };
78
79 // Infer the root function graph.
InferShape(const AbstractBasePtrList & args)80 bool MindIREngine::InferShape(const AbstractBasePtrList &args) {
81 Init(args);
82 while (!ready_.empty()) {
83 auto current = ready_.front();
84 MS_EXCEPTION_IF_NULL(current);
85 ready_.pop_front();
86 if (current->isa<CNode>()) {
87 InferCNode(current);
88 } else if (current->isa<ValueNode>()) {
89 InferValueNode(current);
90 } else if (current->isa<Parameter>()) {
91 InferParameter(current);
92 } else {
93 MS_LOG(WARNING) << " There is something changed. Please check the code.";
94 }
95 }
96
97 // Set abstract of node.
98 for (const auto &item : infer_result_) {
99 item.first->set_abstract(item.second);
100 }
101
102 if (todo_.empty()) {
103 MS_LOG(DEBUG) << "Finish to Infere.";
104 return true;
105 }
106 MS_LOG(INFO) << "Not finished to infer: " << todo_.size();
107 for (const auto &node : todo_) {
108 MS_LOG(DEBUG) << "Node uninfered: " << node->DebugString();
109 }
110 return false;
111 }
112
Init(const AbstractBasePtrList & args)113 void MindIREngine::Init(const AbstractBasePtrList &args) {
114 MS_EXCEPTION_IF_NULL(func_graph_);
115 auto manager = func_graph_->manager();
116 MS_EXCEPTION_IF_NULL(manager);
117 for (const auto &node : manager->all_nodes()) {
118 MS_EXCEPTION_IF_NULL(node);
119 if (node->isa<CNode>()) {
120 auto cnode = node->cast<CNodePtr>();
121 MS_EXCEPTION_IF_NULL(cnode);
122 (void)todo_.insert(node);
123 node_input_depends_[node] = SizeToInt(cnode->size());
124 } else if (node->isa<Parameter>()) {
125 auto param = node->cast<ParameterPtr>();
126 MS_EXCEPTION_IF_NULL(param);
127 if (param->has_default()) {
128 node_input_depends_[node] = 0;
129 auto default_param = param->default_param();
130 MS_EXCEPTION_IF_NULL(default_param);
131 infer_result_[node] = default_param->ToAbstract();
132 ready_.push_back(node);
133 } else {
134 node_input_depends_[node] = 1;
135 (void)todo_.insert(node);
136 }
137 } else {
138 // Value Node
139 node_input_depends_[node] = 0;
140 ready_.push_back(node);
141 }
142 }
143
144 auto inputs = func_graph_->get_inputs();
145 if (inputs.size() != args.size()) {
146 MS_LOG(EXCEPTION) << "The input number of parameters is not Compatible.\n"
147 << "Mindir:" << inputs.size() << " inputs: " << args.size()
148 << " FuncGraph:" << func_graph_->ToString() << "\n"
149 << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
150 }
151 // Root Func Parameters
152 for (size_t i = 0; i < args.size(); ++i) {
153 this->SaveNodeInferResult(inputs[i], args[i]);
154 }
155 MS_LOG(DEBUG) << "Finish init. Size of nodes:" << manager->all_nodes().size();
156 }
157
158 // Infer primitive using C++ implement.
InferPrimitiveShape(const PrimitivePtr & prim,const AbstractBasePtrList & args_abs_list) const159 AbstractBasePtr MindIREngine::InferPrimitiveShape(const PrimitivePtr &prim,
160 const AbstractBasePtrList &args_abs_list) const {
161 MS_EXCEPTION_IF_NULL(prim);
162 try {
163 MS_LOG_TRY_CATCH_SCOPE;
164 // For Lite, the op is with old format, it will fail in new infer function, so skip it.
165 #ifndef BUILD_LITE
166 auto abstract_optional = abstract::InferAbstractByFuncImpl(prim, args_abs_list);
167 if (abstract_optional.has_value()) {
168 return abstract_optional.value();
169 }
170 #endif
171
172 auto found = abstract::GetPrimitiveInferImpl(prim);
173 if (found.has_value()) {
174 auto infer = found.value();
175 if (infer.IsImplInferShapeAndType()) {
176 return infer.InferShapeAndType(nullptr, prim, args_abs_list);
177 }
178 }
179
180 if (raise_exception_) {
181 MS_LOG(INTERNAL_EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
182 << " primitive type:" << prim->type_name()
183 << " It will keep the previous value with danger.";
184 } else {
185 MS_LOG(INFO) << "Get infer shape function failed, primitive name:" << prim->name()
186 << " primitive type:" << prim->type_name() << " It will keep the previous value with danger.";
187 }
188 } catch (const std::exception &ex) {
189 if (raise_exception_) {
190 MS_LOG(INTERNAL_EXCEPTION) << "Catch primitive:" << prim->ToString()
191 << " InferPrimitiveShape exception:" << ex.what()
192 << " It will keep the previous value with danger.";
193 } else {
194 MS_LOG(INFO) << "Catch primitive:" << prim->ToString() << " InferPrimitiveShape exception:" << ex.what()
195 << " It will keep the previous value with danger.";
196 }
197 }
198 return nullptr;
199 }
200
EvalCommonPrimitive(const PrimitivePtr & prim,const CNodePtr & node,const AbstractBasePtrListPtr & args)201 void MindIREngine::EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node,
202 const AbstractBasePtrListPtr &args) {
203 // Save MakeTuple cnode abstract by its own abstract when MakeTuple have an abstract of
204 // AbstractCSRTensor/AbstractCOOTensor that can not be inferred by its Infer Functions.
205 if (prim->name() == prim::kPrimMakeTuple->name()) {
206 if (node->abstract() != nullptr && (node->abstract()->isa<abstract::AbstractSparseTensor>())) {
207 MS_LOG(INFO) << "Save MakeTuple cnode abstract by its own abstract : " << node->abstract()->ToString();
208 SaveNodeInferResult(node, node->abstract());
209 return;
210 }
211 }
212
213 AbstractBasePtrList args_abs_list;
214 // Args has been resolved by partial
215 if (args != nullptr) {
216 (void)args_abs_list.insert(args_abs_list.end(), args->begin(), args->end());
217 } else {
218 (void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(args_abs_list),
219 [this](const AnfNodePtr &arg) { return infer_result_[arg]; });
220 }
221
222 // Call C++ infer
223 auto result = InferPrimitiveShape(prim, args_abs_list);
224 if (result == nullptr) {
225 MS_LOG(INFO) << node->ToString()
226 << " can't be inferred shape. It will keep the previous value with danger. Prim: " << prim->ToString();
227 if (node->abstract() == nullptr) {
228 MS_LOG(WARNING) << "The abstract of the node: " << node->ToString()
229 << " is nullptr. And it can't be inferred shape. Prim: " << prim->ToString();
230 } else {
231 result = node->abstract()->Clone();
232 }
233 }
234 SaveNodeInferResult(node, result);
235 }
236
EvalReturnPrimitive(const CNodePtr & node)237 void MindIREngine::EvalReturnPrimitive(const CNodePtr &node) {
238 constexpr auto min_size = 2;
239 if (node->size() < min_size) {
240 MS_LOG(INTERNAL_EXCEPTION) << node->DebugString() << " input size < 2";
241 }
242 auto result = infer_result_[node->inputs()[1]];
243 auto funcName = node->func_graph()->ToString();
244 auto it = func_graph_result_.find(funcName);
245 if (it != func_graph_result_.end()) {
246 try {
247 MS_LOG_TRY_CATCH_SCOPE;
248 result = result->Join(it->second);
249 } catch (const std::exception &e) {
250 MS_LOG(INFO) << "Join abstract for return node " << node->DebugString() << " failed, exception: " << e.what();
251 }
252 }
253 this->func_graph_result_[funcName] = result;
254 SaveNodeInferResult(node, result);
255 MS_LOG(DEBUG) << funcName << " result: " << result->ToString();
256
257 // Set the result of the node whose Operator is this funcGraph
258 for (const auto &item : func_graph_visited_[funcName]) {
259 SaveNodeInferResult(item, result);
260 }
261 }
262
EvalPartialPrimitive(const CNodePtr & node,const AbstractBasePtrListPtr & args)263 void MindIREngine::EvalPartialPrimitive(const CNodePtr &node, const AbstractBasePtrListPtr &args) {
264 // Args has been resolved
265 if (args != nullptr) {
266 if (args->size() < 2) {
267 MS_LOG(INTERNAL_EXCEPTION) << node->DebugString() << " input size < 2";
268 }
269 auto real_func = (*args)[0]->cast<abstract::AbstractFuncAtomPtr>();
270 if (real_func == nullptr) {
271 MS_LOG(INTERNAL_EXCEPTION) << (*args)[0]->ToString() << " is not a function abstract.";
272 }
273 AbstractBasePtrList partial_args_list;
274 (void)partial_args_list.insert(partial_args_list.end(), args->begin() + 1, args->end());
275 auto partial_func = std::make_shared<abstract::PartialAbstractClosure>(real_func, partial_args_list, node);
276 SaveNodeInferResult(node, partial_func);
277 return;
278 }
279 // Not Resolved.
280 constexpr size_t kSizeTwo = 2;
281 if (node->size() < kSizeTwo) {
282 MS_LOG(INTERNAL_EXCEPTION) << node->DebugString() << " input size < " << kSizeTwo;
283 }
284 auto &func = infer_result_[node->inputs()[1]];
285 auto real_func = func->cast<abstract::AbstractFuncAtomPtr>();
286 if (real_func == nullptr) {
287 MS_LOG(INTERNAL_EXCEPTION) << func->ToString() << " is not a function abstract.";
288 }
289 AbstractBasePtrList partial_args_list;
290 (void)std::transform(node->inputs().begin() + 2, node->inputs().end(), std::back_inserter(partial_args_list),
291 [this](const AnfNodePtr &arg) { return infer_result_[arg]; });
292 auto partial_func = std::make_shared<abstract::PartialAbstractClosure>(real_func, partial_args_list, node);
293 SaveNodeInferResult(node, partial_func);
294 }
295
EvalPartialAbastract(const abstract::PartialAbstractClosurePtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)296 void MindIREngine::EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node,
297 const AbstractBasePtrListPtr &args) {
298 AbstractBasePtrListPtr partial_args_list = std::make_shared<AbstractBasePtrList>();
299 // Join arguments in partial and the rest arguments from args_conf_list.
300 auto func_args = func->args();
301 (void)partial_args_list->insert(partial_args_list->end(), func_args.begin(), func_args.end());
302 if (args == nullptr) {
303 // Not Recursive
304 (void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(*partial_args_list),
305 [this](const AnfNodePtr &arg) { return infer_result_[arg]; });
306 } else {
307 // Recursive
308 (void)partial_args_list->insert(partial_args_list->end(), args->begin(), args->end());
309 }
310
311 // Get real function
312 abstract::AbstractFuncAtomPtrList abstractFuncList;
313 auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) {
314 abstractFuncList.push_back(poss);
315 };
316 func->fn()->Visit(build_fuction);
317 for (const auto &abstractFunc : abstractFuncList) {
318 EvalAbstractFunction(abstractFunc, node, partial_args_list);
319 }
320 }
321
SaveNodeInferResult(const AnfNodePtr & node,const AbstractBasePtr & result)322 void MindIREngine::SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result) {
323 auto answer = result;
324 try {
325 MS_LOG_TRY_CATCH_SCOPE;
326 auto it = infer_result_.find(node);
327 if (it != infer_result_.end()) {
328 MS_LOG(DEBUG) << node->ToString() << " result: " << it->second->ToString();
329 answer = result->Join(it->second);
330 if (*answer == *(it->second)) {
331 MS_LOG(DEBUG) << node->ToString() << " The value is not changed.";
332 return;
333 }
334 }
335 } catch (const std::exception &e) {
336 MS_LOG(INFO) << "Join abstract for node " << node->DebugString() << " failed, exception: " << e.what();
337 return;
338 }
339
340 MS_LOG(DEBUG) << node->ToString() << " result: " << answer->ToString();
341 infer_result_[node] = answer;
342 UpdateReady(node);
343 }
344
EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)345 void MindIREngine::EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node,
346 const AbstractBasePtrListPtr &args) {
347 auto prim = func->prim();
348 // Return Primitive
349 if (prim->name() == prim::kPrimReturn->name()) {
350 EvalReturnPrimitive(node);
351 return;
352 }
353 // Partial Primitive
354 if (prim->name() == prim::kPrimPartial->name()) {
355 EvalPartialPrimitive(node, args);
356 return;
357 }
358 // common Primitive
359 EvalCommonPrimitive(prim, node, args);
360 }
361
CheckCNodeNotReady(const CNodePtr & node)362 bool MindIREngine::CheckCNodeNotReady(const CNodePtr &node) {
363 int depend = 0;
364 for (auto &weak_input : node->weak_inputs()) {
365 auto input = weak_input.lock();
366 MS_EXCEPTION_IF_NULL(input);
367 depend += infer_result_.find(input) != infer_result_.end() ? 0 : 1;
368 }
369 this->node_input_depends_[node] = depend;
370 return depend != 0;
371 }
372
EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)373 void MindIREngine::EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node,
374 const AbstractBasePtrListPtr &args) {
375 MS_EXCEPTION_IF_NULL(node);
376 MS_EXCEPTION_IF_NULL(func);
377 MS_EXCEPTION_IF_NULL(func->func_graph());
378 // Has Processd
379 MS_LOG(DEBUG) << node->ToString() << " FuncGraph: " << func->ToString();
380 auto funcName = func->func_graph()->ToString();
381 auto it = func_graph_result_.find(funcName);
382 if (it != func_graph_result_.end()) {
383 MS_LOG(DEBUG) << "The abstract of " << node->ToString() << " = abstract of " << func->ToString();
384 SaveNodeInferResult(node, it->second);
385
386 // Process only one return valueNode function graph
387 auto func_inputs = func->func_graph()->parameters();
388 // args has been resolved in partial.
389 if (args != nullptr) {
390 if (func_inputs.size() != args->size()) {
391 MS_LOG(INTERNAL_EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
392 << " CNode:" << node->DebugString() << " input size:" << args->size();
393 }
394 for (size_t i = 0; i < func_inputs.size(); ++i) {
395 infer_result_[func_inputs[i]] =
396 (*args)[i]; // Not use SaveNodeInferResult because this function has been evaluated.
397 (void)todo_.erase(func_inputs[i]);
398 }
399 return;
400 }
401 // args is not resolved.
402 auto &cnode_inputs = node->inputs();
403 if (func_inputs.size() != cnode_inputs.size() - 1) {
404 MS_LOG(INTERNAL_EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
405 << " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size();
406 }
407 for (size_t i = 0; i < func_inputs.size(); ++i) {
408 infer_result_[func_inputs[i]] = infer_result_[cnode_inputs[i + 1]];
409 (void)todo_.erase(func_inputs[i]);
410 }
411 return;
412 }
413
414 // Be handling
415 auto visitIt = func_graph_visited_.find(funcName);
416 if (visitIt != func_graph_visited_.end()) {
417 (void)visitIt->second.insert(node);
418 return;
419 }
420 func_graph_visited_[funcName] = std::set<AnfNodePtr>({node});
421
422 // Call the funcGraph
423 auto func_inputs = func->func_graph()->parameters();
424
425 // args has been resolved in partial.
426 if (args != nullptr) {
427 if (func_inputs.size() != args->size()) {
428 MS_LOG(INTERNAL_EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
429 << " CNode:" << node->DebugString() << " input size:" << args->size()
430 << " may have unsupported parameters.";
431 }
432 for (size_t i = 0; i < func_inputs.size(); ++i) {
433 SaveNodeInferResult(func_inputs[i], (*args)[i]);
434 }
435 return;
436 }
437 // args is not resolved.
438 auto &cnode_inputs = node->inputs();
439 if (func_inputs.size() != cnode_inputs.size() - 1) {
440 MS_LOG(INTERNAL_EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
441 << " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size()
442 << " may have unsupported parameters.";
443 }
444
445 for (size_t i = 0; i < func_inputs.size(); ++i) {
446 SaveNodeInferResult(func_inputs[i], infer_result_[cnode_inputs[i + 1]]);
447 }
448 }
449
InferParameter(const AnfNodePtr & node)450 void MindIREngine::InferParameter(const AnfNodePtr &node) { UpdateReady(node); }
451
InferValueNode(const AnfNodePtr & node)452 void MindIREngine::InferValueNode(const AnfNodePtr &node) {
453 MS_EXCEPTION_IF_NULL(node);
454 auto value_node = node->cast<ValueNodePtr>();
455 MS_EXCEPTION_IF_NULL(value_node);
456 auto value = GetValueNode(node);
457 MS_EXCEPTION_IF_NULL(value);
458 AbstractBasePtr result;
459 if (value->isa<FuncGraph>()) {
460 auto func_graph = value->cast<FuncGraphPtr>();
461 auto temp_context = abstract::AnalysisContext::DummyContext();
462 result = std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, node);
463 } else if (value->isa<Primitive>()) {
464 auto prim = value->cast<PrimitivePtr>();
465 result = std::make_shared<abstract::PrimitiveAbstractClosure>(prim, node);
466 } else {
467 result = value->ToAbstract();
468 }
469
470 SaveNodeInferResult(node, result);
471 }
472
GetCNodeOperatorAbstract(const AnfNodePtr & node)473 AbstractBasePtr MindIREngine::GetCNodeOperatorAbstract(const AnfNodePtr &node) {
474 MS_EXCEPTION_IF_NULL(node);
475 auto cnode = node->cast<CNodePtr>();
476 MS_EXCEPTION_IF_NULL(cnode);
477 auto op = cnode->inputs()[0];
478 auto it = infer_result_.find(op);
479 if (it != infer_result_.end()) {
480 return it->second;
481 }
482 MS_LOG(INTERNAL_EXCEPTION) << "Can't get the abstract of Node:" << op->DebugString();
483 }
484
485 // If args is nullPtr, it is called by InferCNode, else it is called recursively by EvalPartialAbastract.
EvalAbstractFunction(const abstract::AbstractFuncAtomPtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)486 void MindIREngine::EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &func, const CNodePtr &node,
487 const AbstractBasePtrListPtr &args) {
488 MS_EXCEPTION_IF_NULL(func);
489 if (func->isa<abstract::PrimitiveAbstractClosure>()) {
490 // C++ Primitive
491 auto prim = func->cast<abstract::PrimitiveAbstractClosurePtr>();
492 EvalPrimitiveAbastract(prim, node, args);
493 } else if (func->isa<abstract::FuncGraphAbstractClosure>()) {
494 // FuncGraph
495 auto funcGraph = func->cast<abstract::FuncGraphAbstractClosurePtr>();
496 EvalFuncGraphAbastract(funcGraph, node, args);
497 } else if (func->isa<abstract::PartialAbstractClosure>()) {
498 // Partial
499 auto partialPrim = func->cast<abstract::PartialAbstractClosurePtr>();
500 EvalPartialAbastract(partialPrim, node, args);
501 } else {
502 MS_LOG(INTERNAL_EXCEPTION) << "MindIR can't process the abstractFunc: " << func->DumpText();
503 }
504 }
505
UpdateReady(const AnfNodePtr & node)506 void MindIREngine::UpdateReady(const AnfNodePtr &node) {
507 (void)todo_.erase(node);
508 auto it = nodeuser_map_.find(node);
509 if (it == nodeuser_map_.end()) {
510 return;
511 }
512 const auto &users = it->second;
513 MS_LOG(DEBUG) << node->ToString() << " has users: " << users.size();
514 for (const auto &user : users) {
515 int count = node_input_depends_[user.first];
516 node_input_depends_[user.first] = count - 1;
517 if (count <= 1) {
518 ready_.push_back(user.first);
519 MS_LOG(DEBUG) << "Node:" << user.first->ToString() << " is ready.";
520 if (count < 1) {
521 MS_LOG(INFO) << " There is something to do. Node:" << node->ToString() << " user:" << user.first->DebugString();
522 }
523 }
524 }
525 }
526
InferCNode(const AnfNodePtr & node)527 void MindIREngine::InferCNode(const AnfNodePtr &node) {
528 auto cnode = node->cast<CNodePtr>();
529 MS_EXCEPTION_IF_NULL(cnode);
530 if (CheckCNodeNotReady(cnode)) {
531 MS_LOG(INFO) << "The node is not ready: " << cnode->DebugString();
532 return;
533 }
534 AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode);
535 MS_EXCEPTION_IF_NULL(possible_func);
536 auto type = possible_func->BuildType();
537 MS_EXCEPTION_IF_NULL(type);
538 if (type->type_id() == kObjectTypeUndeterminedType) {
539 MS_LOG(INTERNAL_EXCEPTION) << "EvalCNode eval Undetermined";
540 }
541 abstract::AbstractFunctionPtr func = dyn_cast<abstract::AbstractFunction>(possible_func);
542 if (func == nullptr) {
543 MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << possible_func->ToString() << ".";
544 MS_EXCEPTION(ValueError) << "This may be not defined, and it can't be a operator. Please check code.";
545 }
546 abstract::AbstractFuncAtomPtrList abstractFuncList;
547 auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) {
548 abstractFuncList.push_back(poss);
549 };
550 func->Visit(build_fuction);
551 for (const auto &abstractFunc : abstractFuncList) {
552 EvalAbstractFunction(abstractFunc, cnode, nullptr);
553 }
554 }
555 } // namespace
InferMindir(const FuncGraphPtr & root,const AbstractBasePtrList & args,bool raise_exception)556 bool InferMindir(const FuncGraphPtr &root, const AbstractBasePtrList &args, bool raise_exception) {
557 auto engine = std::make_shared<MindIREngine>(root);
558 engine->SetException(raise_exception);
559 return engine->InferShape(args);
560 }
561
ValidMindir(const FuncGraphPtr & root)562 bool ValidMindir(const FuncGraphPtr &root) {
563 MS_EXCEPTION_IF_NULL(root);
564 auto manager = root->manager();
565 if (manager == nullptr) {
566 manager = MakeManager();
567 manager->AddFuncGraph(root, true);
568 }
569 MS_LOG(DEBUG) << "Success to valid the mindir. " << root->ToString() << " : " << root.get();
570 return true;
571 }
572
InferFuncGraphLoaded(const FuncGraphPtr & root)573 void InferFuncGraphLoaded(const FuncGraphPtr &root) {
574 abstract::AbstractBasePtrList func_args;
575 const auto &inputs = root->get_inputs();
576 (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(func_args),
577 [](const AnfNodePtr &arg) -> AbstractBasePtr {
578 MS_EXCEPTION_IF_NULL(arg);
579 if (arg->abstract() == nullptr) {
580 MS_LOG(EXCEPTION) << "The parameter's abstract is null:" << arg->DebugString();
581 }
582 return arg->abstract();
583 });
584 (void)InferMindir(root, func_args);
585 }
586 } // namespace mindspore
587