• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "load_mindir/infer_mindir.h"
17 #include <deque>
18 #include <set>
19 #include <map>
20 #include <memory>
21 #include <algorithm>
22 #include <string>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "ir/func_graph.h"
26 #include "abstract/abstract_function.h"
27 #include "abstract/abstract_value.h"
28 #include "utils/ms_context.h"
29 #include "abstract/ops/primitive_infer_map.h"
30 
31 namespace mindspore {
32 namespace {
33 class MindIREngine {
34  public:
MindIREngine(const FuncGraphPtr & root)35   explicit MindIREngine(const FuncGraphPtr &root) : func_graph_(root), nodeuser_map_(root->manager()->node_users()) {}
36   ~MindIREngine() = default;
37   MindIREngine(const MindIREngine &) = delete;
38   MindIREngine &operator=(const MindIREngine &) = delete;
39 
40   bool InferShape(const AbstractBasePtrList &args);
41 
SetException(bool flag)42   void SetException(bool flag) { raise_exception_ = flag; }
43 
44  private:
45   using AbstractBasePtrListPtr = std::shared_ptr<AbstractBasePtrList>;
46 
47   void Init(const AbstractBasePtrList &args);
48   AbstractBasePtr InferPrimitiveShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_abs_list) const;
49   void EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node, const AbstractBasePtrListPtr &args);
50   void EvalPartialPrimitive(const CNodePtr &node, const AbstractBasePtrListPtr &args);
51   void EvalReturnPrimitive(const CNodePtr &node);
52   void InferParameter(const AnfNodePtr &node);
53   void InferValueNode(const AnfNodePtr &node);
54   void InferCNode(const AnfNodePtr &node);
55   void EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &func, const CNodePtr &node,
56                             const AbstractBasePtrListPtr &args);
57   void EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node,
58                               const AbstractBasePtrListPtr &args);
59   void EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node,
60                               const AbstractBasePtrListPtr &args);
61   void EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node,
62                             const AbstractBasePtrListPtr &args);
63   bool CheckCNodeNotReady(const CNodePtr &node);
64   void UpdateReady(const AnfNodePtr &node);
65   void SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result);
66   AbstractBasePtr GetCNodeOperatorAbstract(const AnfNodePtr &node);
67 
68   FuncGraphPtr func_graph_;
69   std::map<AnfNodePtr, int> node_input_depends_;
70   std::map<AnfNodePtr, AbstractBasePtr> infer_result_;
71   std::map<std::string, AbstractBasePtr> func_graph_result_;
72   std::map<std::string, std::set<AnfNodePtr>> func_graph_visited_;
73   std::deque<AnfNodePtr> ready_;
74   std::set<AnfNodePtr> todo_;
75   NodeUsersMap nodeuser_map_;
76   bool raise_exception_ = false;
77 };
78 
79 // Infer the root function graph.
InferShape(const AbstractBasePtrList & args)80 bool MindIREngine::InferShape(const AbstractBasePtrList &args) {
81   Init(args);
82   while (!ready_.empty()) {
83     auto current = ready_.front();
84     MS_EXCEPTION_IF_NULL(current);
85     ready_.pop_front();
86     if (current->isa<CNode>()) {
87       InferCNode(current);
88     } else if (current->isa<ValueNode>()) {
89       InferValueNode(current);
90     } else if (current->isa<Parameter>()) {
91       InferParameter(current);
92     } else {
93       MS_LOG(WARNING) << " There is something changed. Please check the code.";
94     }
95   }
96 
97   // Set abstract of node.
98   for (const auto &item : infer_result_) {
99     item.first->set_abstract(item.second);
100   }
101 
102   if (todo_.empty()) {
103     MS_LOG(DEBUG) << "Finish to Infere.";
104     return true;
105   }
106   MS_LOG(INFO) << "Not finished to infer: " << todo_.size();
107   for (const auto &node : todo_) {
108     MS_LOG(DEBUG) << "Node uninfered: " << node->DebugString();
109   }
110   return false;
111 }
112 
Init(const AbstractBasePtrList & args)113 void MindIREngine::Init(const AbstractBasePtrList &args) {
114   MS_EXCEPTION_IF_NULL(func_graph_);
115   auto manager = func_graph_->manager();
116   MS_EXCEPTION_IF_NULL(manager);
117   for (const auto &node : manager->all_nodes()) {
118     MS_EXCEPTION_IF_NULL(node);
119     if (node->isa<CNode>()) {
120       auto cnode = node->cast<CNodePtr>();
121       MS_EXCEPTION_IF_NULL(cnode);
122       (void)todo_.insert(node);
123       node_input_depends_[node] = SizeToInt(cnode->size());
124     } else if (node->isa<Parameter>()) {
125       auto param = node->cast<ParameterPtr>();
126       MS_EXCEPTION_IF_NULL(param);
127       if (param->has_default()) {
128         node_input_depends_[node] = 0;
129         auto default_param = param->default_param();
130         MS_EXCEPTION_IF_NULL(default_param);
131         infer_result_[node] = default_param->ToAbstract();
132         ready_.push_back(node);
133       } else {
134         node_input_depends_[node] = 1;
135         (void)todo_.insert(node);
136       }
137     } else {
138       // Value Node
139       node_input_depends_[node] = 0;
140       ready_.push_back(node);
141     }
142   }
143 
144   auto inputs = func_graph_->get_inputs();
145   if (inputs.size() != args.size()) {
146     MS_LOG(EXCEPTION) << "The input number of parameters is not Compatible.\n"
147                       << "Mindir:" << inputs.size() << " inputs: " << args.size()
148                       << " FuncGraph:" << func_graph_->ToString() << "\n"
149                       << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
150   }
151   // Root Func Parameters
152   for (size_t i = 0; i < args.size(); ++i) {
153     this->SaveNodeInferResult(inputs[i], args[i]);
154   }
155   MS_LOG(DEBUG) << "Finish init. Size of nodes:" << manager->all_nodes().size();
156 }
157 
158 // Infer primitive using C++ implement.
InferPrimitiveShape(const PrimitivePtr & prim,const AbstractBasePtrList & args_abs_list) const159 AbstractBasePtr MindIREngine::InferPrimitiveShape(const PrimitivePtr &prim,
160                                                   const AbstractBasePtrList &args_abs_list) const {
161   MS_EXCEPTION_IF_NULL(prim);
162   try {
163     MS_LOG_TRY_CATCH_SCOPE;
164     // For Lite, the op is with old format, it will fail in new infer function, so skip it.
165 #ifndef BUILD_LITE
166     auto abstract_optional = abstract::InferAbstractByFuncImpl(prim, args_abs_list);
167     if (abstract_optional.has_value()) {
168       return abstract_optional.value();
169     }
170 #endif
171 
172     auto found = abstract::GetPrimitiveInferImpl(prim);
173     if (found.has_value()) {
174       auto infer = found.value();
175       if (infer.IsImplInferShapeAndType()) {
176         return infer.InferShapeAndType(nullptr, prim, args_abs_list);
177       }
178     }
179 
180     if (raise_exception_) {
181       MS_LOG(INTERNAL_EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
182                                  << " primitive type:" << prim->type_name()
183                                  << " It will keep the previous value with danger.";
184     } else {
185       MS_LOG(INFO) << "Get infer shape function failed, primitive name:" << prim->name()
186                    << " primitive type:" << prim->type_name() << " It will keep the previous value with danger.";
187     }
188   } catch (const std::exception &ex) {
189     if (raise_exception_) {
190       MS_LOG(INTERNAL_EXCEPTION) << "Catch primitive:" << prim->ToString()
191                                  << " InferPrimitiveShape exception:" << ex.what()
192                                  << " It will keep the previous value with danger.";
193     } else {
194       MS_LOG(INFO) << "Catch primitive:" << prim->ToString() << " InferPrimitiveShape exception:" << ex.what()
195                    << " It will keep the previous value with danger.";
196     }
197   }
198   return nullptr;
199 }
200 
EvalCommonPrimitive(const PrimitivePtr & prim,const CNodePtr & node,const AbstractBasePtrListPtr & args)201 void MindIREngine::EvalCommonPrimitive(const PrimitivePtr &prim, const CNodePtr &node,
202                                        const AbstractBasePtrListPtr &args) {
203   // Save MakeTuple cnode abstract by its own abstract when MakeTuple have an abstract of
204   // AbstractCSRTensor/AbstractCOOTensor that can not be inferred by its Infer Functions.
205   if (prim->name() == prim::kPrimMakeTuple->name()) {
206     if (node->abstract() != nullptr && (node->abstract()->isa<abstract::AbstractSparseTensor>())) {
207       MS_LOG(INFO) << "Save MakeTuple cnode abstract by its own abstract : " << node->abstract()->ToString();
208       SaveNodeInferResult(node, node->abstract());
209       return;
210     }
211   }
212 
213   AbstractBasePtrList args_abs_list;
214   // Args has been resolved by partial
215   if (args != nullptr) {
216     (void)args_abs_list.insert(args_abs_list.end(), args->begin(), args->end());
217   } else {
218     (void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(args_abs_list),
219                          [this](const AnfNodePtr &arg) { return infer_result_[arg]; });
220   }
221 
222   // Call C++ infer
223   auto result = InferPrimitiveShape(prim, args_abs_list);
224   if (result == nullptr) {
225     MS_LOG(INFO) << node->ToString()
226                  << " can't be inferred shape. It will keep the previous value with danger. Prim: " << prim->ToString();
227     if (node->abstract() == nullptr) {
228       MS_LOG(WARNING) << "The abstract of the node: " << node->ToString()
229                       << " is nullptr. And it can't be inferred shape. Prim: " << prim->ToString();
230     } else {
231       result = node->abstract()->Clone();
232     }
233   }
234   SaveNodeInferResult(node, result);
235 }
236 
EvalReturnPrimitive(const CNodePtr & node)237 void MindIREngine::EvalReturnPrimitive(const CNodePtr &node) {
238   constexpr auto min_size = 2;
239   if (node->size() < min_size) {
240     MS_LOG(INTERNAL_EXCEPTION) << node->DebugString() << " input size < 2";
241   }
242   auto result = infer_result_[node->inputs()[1]];
243   auto funcName = node->func_graph()->ToString();
244   auto it = func_graph_result_.find(funcName);
245   if (it != func_graph_result_.end()) {
246     try {
247       MS_LOG_TRY_CATCH_SCOPE;
248       result = result->Join(it->second);
249     } catch (const std::exception &e) {
250       MS_LOG(INFO) << "Join abstract for return node " << node->DebugString() << " failed, exception: " << e.what();
251     }
252   }
253   this->func_graph_result_[funcName] = result;
254   SaveNodeInferResult(node, result);
255   MS_LOG(DEBUG) << funcName << " result: " << result->ToString();
256 
257   // Set the result of the node whose Operator is this funcGraph
258   for (const auto &item : func_graph_visited_[funcName]) {
259     SaveNodeInferResult(item, result);
260   }
261 }
262 
EvalPartialPrimitive(const CNodePtr & node,const AbstractBasePtrListPtr & args)263 void MindIREngine::EvalPartialPrimitive(const CNodePtr &node, const AbstractBasePtrListPtr &args) {
264   // Args has  been resolved
265   if (args != nullptr) {
266     if (args->size() < 2) {
267       MS_LOG(INTERNAL_EXCEPTION) << node->DebugString() << " input size < 2";
268     }
269     auto real_func = (*args)[0]->cast<abstract::AbstractFuncAtomPtr>();
270     if (real_func == nullptr) {
271       MS_LOG(INTERNAL_EXCEPTION) << (*args)[0]->ToString() << " is not a function abstract.";
272     }
273     AbstractBasePtrList partial_args_list;
274     (void)partial_args_list.insert(partial_args_list.end(), args->begin() + 1, args->end());
275     auto partial_func = std::make_shared<abstract::PartialAbstractClosure>(real_func, partial_args_list, node);
276     SaveNodeInferResult(node, partial_func);
277     return;
278   }
279   // Not Resolved.
280   constexpr size_t kSizeTwo = 2;
281   if (node->size() < kSizeTwo) {
282     MS_LOG(INTERNAL_EXCEPTION) << node->DebugString() << " input size < " << kSizeTwo;
283   }
284   auto &func = infer_result_[node->inputs()[1]];
285   auto real_func = func->cast<abstract::AbstractFuncAtomPtr>();
286   if (real_func == nullptr) {
287     MS_LOG(INTERNAL_EXCEPTION) << func->ToString() << " is not a function abstract.";
288   }
289   AbstractBasePtrList partial_args_list;
290   (void)std::transform(node->inputs().begin() + 2, node->inputs().end(), std::back_inserter(partial_args_list),
291                        [this](const AnfNodePtr &arg) { return infer_result_[arg]; });
292   auto partial_func = std::make_shared<abstract::PartialAbstractClosure>(real_func, partial_args_list, node);
293   SaveNodeInferResult(node, partial_func);
294 }
295 
EvalPartialAbastract(const abstract::PartialAbstractClosurePtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)296 void MindIREngine::EvalPartialAbastract(const abstract::PartialAbstractClosurePtr &func, const CNodePtr &node,
297                                         const AbstractBasePtrListPtr &args) {
298   AbstractBasePtrListPtr partial_args_list = std::make_shared<AbstractBasePtrList>();
299   // Join arguments in partial and the rest arguments from args_conf_list.
300   auto func_args = func->args();
301   (void)partial_args_list->insert(partial_args_list->end(), func_args.begin(), func_args.end());
302   if (args == nullptr) {
303     // Not Recursive
304     (void)std::transform(node->inputs().begin() + 1, node->inputs().end(), std::back_inserter(*partial_args_list),
305                          [this](const AnfNodePtr &arg) { return infer_result_[arg]; });
306   } else {
307     // Recursive
308     (void)partial_args_list->insert(partial_args_list->end(), args->begin(), args->end());
309   }
310 
311   // Get real function
312   abstract::AbstractFuncAtomPtrList abstractFuncList;
313   auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) {
314     abstractFuncList.push_back(poss);
315   };
316   func->fn()->Visit(build_fuction);
317   for (const auto &abstractFunc : abstractFuncList) {
318     EvalAbstractFunction(abstractFunc, node, partial_args_list);
319   }
320 }
321 
SaveNodeInferResult(const AnfNodePtr & node,const AbstractBasePtr & result)322 void MindIREngine::SaveNodeInferResult(const AnfNodePtr &node, const AbstractBasePtr &result) {
323   auto answer = result;
324   try {
325     MS_LOG_TRY_CATCH_SCOPE;
326     auto it = infer_result_.find(node);
327     if (it != infer_result_.end()) {
328       MS_LOG(DEBUG) << node->ToString() << " result: " << it->second->ToString();
329       answer = result->Join(it->second);
330       if (*answer == *(it->second)) {
331         MS_LOG(DEBUG) << node->ToString() << " The value is not changed.";
332         return;
333       }
334     }
335   } catch (const std::exception &e) {
336     MS_LOG(INFO) << "Join abstract for node " << node->DebugString() << " failed, exception: " << e.what();
337     return;
338   }
339 
340   MS_LOG(DEBUG) << node->ToString() << " result: " << answer->ToString();
341   infer_result_[node] = answer;
342   UpdateReady(node);
343 }
344 
EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)345 void MindIREngine::EvalPrimitiveAbastract(const abstract::PrimitiveAbstractClosurePtr &func, const CNodePtr &node,
346                                           const AbstractBasePtrListPtr &args) {
347   auto prim = func->prim();
348   // Return Primitive
349   if (prim->name() == prim::kPrimReturn->name()) {
350     EvalReturnPrimitive(node);
351     return;
352   }
353   // Partial Primitive
354   if (prim->name() == prim::kPrimPartial->name()) {
355     EvalPartialPrimitive(node, args);
356     return;
357   }
358   // common Primitive
359   EvalCommonPrimitive(prim, node, args);
360 }
361 
CheckCNodeNotReady(const CNodePtr & node)362 bool MindIREngine::CheckCNodeNotReady(const CNodePtr &node) {
363   int depend = 0;
364   for (auto &weak_input : node->weak_inputs()) {
365     auto input = weak_input.lock();
366     MS_EXCEPTION_IF_NULL(input);
367     depend += infer_result_.find(input) != infer_result_.end() ? 0 : 1;
368   }
369   this->node_input_depends_[node] = depend;
370   return depend != 0;
371 }
372 
EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)373 void MindIREngine::EvalFuncGraphAbastract(const abstract::FuncGraphAbstractClosurePtr &func, const CNodePtr &node,
374                                           const AbstractBasePtrListPtr &args) {
375   MS_EXCEPTION_IF_NULL(node);
376   MS_EXCEPTION_IF_NULL(func);
377   MS_EXCEPTION_IF_NULL(func->func_graph());
378   // Has Processd
379   MS_LOG(DEBUG) << node->ToString() << " FuncGraph: " << func->ToString();
380   auto funcName = func->func_graph()->ToString();
381   auto it = func_graph_result_.find(funcName);
382   if (it != func_graph_result_.end()) {
383     MS_LOG(DEBUG) << "The abstract of " << node->ToString() << " = abstract of " << func->ToString();
384     SaveNodeInferResult(node, it->second);
385 
386     // Process only one return valueNode function graph
387     auto func_inputs = func->func_graph()->parameters();
388     // args has been resolved in partial.
389     if (args != nullptr) {
390       if (func_inputs.size() != args->size()) {
391         MS_LOG(INTERNAL_EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
392                                    << " CNode:" << node->DebugString() << " input size:" << args->size();
393       }
394       for (size_t i = 0; i < func_inputs.size(); ++i) {
395         infer_result_[func_inputs[i]] =
396           (*args)[i];  // Not use SaveNodeInferResult because this function has been evaluated.
397         (void)todo_.erase(func_inputs[i]);
398       }
399       return;
400     }
401     // args is not resolved.
402     auto &cnode_inputs = node->inputs();
403     if (func_inputs.size() != cnode_inputs.size() - 1) {
404       MS_LOG(INTERNAL_EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
405                                  << " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size();
406     }
407     for (size_t i = 0; i < func_inputs.size(); ++i) {
408       infer_result_[func_inputs[i]] = infer_result_[cnode_inputs[i + 1]];
409       (void)todo_.erase(func_inputs[i]);
410     }
411     return;
412   }
413 
414   // Be handling
415   auto visitIt = func_graph_visited_.find(funcName);
416   if (visitIt != func_graph_visited_.end()) {
417     (void)visitIt->second.insert(node);
418     return;
419   }
420   func_graph_visited_[funcName] = std::set<AnfNodePtr>({node});
421 
422   // Call the funcGraph
423   auto func_inputs = func->func_graph()->parameters();
424 
425   // args has been resolved in partial.
426   if (args != nullptr) {
427     if (func_inputs.size() != args->size()) {
428       MS_LOG(INTERNAL_EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
429                                  << " CNode:" << node->DebugString() << " input size:" << args->size()
430                                  << " may have unsupported parameters.";
431     }
432     for (size_t i = 0; i < func_inputs.size(); ++i) {
433       SaveNodeInferResult(func_inputs[i], (*args)[i]);
434     }
435     return;
436   }
437   // args is not resolved.
438   auto &cnode_inputs = node->inputs();
439   if (func_inputs.size() != cnode_inputs.size() - 1) {
440     MS_LOG(INTERNAL_EXCEPTION) << func->func_graph()->ToString() << " input size:" << func_inputs.size()
441                                << " CNode:" << node->DebugString() << " input size:" << cnode_inputs.size()
442                                << " may have unsupported parameters.";
443   }
444 
445   for (size_t i = 0; i < func_inputs.size(); ++i) {
446     SaveNodeInferResult(func_inputs[i], infer_result_[cnode_inputs[i + 1]]);
447   }
448 }
449 
InferParameter(const AnfNodePtr & node)450 void MindIREngine::InferParameter(const AnfNodePtr &node) { UpdateReady(node); }
451 
InferValueNode(const AnfNodePtr & node)452 void MindIREngine::InferValueNode(const AnfNodePtr &node) {
453   MS_EXCEPTION_IF_NULL(node);
454   auto value_node = node->cast<ValueNodePtr>();
455   MS_EXCEPTION_IF_NULL(value_node);
456   auto value = GetValueNode(node);
457   MS_EXCEPTION_IF_NULL(value);
458   AbstractBasePtr result;
459   if (value->isa<FuncGraph>()) {
460     auto func_graph = value->cast<FuncGraphPtr>();
461     auto temp_context = abstract::AnalysisContext::DummyContext();
462     result = std::make_shared<abstract::FuncGraphAbstractClosure>(func_graph, temp_context, node);
463   } else if (value->isa<Primitive>()) {
464     auto prim = value->cast<PrimitivePtr>();
465     result = std::make_shared<abstract::PrimitiveAbstractClosure>(prim, node);
466   } else {
467     result = value->ToAbstract();
468   }
469 
470   SaveNodeInferResult(node, result);
471 }
472 
GetCNodeOperatorAbstract(const AnfNodePtr & node)473 AbstractBasePtr MindIREngine::GetCNodeOperatorAbstract(const AnfNodePtr &node) {
474   MS_EXCEPTION_IF_NULL(node);
475   auto cnode = node->cast<CNodePtr>();
476   MS_EXCEPTION_IF_NULL(cnode);
477   auto op = cnode->inputs()[0];
478   auto it = infer_result_.find(op);
479   if (it != infer_result_.end()) {
480     return it->second;
481   }
482   MS_LOG(INTERNAL_EXCEPTION) << "Can't get the abstract of Node:" << op->DebugString();
483 }
484 
485 // If args is nullPtr, it is called by InferCNode, else it is called recursively by EvalPartialAbastract.
EvalAbstractFunction(const abstract::AbstractFuncAtomPtr & func,const CNodePtr & node,const AbstractBasePtrListPtr & args)486 void MindIREngine::EvalAbstractFunction(const abstract::AbstractFuncAtomPtr &func, const CNodePtr &node,
487                                         const AbstractBasePtrListPtr &args) {
488   MS_EXCEPTION_IF_NULL(func);
489   if (func->isa<abstract::PrimitiveAbstractClosure>()) {
490     // C++ Primitive
491     auto prim = func->cast<abstract::PrimitiveAbstractClosurePtr>();
492     EvalPrimitiveAbastract(prim, node, args);
493   } else if (func->isa<abstract::FuncGraphAbstractClosure>()) {
494     // FuncGraph
495     auto funcGraph = func->cast<abstract::FuncGraphAbstractClosurePtr>();
496     EvalFuncGraphAbastract(funcGraph, node, args);
497   } else if (func->isa<abstract::PartialAbstractClosure>()) {
498     // Partial
499     auto partialPrim = func->cast<abstract::PartialAbstractClosurePtr>();
500     EvalPartialAbastract(partialPrim, node, args);
501   } else {
502     MS_LOG(INTERNAL_EXCEPTION) << "MindIR can't process the abstractFunc: " << func->DumpText();
503   }
504 }
505 
UpdateReady(const AnfNodePtr & node)506 void MindIREngine::UpdateReady(const AnfNodePtr &node) {
507   (void)todo_.erase(node);
508   auto it = nodeuser_map_.find(node);
509   if (it == nodeuser_map_.end()) {
510     return;
511   }
512   const auto &users = it->second;
513   MS_LOG(DEBUG) << node->ToString() << " has users: " << users.size();
514   for (const auto &user : users) {
515     int count = node_input_depends_[user.first];
516     node_input_depends_[user.first] = count - 1;
517     if (count <= 1) {
518       ready_.push_back(user.first);
519       MS_LOG(DEBUG) << "Node:" << user.first->ToString() << " is ready.";
520       if (count < 1) {
521         MS_LOG(INFO) << " There is something to do. Node:" << node->ToString() << " user:" << user.first->DebugString();
522       }
523     }
524   }
525 }
526 
InferCNode(const AnfNodePtr & node)527 void MindIREngine::InferCNode(const AnfNodePtr &node) {
528   auto cnode = node->cast<CNodePtr>();
529   MS_EXCEPTION_IF_NULL(cnode);
530   if (CheckCNodeNotReady(cnode)) {
531     MS_LOG(INFO) << "The node is not ready: " << cnode->DebugString();
532     return;
533   }
534   AbstractBasePtr possible_func = GetCNodeOperatorAbstract(cnode);
535   MS_EXCEPTION_IF_NULL(possible_func);
536   auto type = possible_func->BuildType();
537   MS_EXCEPTION_IF_NULL(type);
538   if (type->type_id() == kObjectTypeUndeterminedType) {
539     MS_LOG(INTERNAL_EXCEPTION) << "EvalCNode eval Undetermined";
540   }
541   abstract::AbstractFunctionPtr func = dyn_cast<abstract::AbstractFunction>(possible_func);
542   if (func == nullptr) {
543     MS_LOG(ERROR) << "Can not cast to a AbstractFunction: " << possible_func->ToString() << ".";
544     MS_EXCEPTION(ValueError) << "This may be not defined, and it can't be a operator. Please check code.";
545   }
546   abstract::AbstractFuncAtomPtrList abstractFuncList;
547   auto build_fuction = [&abstractFuncList](const abstract::AbstractFuncAtomPtr &poss) {
548     abstractFuncList.push_back(poss);
549   };
550   func->Visit(build_fuction);
551   for (const auto &abstractFunc : abstractFuncList) {
552     EvalAbstractFunction(abstractFunc, cnode, nullptr);
553   }
554 }
555 }  // namespace
InferMindir(const FuncGraphPtr & root,const AbstractBasePtrList & args,bool raise_exception)556 bool InferMindir(const FuncGraphPtr &root, const AbstractBasePtrList &args, bool raise_exception) {
557   auto engine = std::make_shared<MindIREngine>(root);
558   engine->SetException(raise_exception);
559   return engine->InferShape(args);
560 }
561 
ValidMindir(const FuncGraphPtr & root)562 bool ValidMindir(const FuncGraphPtr &root) {
563   MS_EXCEPTION_IF_NULL(root);
564   auto manager = root->manager();
565   if (manager == nullptr) {
566     manager = MakeManager();
567     manager->AddFuncGraph(root, true);
568   }
569   MS_LOG(DEBUG) << "Success to valid the mindir. " << root->ToString() << " : " << root.get();
570   return true;
571 }
572 
InferFuncGraphLoaded(const FuncGraphPtr & root)573 void InferFuncGraphLoaded(const FuncGraphPtr &root) {
574   abstract::AbstractBasePtrList func_args;
575   const auto &inputs = root->get_inputs();
576   (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(func_args),
577                        [](const AnfNodePtr &arg) -> AbstractBasePtr {
578                          MS_EXCEPTION_IF_NULL(arg);
579                          if (arg->abstract() == nullptr) {
580                            MS_LOG(EXCEPTION) << "The parameter's abstract is null:" << arg->DebugString();
581                          }
582                          return arg->abstract();
583                        });
584   (void)InferMindir(root, func_args);
585 }
586 }  // namespace mindspore
587