• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/static_analysis/prim.h"
20 
21 #include <algorithm>
22 #include <limits>
23 #include <mutex>
24 #include <string>
25 #include <utility>
26 #include <unordered_set>
27 
28 #include "frontend/operator/cc_implementations.h"
29 #include "frontend/operator/ops.h"
30 #include "frontend/operator/composite/do_signature.h"
31 #include "frontend/operator/prim_to_function.h"
32 #include "abstract/utils.h"
33 #include "utils/symbolic.h"
34 #include "pipeline/jit/resource.h"
35 #include "pipeline/jit/parse/resolve.h"
36 #include "utils/convert_utils.h"
37 #include "utils/convert_utils_py.h"
38 #include "utils/ms_context.h"
39 #include "pipeline/jit/parse/data_converter.h"
40 #include "abstract/primitive_infer_map.h"
41 #include "abstract/param_validator.h"
42 #include "utils/ms_utils.h"
43 #include "utils/shape_utils.h"
44 #include "utils/parallel_node_check.h"
45 
46 namespace mindspore {
47 namespace abstract {
48 using mindspore::parse::PyObjectWrapper;
49 
50 std::unordered_set<std::string> prims_to_skip_undetermined_infer{
51   "MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
52 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)53 EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
54                                         const AnfNodeConfigPtr &out_conf) {
55   MS_EXCEPTION_IF_NULL(engine);
56   MS_EXCEPTION_IF_NULL(out_conf);
57   AbstractBasePtrList args_spec_list;
58   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
59                        [](const ConfigPtr &ref) -> AbstractBasePtr {
60                          MS_EXCEPTION_IF_NULL(ref);
61                          MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult());
62                          return ref->ObtainEvalResult()->abstract();
63                        });
64   auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
65   MS_EXCEPTION_IF_NULL(do_signature);
66   auto &func = do_signature->function();
67   if (func->isa<Primitive>()) {
68     auto sig_prim = func->cast<PrimitivePtr>();
69     if (prims_to_skip_undetermined_infer.find(sig_prim->name()) == prims_to_skip_undetermined_infer.end()) {
70       auto ret_abstract = AbstractEval(args_spec_list);
71       if (ret_abstract != nullptr) {
72         MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined";
73         return ret_abstract;
74       }
75     }
76   }
77   if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
78     MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
79   }
80 
81   auto out_node = dyn_cast<CNode>(out_conf->node());
82   MS_EXCEPTION_IF_NULL(out_node);
83   const auto &out_node_inputs = out_node->inputs();
84   if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
85     MS_LOG(EXCEPTION) << "Op: " << func->ToString() << " args size should equal to inputs size minus 1, but args size "
86                       << args_conf_list.size() << ", inputs size " << out_node_inputs.size();
87   }
88   AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
89 
90   ScopePtr scope = kDefaultScope;
91   if (out_conf != nullptr) {
92     scope = out_conf->node()->scope();
93   }
94   ScopeGuard scope_guard(scope);
95 
96   AnfNodePtr new_node = nullptr;
97   if (bound_node() != nullptr) {
98     TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
99     new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
100   } else {
101     new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
102   }
103   AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
104 
105   if (out_node->isa<CNode>()) {
106     auto out_cnode = out_node->cast<CNodePtr>();
107     auto new_cnode = new_node->cast<CNodePtr>();
108     new_cnode->CloneCNodeInfo(out_cnode);
109   }
110 
111   return engine->ForwardConfig(out_conf, fn_conf);
112 }
113 
GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list,bool need_unpack)114 static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) {
115   // arg[0] is the func graph to unpack, ignore it
116   AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end());
117   AbstractBasePtrList graph_specialize_args;
118   if (need_unpack) {
119     for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
120       MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]);
121       if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) {
122         auto arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
123         std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
124                        std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
125       } else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
126         auto arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
127         auto dict_elems = arg_dict->elements();
128         (void)std::transform(
129           dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args),
130           [](const AbstractAttribute &item) { return std::make_shared<AbstractKeywordArg>(item.first, item.second); });
131       } else {
132         MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got "
133                           << specialize_args_before_unpack[index]->ToString();
134       }
135     }
136   } else {
137     graph_specialize_args = specialize_args_before_unpack;
138   }
139   return graph_specialize_args;
140 }
141 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)142 EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
143                                         const AnfNodeConfigPtr &out_conf) {
144   MS_EXCEPTION_IF_NULL(engine);
145   MS_EXCEPTION_IF_NULL(out_conf);
146   MS_EXCEPTION_IF_NULL(out_conf->node());
147   if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
148     MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
149   }
150 
151   auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
152   MS_EXCEPTION_IF_NULL(unpack_graph);
153   auto out_node = out_conf->node()->cast<CNodePtr>();
154   MS_EXCEPTION_IF_NULL(out_node);
155   const auto &out_node_inputs = out_node->inputs();
156   if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
157     MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
158                       << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
159                       << ", inputs size " << out_node_inputs.size();
160   }
161   AbstractBasePtrList args_spec_list;
162   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
163                        [](const ConfigPtr &ref) -> AbstractBasePtr {
164                          MS_EXCEPTION_IF_NULL(ref);
165                          MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult());
166                          return ref->ObtainEvalResult()->abstract();
167                        });
168   // get the forward graph
169   if (args_spec_list.empty()) {
170     MS_LOG(EXCEPTION) << "args_spec_list can't be empty.";
171   }
172   MS_EXCEPTION_IF_NULL(args_spec_list[0]);
173   auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
174   if (fn == nullptr) {
175     MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
176   }
177   auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>();
178   MS_EXCEPTION_IF_NULL(real_fn);
179   FuncGraphPtr forward_graph = real_fn->func_graph();
180   MS_EXCEPTION_IF_NULL(forward_graph);
181   AbstractBasePtrList graph_specialize_args =
182     GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
183   AbstractBasePtrList graph_specialize_args_without_sens;
184   if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
185     MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
186   }
187   (void)std::transform(graph_specialize_args.begin(),
188                        graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0),
189                        std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; });
190   auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens);
191   engine->func_graph_manager()->AddFuncGraph(new_graph);
192   ScopePtr scope = kDefaultScope;
193   if (out_conf != nullptr) {
194     scope = out_conf->node()->scope();
195   }
196   ScopeGuard scope_guard(scope);
197   AnfNodePtr new_vnode = NewValueNode(new_graph);
198   AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context(), out_conf->func_graph());
199 
200   return engine->ForwardConfig(out_conf, fn_conf);
201 }
202 
MixedPrecisionCastHelper(const AnfNodePtr & source_node,const AbstractBasePtr & node_type,const AnfNodePtr & target_type,const FuncGraphPtr & func_graph)203 AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const AbstractBasePtr &node_type,
204                                     const AnfNodePtr &target_type, const FuncGraphPtr &func_graph) {
205   MS_EXCEPTION_IF_NULL(node_type);
206   MS_EXCEPTION_IF_NULL(func_graph);
207   AnfNodePtr target_node = source_node;
208   if (node_type->isa<AbstractTensor>()) {
209     auto x = node_type->cast<AbstractTensorPtr>();
210     if (x->element()->BuildType()->isa<Float>()) {
211       auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
212       MS_EXCEPTION_IF_NULL(cast);
213       target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type});
214     }
215   } else if (node_type->isa<AbstractTuple>()) {
216     auto x = node_type->cast<AbstractTuplePtr>();
217     auto &items = x->elements();
218     std::vector<AnfNodePtr> nodes;
219     nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
220     int64_t idx = 0;
221     for (const auto &item : items) {
222       AnfNodePtr tuple_node =
223         func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)});
224       AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, item, target_type, func_graph);
225       nodes.emplace_back(node);
226       ++idx;
227     }
228     target_node = func_graph->NewCNode(nodes);
229   } else if (node_type->isa<AbstractDictionary>()) {
230     auto x = node_type->cast<AbstractDictionaryPtr>();
231     auto &items = x->elements();
232     std::vector<AnfNodePtr> dict_key_nodes;
233     std::vector<AnfNodePtr> dict_value_nodes;
234     dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
235     dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
236     for (const auto &item : items) {
237       AnfNodePtr dict_value_node =
238         func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)});
239       AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph);
240       dict_key_nodes.emplace_back(NewValueNode(item.first));
241       dict_value_nodes.emplace_back(node);
242     }
243     target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes),
244                                         func_graph->NewCNode(dict_value_nodes)});
245   } else if (node_type->isa<AbstractKeywordArg>()) {
246     auto x = node_type->cast<AbstractKeywordArgPtr>();
247     std::string kwarg_key = x->get_key();
248     AnfNodePtr kwarg_value_node =
249       func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node});
250     AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph);
251     target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node});
252   }
253   return target_node;
254 }
255 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)256 EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
257                                                const AnfNodeConfigPtr &out_conf) {
258   MS_EXCEPTION_IF_NULL(engine);
259   AbstractBasePtrList args_spec_list;
260   MS_EXCEPTION_IF_NULL(out_conf);
261   if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
262     MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
263   }
264   auto out_node = out_conf->node()->cast<CNodePtr>();
265   MS_EXCEPTION_IF_NULL(out_node);
266   const auto &out_node_inputs = out_node->inputs();
267   if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
268     MS_LOG(EXCEPTION) << "MixedPrecisionCast"
269                       << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
270                       << ", inputs size " << out_node_inputs.size();
271   }
272   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
273                        [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
274 
275   ScopePtr scope = kDefaultScope;
276   scope = out_conf->node()->scope();
277   ScopeGuard scope_guard(scope);
278 
279   FuncGraphPtr func_graph = out_node->func_graph();
280   constexpr size_t source_node_index = 2;
281   if (out_node_inputs.size() <= source_node_index) {
282     MS_LOG(EXCEPTION) << "Input size:" << out_node_inputs.size() << " should bigger than 2.";
283   }
284 
285   AnfNodePtr new_node =
286     MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph);
287   AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
288 
289   if (new_node->isa<CNode>()) {
290     auto new_cnode = new_node->cast<CNodePtr>();
291     new_cnode->CloneCNodeInfo(out_node);
292   }
293   return engine->ForwardConfig(out_conf, fn_conf);
294 }
295 
296 namespace {
BuildValue(const ValuePtr & value_ptr)297 py::object BuildValue(const ValuePtr &value_ptr) {
298   if (value_ptr == nullptr) {
299     return py::none();
300   } else {
301     return ValueToPyData(value_ptr);
302   }
303 }
304 
AbstractTupleToPython(const AbstractBasePtr & abs_base)305 py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
306   auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
307   MS_EXCEPTION_IF_NULL(arg_tuple);
308   size_t len = arg_tuple->size();
309   py::tuple shape_tuple(len);
310   py::tuple dtype_tuple(len);
311   py::tuple value_tuple(len);
312   py::tuple min_value_tuple(len);
313   py::tuple max_value_tuple(len);
314   py::tuple min_shape_tuple(len);
315   py::tuple max_shape_tuple(len);
316   bool dyn_shape = false;
317   bool dyn_value = false;
318 
319   for (size_t i = 0; i < len; i++) {
320     auto arg = arg_tuple->elements()[i];
321     py::dict out = ConvertAbstractToPython(arg);
322     shape_tuple[i] = out[ATTR_SHAPE];
323     dtype_tuple[i] = out[ATTR_DTYPE];
324     value_tuple[i] = out[ATTR_VALUE];
325 
326     // Elements in tuple is tensor shape value.
327     if (out.contains(py::str(ATTR_MIN_VALUE)) && out.contains(py::str(ATTR_MAX_VALUE))) {
328       min_value_tuple[i] = out[ATTR_MIN_VALUE];
329       max_value_tuple[i] = out[ATTR_MAX_VALUE];
330       dyn_value = true;
331     }
332 
333     // Elements in tuple is tensor, which shape is dynamic.
334     if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
335       min_shape_tuple[i] = out[ATTR_MIN_SHAPE];
336       max_shape_tuple[i] = out[ATTR_MAX_SHAPE];
337       dyn_shape = true;
338     }
339   }
340   auto dic = py::dict();
341   dic[ATTR_SHAPE] = shape_tuple;
342   dic[ATTR_DTYPE] = dtype_tuple;
343   MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
344   if (arg_tuple->BuildValue()->isa<AnyValue>()) {
345     dic[ATTR_VALUE] = py::none();
346   } else {
347     dic[ATTR_VALUE] = value_tuple;
348   }
349 
350   if (dyn_value) {
351     dic[ATTR_MIN_VALUE] = min_value_tuple;
352     dic[ATTR_MAX_VALUE] = max_value_tuple;
353   }
354   if (dyn_shape) {
355     dic[ATTR_MIN_SHAPE] = min_shape_tuple;
356     dic[ATTR_MAX_SHAPE] = max_shape_tuple;
357   }
358 
359   return dic;
360 }
361 
AbstractListToPython(const AbstractBasePtr & abs_base)362 py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
363   auto arg_list = dyn_cast<AbstractList>(abs_base);
364   MS_EXCEPTION_IF_NULL(arg_list);
365   size_t len = arg_list->size();
366   py::list shape_list(len);
367   py::list dtype_list(len);
368   py::list value_list(len);
369   py::list min_shape_list(len);
370   py::list max_shape_list(len);
371   bool dyn_shape = false;
372 
373   for (size_t i = 0; i < len; i++) {
374     py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
375     shape_list[i] = out[ATTR_SHAPE];
376     dtype_list[i] = out[ATTR_DTYPE];
377     value_list[i] = out[ATTR_VALUE];
378 
379     // Elements in list is tensor, which shape is dynamic.
380     if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
381       min_shape_list[i] = out[ATTR_MIN_SHAPE];
382       max_shape_list[i] = out[ATTR_MAX_SHAPE];
383       dyn_shape = true;
384     }
385   }
386   auto dic = py::dict();
387   dic[ATTR_SHAPE] = shape_list;
388   dic[ATTR_DTYPE] = dtype_list;
389   MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
390   if (arg_list->BuildValue()->isa<AnyValue>()) {
391     dic[ATTR_VALUE] = py::none();
392   } else {
393     dic[ATTR_VALUE] = value_list;
394   }
395 
396   if (dyn_shape) {
397     dic[ATTR_MIN_SHAPE] = min_shape_list;
398     dic[ATTR_MAX_SHAPE] = max_shape_list;
399   }
400 
401   return dic;
402 }
403 
ConvertAbstractTensorToPython(const AbstractBasePtr & abs_base,py::dict * dic)404 void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
405   auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
406   MS_EXCEPTION_IF_NULL(dic);
407   MS_EXCEPTION_IF_NULL(arg_tensor);
408   MS_EXCEPTION_IF_NULL(arg_tensor->shape());
409   (*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
410   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
411     const auto &min_shape = arg_tensor->shape()->min_shape();
412     const auto &max_shape = arg_tensor->shape()->max_shape();
413     if (!min_shape.empty() && !max_shape.empty()) {
414       (*dic)[ATTR_MIN_SHAPE] = min_shape;
415       (*dic)[ATTR_MAX_SHAPE] = max_shape;
416     }
417   }
418 
419   auto min_value = arg_tensor->get_min_value();
420   auto max_value = arg_tensor->get_max_value();
421   if (min_value != nullptr && max_value != nullptr) {
422     (*dic)[ATTR_MIN_VALUE] = BuildValue(min_value);
423     (*dic)[ATTR_MAX_VALUE] = BuildValue(max_value);
424   }
425 
426   (*dic)[ATTR_DTYPE] = arg_tensor->BuildType();
427   (*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
428 }
429 
ConvertAbstractFunctionToPython(const AbstractBasePtr & abs_base,py::dict * dic)430 void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
431   MS_EXCEPTION_IF_NULL(dic);
432   MS_EXCEPTION_IF_NULL(abs_base);
433   (*dic)[ATTR_SHAPE] = py::none();
434   (*dic)[ATTR_DTYPE] = abs_base->BuildType();
435   (*dic)[ATTR_VALUE] = py::none();
436   if (abs_base->isa<PartialAbstractClosure>()) {
437     AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
438     if (!args.empty()) {
439       MS_EXCEPTION_IF_NULL(args[0]->BuildValue());
440       auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
441       if (value != nullptr) {
442         (*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
443         (*dic)[ATTR_VALUE] = value->obj();
444       }
445     }
446   }
447 }
448 }  // end anonymous namespace
449 
ConvertAbstractToPython(const AbstractBasePtr & abs_base)450 py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
451   MS_EXCEPTION_IF_NULL(abs_base);
452   auto dic = py::dict();
453   if (abs_base->isa<AbstractTensor>()) {
454     ConvertAbstractTensorToPython(abs_base, &dic);
455   } else if (abs_base->isa<AbstractRowTensor>()) {
456     auto arg = dyn_cast<AbstractRowTensor>(abs_base);
457     dic[ATTR_SHAPE] = arg->shape()->shape();
458     dic[ATTR_DTYPE] = arg->BuildType();
459     dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
460   } else if (abs_base->isa<AbstractSparseTensor>()) {
461     auto arg = dyn_cast<AbstractSparseTensor>(abs_base);
462     dic[ATTR_SHAPE] = arg->shape()->shape();
463     dic[ATTR_DTYPE] = arg->BuildType();
464     dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
465   } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
466     ShapeVector shape;
467     dic[ATTR_SHAPE] = shape;
468     dic[ATTR_DTYPE] = abs_base->BuildType();
469     dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
470   } else if (abs_base->isa<AbstractSlice>()) {
471     auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
472     ShapeVector shape;
473     dic[ATTR_SHAPE] = shape;
474     dic[ATTR_DTYPE] = arg_slice->BuildType();
475     dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
476   } else if (abs_base->isa<AbstractEllipsis>()) {
477     dic[ATTR_SHAPE] = py::none();
478     dic[ATTR_DTYPE] = py::ellipsis();
479     dic[ATTR_VALUE] = py::ellipsis();
480   } else if (abs_base->isa<AbstractTuple>()) {
481     return AbstractTupleToPython(abs_base);
482   } else if (abs_base->isa<AbstractList>()) {
483     return AbstractListToPython(abs_base);
484   } else if (abs_base->isa<AbstractNone>()) {
485     dic[ATTR_SHAPE] = py::none();
486     dic[ATTR_DTYPE] = py::none();
487     dic[ATTR_VALUE] = py::none();
488   } else if (abs_base->isa<AbstractFunction>()) {
489     ConvertAbstractFunctionToPython(abs_base, &dic);
490   } else if (abs_base->isa<AbstractUndetermined>()) {
491     auto arg = dyn_cast<AbstractUndetermined>(abs_base);
492     dic[ATTR_SHAPE] = py::none();
493     dic[ATTR_DTYPE] = arg->BuildType();
494     dic[ATTR_VALUE] = py::none();
495   } else if (abs_base->isa<AbstractMonad>()) {
496     dic[ATTR_SHAPE] = py::none();
497     dic[ATTR_DTYPE] = abs_base->BuildType();
498     dic[ATTR_VALUE] = py::none();
499   } else {
500     auto value = abs_base->BuildValue();
501     MS_EXCEPTION_IF_NULL(value);
502     if ((*value == *kAnyValue)) {
503       auto value_desc = abs_base->value_desc();
504       MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
505                               << " for python primitive." << abs_base->ToString();
506     }
507     MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is "
508                             << value->ToString();
509   }
510   return dic;
511 }
512 
513 namespace {
PreparePyInputs(const PrimitivePyPtr &,const AbstractBasePtrList & args)514 py::tuple PreparePyInputs(const PrimitivePyPtr &, const AbstractBasePtrList &args) {
515   // The monad parameter is defined at the end of the parameter and needs to be ignored
516   std::size_t size_args = args.size() - GetAbstractMonadNum(args);
517   py::tuple py_args(size_args);
518   for (size_t i = 0; i < size_args; i++) {
519     auto arg_i = (args)[i];
520     py_args[i] = ConvertAbstractToPython(arg_i);
521   }
522   return py_args;
523 }
524 
CheckCustomPrimOutputInferResult(const PrimitivePtr & prim,const AbstractBasePtr & res_spec)525 void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) {
526   MS_EXCEPTION_IF_NULL(prim);
527   MS_EXCEPTION_IF_NULL(res_spec);
528   const string kOutputNum = "output_num";
529   if (prim->IsCustomPrim()) {
530     // Raise error if output_num is not match the infer result.
531     auto output_num_value = prim->GetAttr(kOutputNum);
532     if (output_num_value == nullptr) {
533       MS_LOG(DEBUG) << "The output num may no need to check";
534       return;
535     }
536     int64_t output_num = GetValue<int64_t>(output_num_value);
537     if (res_spec->isa<AbstractTensor>() && output_num != 1) {
538       MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
539                         << "]'s attribute[output_num]:" << output_num << " not matches the infer result "
540                         << res_spec->ToString();
541     } else if (res_spec->isa<AbstractTuple>() &&
542                (res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
543       MS_LOG(EXCEPTION) << "Custom primitive[" << prim->ToString() << "]'s attribute[output_num]:" << output_num
544                         << " not matches the infer result " << res_spec->ToString();
545     }
546   }
547 }
548 
PyInferRes2Abstract(const PrimitivePyPtr & prim_py,const py::dict & output)549 AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
550   // Convert to AbstractValue based on type and shape
551   auto out_dtype = output[ATTR_DTYPE];
552   if (output[ATTR_VALUE].is_none()) {
553     auto out_shape = output[ATTR_SHAPE];
554     return MakePyInferRes2Abstract(out_shape, out_dtype, output);
555   }
556   // Convert pyobject to Value, then to AbstractValue
557   ValuePtr converted_ret = nullptr;
558   TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
559   bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
560   if (!converted) {
561     MS_LOG(EXCEPTION) << "Convert data failed";
562   }
563   auto res_spec = FromValue(converted_ret);
564   MS_EXCEPTION_IF_NULL(res_spec);
565   if (res_spec->isa<AbstractTensor>()) {
566     // Replace to tensor constant node in specialize
567     auto res_tensor = res_spec->cast<AbstractTensorPtr>();
568     res_tensor->set_value(converted_ret);
569     SetValueRange(res_tensor, output);
570   }
571   CheckCustomPrimOutputInferResult(prim_py, res_spec);
572   return res_spec;
573 }
574 }  // end anonymous namespace
575 
RunPyInferValue(const AnalysisEnginePtr & engine,const AbstractBasePtr & abs_base,const AbstractBasePtrList & args)576 EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base,
577                                                      const AbstractBasePtrList &args) {
578   auto prim_py = dyn_cast<PrimitivePy>(prim_);
579   if (prim_py == nullptr) {
580     MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
581   }
582   // Call checking method 'infer_value' for python primitive
583   MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
584   auto py_args = PreparePyInputs(prim_py, args);
585   py::tuple py_vals(py_args.size());
586   auto added_attrs = prim_->evaluate_added_attrs();
587   for (size_t i = 0; i < py_args.size(); ++i) {
588     py_vals[i] = py_args[i][ATTR_VALUE];
589   }
590   py::object py_ret = prim_py->RunInferValue(py_vals);
591   if (py::isinstance<py::none>(py_ret)) {
592     return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
593   }
594   // Convert pyobject to Value, then to AbstractValue
595   ValuePtr converted_ret = nullptr;
596   TypePtr dtype = abs_base->BuildType();
597   bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype);
598   if (!converted) {
599     MS_LOG(EXCEPTION) << "Convert data failed";
600   }
601   auto res_spec = FromValue(converted_ret);
602   MS_EXCEPTION_IF_NULL(res_spec);
603   if (res_spec->isa<AbstractTensor>()) {
604     // Replace to tensor constant node in specialize
605     auto res_tensor = res_spec->cast<AbstractTensorPtr>();
606     res_tensor->set_value(converted_ret);
607   }
608   return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
609 }
610 
EvalPyCheckPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)611 EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
612   auto prim_py = dyn_cast<PrimitivePy>(prim_);
613   if (prim_py == nullptr) {
614     MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
615   }
616   // Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
617   MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
618   auto py_args = PreparePyInputs(prim_py, args);
619   prim_py->RunCheck(py_args);
620 
621   prim_->BeginRecordAddAttr();
622   AbstractBasePtr abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
623   prim_->EndRecordAddAttr();
624   auto added_attrs = prim_->evaluate_added_attrs();
625 
626   if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
627     return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
628   }
629   // Call method 'infer_value' for primitive with this method for constant propagation
630   return RunPyInferValue(engine, abs_base, args);
631 }
632 
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)633 EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
634   if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
635     auto ret_abstract = AbstractEval(args);
636     if (ret_abstract != nullptr) {
637       MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
638       return ret_abstract;
639     }
640   }
641   if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
642     return EvalPyCheckPrim(engine, args);
643   }
644   auto context = MsContext::GetInstance();
645   MS_EXCEPTION_IF_NULL(context);
646   bool need_infer_value = !eval_impl_.in_white_list_;
647   if (need_infer_value == false) {
648     need_infer_value = ((context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
649                        std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
650                          MS_EXCEPTION_IF_NULL(abs);
651                          auto value = abs->BuildValue();
652                          return (value != nullptr && !value->isa<AnyValue>() && !value->isa<None>() &&
653                                  !value->isa<Monad>() && !value->isa<FuncGraph>());
654                        });
655   }
656   AbstractBasePtr abs_base = nullptr;
657   ValuePtr value = nullptr;
658   prim_->BeginRecordAddAttr();
659   if (need_infer_value && eval_impl_.infer_value_impl_ != nullptr) {
660     value = eval_impl_.infer_value_impl_(prim_, args);
661     if (value != nullptr) {
662       abs_base = value->ToAbstract();
663       prim_->EndRecordAddAttr();
664       auto added_attrs = prim_->evaluate_added_attrs();
665       return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
666     }
667   }
668   abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
669   prim_->EndRecordAddAttr();
670   auto added_attrs = prim_->evaluate_added_attrs();
671   auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
672   return eval_result;
673 }
674 
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args)675 EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
676   auto ret_abstract = AbstractEval(args);
677   if (ret_abstract != nullptr) {
678     MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
679     return ret_abstract;
680   }
681   MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
682 
683   const auto eval_result = evaluator_cache_mgr_->GetValue(args);
684   if (eval_result != nullptr) {
685     return eval_result;
686   }
687 
688   auto py_args = PreparePyInputs(prim_py_, args);
689   prim_py_->BeginRecordAddAttr();
690   py::dict output = prim_py_->RunInfer(py_args);
691   prim_py_->EndRecordAddAttr();
692   auto added_attrs = prim_py_->evaluate_added_attrs();
693   MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
694   auto res_spec = PyInferRes2Abstract(prim_py_, output);
695   MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
696   auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
697   evaluator_cache_mgr_->SetValue(args, infer_result);
698   return infer_result;
699 }
700 
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args)701 EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
702   auto ret_abstract = AbstractEval(args);
703   if (ret_abstract != nullptr) {
704     MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
705     return ret_abstract;
706   }
707   // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
708   if (nargs_ != args.size()) {
709     MS_LOG(EXCEPTION) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
710   }
711   TypePtr ret_value_type = return_value_type_;
712   ValuePtrList value_list;
713   for (const auto &arg : args) {
714     // Check if all arguments are scalar type.
715     MS_EXCEPTION_IF_NULL(arg);
716     if (arg->isa<AbstractScalar>()) {
717       auto arg_scalar = dyn_cast<AbstractScalar>(arg);
718       auto arg_value = arg_scalar->GetValueTrack();
719       value_list.push_back(arg_value);
720     } else {
721       // Raise TypeError Expected Scalar.
722       MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives.";
723     }
724   }
725   for (const auto &item : type_map_) {
726     TypePtrList selections;
727     MS_EXCEPTION_IF_NULL(item.second);
728     (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections),
729                          [&args](size_t arg_idx) -> TypePtr {
730                            if (arg_idx >= args.size()) {
731                              MS_LOG(EXCEPTION) << "Index:" << arg_idx << " out of range:" << args.size();
732                            }
733                            MS_EXCEPTION_IF_NULL(args[arg_idx]);
734                            return args[arg_idx]->GetTypeTrack();
735                          });
736     TypePtr res = CheckTypeList(item.first, selections);
737     MS_EXCEPTION_IF_NULL(return_value_type_);
738     MS_EXCEPTION_IF_NULL(item.first);
739     if (*return_value_type_ == *(item.first)) {
740       ret_value_type = res;
741     }
742   }
743 
744   ValuePtr evaluated_value = RunImpl(value_list);
745   if (!(*evaluated_value == *kAnyValue)) {
746     ret_value_type = evaluated_value->type();
747   }
748   // for comparison primitives , return type shall have be specified to be bool.
749   if (specify_out_type_ != nullptr) {
750     ret_value_type = specify_out_type_;
751   }
752 
753   AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
754   return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
755 }
756 
RunImpl(const ValuePtrList & args) const757 ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
758   if (!eval_value_) {
759     return kAnyValue;
760   } else {
761     if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) {
762           MS_EXCEPTION_IF_NULL(arg);
763           return arg->isa<AnyValue>();
764         })) {
765       return kAnyValue;
766     }
767     return impl_(args);
768   }
769 }
770 
771 // Primitive implementation
772 // static function start
773 namespace {
InitStandardPrimEvaluator(PrimitivePtr primitive,const StandardPrimitiveImplReg eval_impl)774 EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveImplReg eval_impl) {
775   EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl);
776   return prim_evaluator;
777 }
778 
InitUniformPrimEvaluator(const PrimitivePtr & primitive,PrimitiveImpl prim_impl,bool eval_value,const TypePtr & specify_out_type)779 EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value,
780                                       const TypePtr &specify_out_type) {
781   FunctionPtr func = nullptr;
782   (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func);
783   MS_EXCEPTION_IF_NULL(func);
784 
785   EvaluatorPtr uniform_primitive_evaluator =
786     std::make_shared<UniformPrimEvaluator>(func, prim_impl, eval_value, specify_out_type);
787   return uniform_primitive_evaluator;
788 }
789 
PyObjToGraph(const AnalysisEnginePtr & engine,const ValuePtr & method)790 FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) {
791   MS_EXCEPTION_IF_NULL(engine);
792   MS_EXCEPTION_IF_NULL(method);
793   if (!method->isa<parse::PyObjectWrapper>()) {
794     MS_LOG(EXCEPTION) << "Method type error: " << method->ToString();
795   }
796 
797   std::shared_ptr<PyObjectWrapper> obj = method->cast<std::shared_ptr<PyObjectWrapper>>();
798   FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj());
799   if (func_graph == nullptr) {
800     MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed";
801   }
802 
803   FuncGraphManagerPtr manager = engine->func_graph_manager();
804   manager->AddFuncGraph(func_graph);
805   return func_graph;
806 }
807 
AddToManager(const AnalysisEnginePtr & engine,const FuncGraphPtr func_graph)808 inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) {
809   MS_EXCEPTION_IF_NULL(engine);
810   FuncGraphManagerPtr manager = engine->func_graph_manager();
811   manager->AddFuncGraph(func_graph);
812 }
813 
814 enum class REQUIRE_TYPE { ATTR, METHOD };
815 
StaticGetterInferred(const ValuePtr & value,const ConfigPtr & data_conf,const AnfNodeConfigPtr & old_conf,REQUIRE_TYPE require_type=REQUIRE_TYPE::METHOD)816 EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf,
817                                    REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
818   MS_EXCEPTION_IF_NULL(old_conf);
819 
820   AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
821   AbstractFunctionPtr abs_func = dyn_cast<abstract::AbstractFunction>(abs_ptr);
822   MS_EXCEPTION_IF_NULL(abs_func);
823 
824   // Create new cnode
825   std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
826   auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abs_func);
827   if (func_graph_func != nullptr) {
828     FuncGraphPtr fg = func_graph_func->func_graph();
829     input.push_back(NewValueNode(fg));
830   } else {
831     auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abs_func);
832     MS_EXCEPTION_IF_NULL(prim_func);
833     PrimitivePtr prim = prim_func->prim();
834     input.push_back(NewValueNode(prim));
835   }
836 
837   AnfNodeConfigPtr conf = dyn_cast<abstract::AnfNodeConfig>(data_conf);
838   MS_EXCEPTION_IF_NULL(conf);
839   input.push_back(conf->node());
840   MS_EXCEPTION_IF_NULL(old_conf);
841   FuncGraphPtr func_graph = old_conf->node()->func_graph();
842   MS_EXCEPTION_IF_NULL(func_graph);
843   CNodePtr new_cnode = func_graph->NewCNode(input);
844   if (require_type == REQUIRE_TYPE::ATTR) {
845     new_cnode = func_graph->NewCNode({new_cnode});
846   }
847   AnalysisEnginePtr eng = old_conf->engine();
848   AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context(), old_conf->func_graph());
849   return eng->ForwardConfig(old_conf, fn_conf);
850 }
851 
GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &,const AbstractBasePtrList & args_spec_list,const AnfNodeConfigPtr & out_conf)852 EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list,
853                                                   const AnfNodeConfigPtr &out_conf) {
854   // args_spec_list: same as StaticGetter
855   if (args_spec_list.size() < 2) {
856     MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
857   }
858   MS_EXCEPTION_IF_NULL(out_conf);
859   // An external type.
860   MS_EXCEPTION_IF_NULL(args_spec_list[0]);
861   MS_EXCEPTION_IF_NULL(args_spec_list[1]);
862   MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
863   MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
864   auto data_v = args_spec_list[0]->BuildValue();
865   MS_EXCEPTION_IF_NULL(data_v);
866   if (!data_v->isa<parse::NameSpace>()) {
867     MS_LOG(EXCEPTION) << "Not supported to get attribute for " << data_v->ToString()
868                       << "\nThe data should be a NameSpace.";
869   }
870 
871   auto item_value = args_spec_list[1]->BuildValue();
872   MS_EXCEPTION_IF_NULL(item_value);
873   if (item_value->isa<StringImm>()) {
874     item_value = std::make_shared<parse::Symbol>(item_value->cast<StringImmPtr>()->value());
875   }
876 
877   if (!item_value->isa<parse::Symbol>()) {
878     MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
879   }
880 
881   // item_name to func addr from obj_map
882   parse::SymbolPtr symbol = item_value->cast<parse::SymbolPtr>();
883   parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
884   MS_EXCEPTION_IF_NULL(out_conf);
885   auto out_node = out_conf->node();
886   FuncGraphPtr func_graph = out_node->func_graph();
887   MS_EXCEPTION_IF_NULL(func_graph);
888   auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_node);
889   if (new_node == nullptr) {
890     MS_LOG(EXCEPTION) << "Resolve node failed";
891   }
892 
893   // Replace old node with the resolved new node in order list.
894   func_graph->ReplaceInOrder(out_node, new_node);
895 
896   AnalysisEnginePtr eng = out_conf->engine();
897   MS_EXCEPTION_IF_NULL(eng);
898   AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
899   return eng->ForwardConfig(out_conf, fn_conf);
900 }
901 
GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ValuePtr & item_value,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)902 EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
903                                                     const AbstractBasePtrList &args_spec_list,
904                                                     const ValuePtr &item_value, const ConfigPtr &data_conf,
905                                                     const AnfNodeConfigPtr &out_conf) {
906   if (args_spec_list.empty()) {
907     MS_LOG(EXCEPTION) << "args_spec_list is empty";
908   }
909   AbstractClassPtr cls = CheckArg<AbstractClass>("__FUNC__", args_spec_list, 0);
910 
911   // If item_value is an attribute, get abstract value from AbstractClass
912   MS_EXCEPTION_IF_NULL(item_value);
913   if (!item_value->isa<StringImm>()) {
914     MS_LOG(EXCEPTION) << "Attribute type error";
915   }
916   std::string item_name = item_value->cast<StringImmPtr>()->value();
917   MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name();
918   MS_LOG(DEBUG) << "Resolve item: " << item_name;
919   MS_EXCEPTION_IF_NULL(cls);
920   AbstractBasePtr attr = cls->GetAttribute(item_name);
921   if (attr != nullptr) {
922     return std::make_shared<EvalResult>(attr, nullptr);
923   }
924 
925   ValuePtr method = cls->GetMethod(item_name);
926   if (method->isa<AnyValue>()) {
927     MS_EXCEPTION_IF_NULL(args_spec_list[0]);
928     MS_EXCEPTION_IF_NULL(args_spec_list[0]->BuildType());
929     MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
930                                  << ", item value: " << item_value->ToString();
931   }
932 
933   // Infer class method
934   ValuePtr converted_value = PyObjToGraph(engine, method);
935   return StaticGetterInferred(converted_value, data_conf, out_conf);
936 }
937 
GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr & engine,const ValuePtr & item_value,const TypePtr & data_type,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)938 EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
939                                                           const TypePtr &data_type, const ConfigPtr &data_conf,
940                                                           const AnfNodeConfigPtr &out_conf) {
941   MS_EXCEPTION_IF_NULL(item_value);
942   MS_EXCEPTION_IF_NULL(data_type);
943   // The method maybe a Primitive or Composite
944   if (!item_value->isa<StringImm>()) {
945     MS_LOG(EXCEPTION) << "Error item is not string";
946   }
947 
948   std::string item_name = item_value->cast<StringImmPtr>()->value();
949   REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
950   Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
951   if (require.empty()) {
952     require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
953     if (require.empty()) {
954       MS_LOG(EXCEPTION) << "Not supported to get attribute item name:\'" << item_name << "\' of a type["
955                         << data_type->ToString() << "]";
956     }
957     require_type = REQUIRE_TYPE::ATTR;
958   }
959 
960   ValuePtr converted_value = nullptr;
961   if (require.is<std::string>()) {
962     // composite registered in standard_method_map go to this branch
963     converted_value = prim::GetPythonOps(require.cast<std::string>());
964     MS_EXCEPTION_IF_NULL(converted_value);
965     if (!converted_value->isa<Primitive>()) {
966       AddToManager(engine, converted_value->cast<FuncGraphPtr>());
967     }
968   } else if (require.is<PrimitivePtr>()) {
969     converted_value = require.cast<PrimitivePtr>();
970   } else {
971     MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
972   }
973   return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
974 }
975 
976 enum ResolveType : int64_t {
977   kResolveTypeUserDefineClass = 1,
978   kResolveTypeBuiltInType,
979   kResolveTypeFunction,
980 };
981 
GetResolveType(const TypePtr & data_type)982 int64_t GetResolveType(const TypePtr &data_type) {
983   MS_EXCEPTION_IF_NULL(data_type);
984   if (data_type->type_id() == kObjectTypeClass) {
985     return kResolveTypeUserDefineClass;
986   }
987   // Try to search method map, if not found, the data_type should be External type.
988   if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
989     return kResolveTypeBuiltInType;
990   }
991   return kResolveTypeFunction;
992 }
993 
StaticGetter(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)994 EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
995                            const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
996   // Inputs: namespace and its static function; or class and its member function
997   CheckArgsSize("StaticGetter", args_spec_list, 2);
998 
999   MS_EXCEPTION_IF_NULL(args_spec_list[0]);
1000   MS_EXCEPTION_IF_NULL(args_spec_list[1]);
1001   TypePtr data_type = args_spec_list[0]->BuildType();
1002   ValuePtr item_value = args_spec_list[1]->BuildValue();
1003   ScopePtr scope = kDefaultScope;
1004   if (out_conf != nullptr) {
1005     scope = out_conf->node()->scope();
1006   }
1007   ScopeGuard scope_guard(scope);
1008   MS_EXCEPTION_IF_NULL(item_value);
1009   if (item_value->isa<AnyValue>()) {
1010     MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
1011   }
1012 
1013   int64_t resolve_type = GetResolveType(data_type);
1014   if (resolve_type == kResolveTypeUserDefineClass) {
1015     return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
1016   } else if (resolve_type == kResolveTypeBuiltInType) {
1017     return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
1018   } else {
1019     return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
1020   }
1021 }
1022 }  // end anonymous namespace
1023 
1024 namespace {
1025 class EmbedEvaluator : public SymbolicPrimEvaluator {
1026  public:
EmbedEvaluator()1027   EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
1028   ~EmbedEvaluator() override = default;
1029   MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
EvalPrim(const ConfigPtrList & args_conf_list)1030   EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
1031     // arg: free variable to be embedded
1032     if (args_conf_list.size() != 1) {
1033       MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
1034     }
1035     AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
1036     MS_EXCEPTION_IF_NULL(node_conf);
1037     MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
1038     AbstractBasePtr x = node_conf->ObtainEvalResult()->abstract();
1039     x = SensitivityTransform(x);
1040     SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
1041     AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
1042     return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
1043   }
1044 };
1045 
FindParameterNodeByString(const FuncGraphManagerPtr & manager,const std::string & name)1046 static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) {
1047   MS_EXCEPTION_IF_NULL(manager);
1048   auto root_g_set = manager->roots();
1049   if (root_g_set.size() != 1) {
1050     return nullptr;
1051   }
1052   const FuncGraphPtr &root_g = root_g_set.back();
1053   for (auto &param_node : root_g->parameters()) {
1054     auto param = param_node->cast<ParameterPtr>();
1055     if (param && name == param->name()) {
1056       return param;
1057     }
1058   }
1059   return nullptr;
1060 }
1061 
1062 class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
1063  public:
RefToEmbedEvaluator()1064   RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
1065   ~RefToEmbedEvaluator() override = default;
1066   MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
EvalPrim(const ConfigPtrList & args_conf_list)1067   EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
1068     if (args_conf_list.size() != 1) {
1069       MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
1070       return nullptr;
1071     }
1072     static TypePtr type = std::make_shared<SymbolicKeyType>();
1073     auto node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
1074     if (node_conf == nullptr) {
1075       MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
1076       return nullptr;
1077     }
1078     MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
1079     AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract();
1080     MS_EXCEPTION_IF_NULL(abs);
1081     AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
1082     if (ref_abs == nullptr) {
1083       MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
1084       return nullptr;
1085     }
1086     auto key_abs = ref_abs->ref_key();
1087     if (key_abs == nullptr) {
1088       MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr.";
1089       return nullptr;
1090     }
1091     auto key_value = key_abs->BuildValue();
1092     if (key_value == nullptr) {
1093       MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
1094       return nullptr;
1095     }
1096     auto refkey = key_value->cast<RefKeyPtr>();
1097     if (refkey == nullptr) {
1098       auto ret = std::make_shared<AbstractScalar>(type);
1099       auto ref_value = ref_abs->ref();
1100       MS_EXCEPTION_IF_NULL(ref_value);
1101       return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
1102     }
1103 
1104     std::string name = refkey->tag();
1105     MS_EXCEPTION_IF_NULL(node_conf->node());
1106     if (node_conf->node()->func_graph() == nullptr) {
1107       MS_LOG(EXCEPTION) << "Should not evaluate a ValueNode, node: " << node_conf->node()->DebugString();
1108     }
1109     const auto &manager = node_conf->node()->func_graph()->manager();
1110     auto node = FindParameterNodeByString(manager, name);
1111     if (node == nullptr) {
1112       MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
1113       return nullptr;
1114     }
1115     AbstractBasePtr x = ref_abs->ref();
1116     x = SensitivityTransform(x);
1117     std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
1118     std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
1119     return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
1120   }
1121 };
1122 
1123 class GetAttrEvaluator : public TransitionPrimEvaluator {
1124  public:
GetAttrEvaluator()1125   GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
1126   ~GetAttrEvaluator() override = default;
1127   MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ConfigPtr & in_conf0,const AnfNodeConfigPtr & out_conf)1128   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
1129                          const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
1130     constexpr auto kGetAttrArgSize = 2;
1131     auto ret_abstract = AbstractEval(args_spec_list);
1132     if (ret_abstract != nullptr) {
1133       MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
1134       return ret_abstract;
1135     }
1136     // Inputs: data, item
1137     if (args_spec_list.size() != kGetAttrArgSize) {
1138       MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
1139     }
1140     EvalResultPtr ret = nullptr;
1141     if (bound_node() != nullptr) {
1142       TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
1143       ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
1144     } else {
1145       ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
1146     }
1147     // don't lookup from cache, as different out_conf with same node but different context
1148     // may add different entry to anfnode_config_map, like getattr primitive;
1149     evaluator_cache_mgr_->SetValue(args_spec_list, ret);
1150     return ret;
1151   }
1152 };
1153 
1154 class ResolveEvaluator : public TransitionPrimEvaluator {
1155  public:
ResolveEvaluator()1156   ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
1157   ~ResolveEvaluator() override = default;
1158   MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ConfigPtr & in_conf0,const AnfNodeConfigPtr & out_conf)1159   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
1160                          const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
1161     constexpr auto kResolveArgSize = 2;
1162     // Inputs: namespace, symbol
1163     if (args_spec_list.size() != kResolveArgSize) {
1164       MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
1165     }
1166     EvalResultPtr ret = nullptr;
1167     if (bound_node() != nullptr) {
1168       TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
1169       ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
1170     } else {
1171       ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
1172     }
1173     return ret;
1174   }
1175 };
1176 
1177 class CreateInstanceEvaluator : public TransitionPrimEvaluator {
1178  public:
CreateInstanceEvaluator()1179   CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
1180   ~CreateInstanceEvaluator() override = default;
1181   MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)1182   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
1183                          const AnfNodeConfigPtr &out_conf) override {
1184     if (args_spec_list.empty()) {
1185       MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
1186     }
1187 
1188     // Get the type parameter.
1189     MS_EXCEPTION_IF_NULL(args_spec_list[0]);
1190     TypePtr type = args_spec_list[0]->GetTypeTrack();
1191     MS_EXCEPTION_IF_NULL(type);
1192     if (type->type_id() != kMetaTypeTypeType) {
1193       MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
1194                         << type->ToString();
1195     }
1196 
1197     ValuePtr value_track = args_spec_list[0]->GetValueTrack();
1198     MS_EXCEPTION_IF_NULL(value_track);
1199 
1200     std::shared_ptr<parse::PyObjectWrapper> type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
1201     if (type_obj == nullptr) {
1202       MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
1203     }
1204 
1205     if (!type_obj->isa<parse::ClassType>()) {
1206       MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got "
1207                         << type_obj->ToString() << ".";
1208     }
1209 
1210     auto class_type = type_obj->obj();
1211     MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
1212 
1213     // Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
1214     py::tuple params = GetParameters(args_spec_list);
1215 
1216     // Create class instance.
1217     auto obj = parse::data_converter::CreatePythonObject(class_type, params);
1218     if (py::isinstance<py::none>(obj)) {
1219       MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type)
1220                         << "` failed, only support to create \'Cell\' or \'Primitive\' object.";
1221     }
1222 
1223     // Process the object.
1224     ValuePtr converted_ret = nullptr;
1225     bool converted = parse::ConvertData(obj, &converted_ret, true);
1226     if (!converted) {
1227       MS_LOG(EXCEPTION) << "Convert the python object failed";
1228     }
1229     MS_EXCEPTION_IF_NULL(converted_ret);
1230 
1231     if (converted_ret->isa<FuncGraph>()) {
1232       AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
1233     }
1234 
1235     AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
1236     auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
1237     evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
1238     return infer_result;
1239   }
1240 
GetParameters(const AbstractBasePtrList & args_spec_list) const1241   py::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
1242     // Exclude class type by minus 1;
1243     std::size_t params_size = args_spec_list.size() - 1;
1244     auto params = py::tuple(params_size);
1245     if (params_size > params.size()) {
1246       MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size << ", params.size():" << params.size();
1247     }
1248     if (params_size > 0) {
1249       for (size_t i = 0; i < params_size; i++) {
1250         // Only support the Scalar parameters type. Bypass class type by offset with 1.
1251         auto arg = args_spec_list[i + 1];
1252         MS_EXCEPTION_IF_NULL(arg);
1253         // Because the Tensor's AbstractTensor can't get value from GetValueTrack.
1254         ValuePtr param_value = arg->BuildValue();
1255         py::object param = ValueToPyData(param_value);
1256         params[i] = param;
1257       }
1258     }
1259     return params;
1260   }
1261 };
1262 
1263 class PyInterpretEvaluator : public TransitionPrimEvaluator {
1264  public:
PyInterpretEvaluator()1265   PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
1266   ~PyInterpretEvaluator() override = default;
1267   MS_DECLARE_PARENT(PyInterpretEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)1268   EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
1269                          const AnfNodeConfigPtr &out_conf) override {
1270     if (args_spec_list.empty()) {
1271       MS_LOG(ERROR) << "'args_spec_list' should not be empty";
1272     }
1273 
1274     // Get the type parameter.
1275     MS_EXCEPTION_IF_NULL(args_spec_list[0]);
1276     ValuePtr value_track = args_spec_list[0]->GetValueTrack();
1277     MS_EXCEPTION_IF_NULL(value_track);
1278 
1279     std::shared_ptr<parse::Script> script_obj = dyn_cast<parse::Script>(value_track);
1280     if (script_obj == nullptr) {
1281       MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
1282     }
1283 
1284     // Make global and local parameters.
1285     py::tuple params = MakeParameters(args_spec_list);
1286 
1287     // Call python script string.
1288     MS_LOG(DEBUG) << "Call script: " << script_obj->script() << ", params: " << py::str(params);
1289     auto obj = parse::data_converter::CallPythonScript(py::str(script_obj->script()), params);
1290     if (py::isinstance<py::none>(obj)) {
1291       MS_LOG(EXCEPTION) << "Failed to call python script: `" << script_obj->script() << "`";
1292     }
1293 
1294     ValuePtr converted_val = nullptr;
1295     bool converted = parse::ConvertData(obj, &converted_val, true);
1296     if (!converted) {
1297       MS_LOG(EXCEPTION) << "Convert the python object failed";
1298     }
1299     MS_EXCEPTION_IF_NULL(converted_val);
1300 
1301     AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
1302     auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
1303     evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
1304     return infer_result;
1305   }
1306 
MakeParameters(const AbstractBasePtrList & args_spec_list) const1307   py::tuple MakeParameters(const AbstractBasePtrList &args_spec_list) const {
1308     constexpr int params_size = 3;
1309     if (params_size != args_spec_list.size()) {
1310       MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size
1311                         << ", not equal to arguments.size:" << args_spec_list.size();
1312     }
1313     // The first argument is script string, ignore it.
1314     auto params = py::tuple(params_size - 1);
1315 
1316     // Make the global parameters.
1317     auto global_dict = dyn_cast<AbstractDictionary>(args_spec_list[1]);  // Global parameters dict.
1318     MS_EXCEPTION_IF_NULL(global_dict);
1319     MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString() << ", [" << global_dict->type_name() << "]";
1320     ValuePtr global_dict_value = global_dict->BuildValue();
1321     py::object global_params_dict = ValueToPyData(global_dict_value);
1322     MS_LOG(DEBUG) << "arg_1, python global_params_dict: " << py::str(global_params_dict);
1323     params[0] = global_params_dict;
1324 
1325     // Make the local parameters.
1326     auto local_dict = dyn_cast<AbstractDictionary>(args_spec_list[2]);  // Local parameters dict.
1327     MS_EXCEPTION_IF_NULL(local_dict);
1328     MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString() << ", [" << local_dict->type_name() << "]";
1329     ValuePtr local_dict_value = local_dict->BuildValue();
1330     py::object local_params_dict = ValueToPyData(local_dict_value);
1331     MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << py::str(local_params_dict);
1332     params[1] = local_params_dict;
1333     return params;
1334   }
1335 };
1336 
1337 class PartialEvaluator : public Evaluator {
1338  public:
PartialEvaluator()1339   PartialEvaluator() : Evaluator("PartialEvaluator") {}
1340   ~PartialEvaluator() override = default;
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)1341   EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
1342                     const AnfNodeConfigPtr &out_conf) override {
1343     if (args_conf_list.size() == 0) {
1344       MS_LOG(EXCEPTION) << "Args size should be greater than 0";
1345     }
1346 
1347     MS_EXCEPTION_IF_NULL(out_conf);
1348     MS_EXCEPTION_IF_NULL(out_conf->node());
1349     MS_EXCEPTION_IF_NULL(args_conf_list[0]);
1350     MS_EXCEPTION_IF_NULL(args_conf_list[0]->ObtainEvalResult());
1351     auto arg0_value = args_conf_list[0]->ObtainEvalResult()->abstract();
1352     MS_EXCEPTION_IF_NULL(arg0_value);
1353     AbstractBasePtrList args_spec_list{arg0_value};
1354     // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
1355     if (arg0_value->isa<AbstractError>()) {
1356       MS_EXCEPTION_IF_NULL(arg0_value->GetValueTrack());
1357       auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
1358       MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
1359                     << " as func is: " << arg0_value->ToString();
1360       auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
1361       evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
1362       return eval_result;
1363     }
1364     auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
1365     // Sometimes, node[0] in out_conf becomes phi0;
1366     if (func->isa<PrimitiveAbstractClosure>()) {
1367       auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
1368       MS_EXCEPTION_IF_NULL(prim_func->prim());
1369       if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
1370         prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(prim_func->prim());
1371         return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
1372       }
1373     }
1374 
1375     (void)std::transform(
1376       args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
1377       [](const ConfigPtr &config) -> AbstractBasePtr { return config->ObtainEvalResult()->abstract(); });
1378     AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
1379 
1380     auto cnode = out_conf->node()->cast<CNodePtr>();
1381     MS_EXCEPTION_IF_NULL(cnode);
1382     if (cnode->size() != (args_conf_list.size() + 1)) {
1383       MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString()
1384                         << ", args_conf_list: " << mindspore::ToString(args_conf_list);
1385     }
1386     AbstractFuncAtomPtrList partial_funcs_list;
1387     auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) {
1388       auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode);
1389       partial_funcs_list.push_back(new_func);
1390     };
1391     func->Visit(build_partial);
1392 
1393     auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
1394     auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
1395     evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
1396     return eval_result;
1397   }
1398 
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)1399   EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
1400     MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
1401   }
1402 
HandleDoSignature(const AnalysisEnginePtr & engine,const ValuePtr & signature_value,const AnfNodeConfigPtr & out_conf) const1403   EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
1404                                   const AnfNodeConfigPtr &out_conf) const {
1405     MS_EXCEPTION_IF_NULL(engine);
1406     MS_EXCEPTION_IF_NULL(out_conf);
1407     MS_EXCEPTION_IF_NULL(out_conf->node());
1408     auto cnode = out_conf->node()->cast<CNodePtr>();
1409     if (cnode == nullptr) {
1410       MS_LOG(EXCEPTION) << "Cnode is nullptr";
1411     }
1412     std::vector<AnfNodePtr> new_nodes_inputs = cnode->inputs();
1413     auto new_signature_value = std::make_shared<prim::DoSignatureMetaFuncGraph>("signature", signature_value);
1414     new_nodes_inputs[1] = NewValueNode(new_signature_value);
1415     FuncGraphPtr func_graph = cnode->func_graph();
1416 
1417     ScopePtr scope = out_conf->node()->scope();
1418     ScopeGuard scope_guard(scope);
1419     MS_EXCEPTION_IF_NULL(func_graph);
1420     CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
1421     AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
1422     return engine->ForwardConfig(out_conf, fn_conf);
1423   }
1424 };
1425 
1426 struct PrimitiveImplInferValue {
1427   PrimitiveImpl impl_;        // implement function of primitive
1428   bool eval_value_;           // whether evaluate value
1429   TypePtr specify_out_type_;  // whether specify return type
1430   bool in_white_list_;        // true if this Primitive in white list, else false.
1431 };
1432 
1433 using PrimitiveToImplMap = std::unordered_map<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
GetUniformPrimitiveToImplMap()1434 PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
1435   static PrimitiveToImplMap uniform_prim_implement_map = {
1436     {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}},
1437     {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}},
1438     {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
1439     {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}},
1440     {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}},
1441     {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}},
1442     {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}},
1443     {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}},
1444     {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}},
1445     {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}},
1446     {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared<Bool>(), true}},
1447     {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared<Bool>(), true}},
1448     {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared<Bool>(), true}},
1449     {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared<Bool>(), true}},
1450     {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared<Bool>(), true}},
1451     {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared<Bool>(), true}},
1452     {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared<Bool>(), true}},
1453     {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}},
1454     {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}},
1455     {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}},
1456   };
1457   return uniform_prim_implement_map;
1458 }
1459 
1460 PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap();
1461 std::mutex PrimEvaluatorConstructorMutex;
1462 
InitPrimEvaluatorConstructors()1463 void InitPrimEvaluatorConstructors() {
1464   PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
1465 
1466   for (const auto &iter : GetPrimitiveToEvalImplMap()) {
1467     constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second);
1468   }
1469 
1470   for (const auto &iter : GetUniformPrimitiveToImplMap()) {
1471     constructor[iter.first] =
1472       InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
1473   }
1474   constructor[prim::kPrimEmbed] = std::make_shared<EmbedEvaluator>();
1475   constructor[prim::kPrimRefToEmbed] = std::make_shared<RefToEmbedEvaluator>();
1476   constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
1477   constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
1478   constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
1479   constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
1480   constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
1481 }
1482 }  // namespace
1483 
ClearPrimEvaluatorMap()1484 void ClearPrimEvaluatorMap() {
1485   PrimEvaluatorConstructors.clear();
1486   GetPrimitiveToEvalImplMap().clear();
1487   GetUniformPrimitiveToImplMap().clear();
1488 }
1489 
IsInWhiteList(const PrimitivePtr & primitive)1490 bool IsInWhiteList(const PrimitivePtr &primitive) {
1491   MS_EXCEPTION_IF_NULL(primitive);
1492 
1493   auto iter = GetPrimitiveToEvalImplMap().find(primitive);
1494   if (iter != GetPrimitiveToEvalImplMap().end()) {
1495     return iter->second.in_white_list_;
1496   }
1497 
1498   auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
1499   if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
1500     return uni_iter->second.in_white_list_;
1501   }
1502 
1503   return false;
1504 }
1505 
GetPrimEvaluatorConstructors()1506 PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
1507   PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
1508   if (!constructor.empty()) {
1509     return constructor;
1510   }
1511   std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
1512   if (constructor.empty()) {
1513     InitPrimEvaluatorConstructors();
1514   }
1515 
1516   return constructor;
1517 }
1518 
1519 namespace {
IsSubtypeTuple(const AbstractBasePtr x,const TypePtr model)1520 bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
1521   MS_EXCEPTION_IF_NULL(x);
1522   MS_EXCEPTION_IF_NULL(model);
1523   auto x_tuple = dyn_cast<AbstractTuple>(x);
1524   auto model_tuple = dyn_cast<Tuple>(model);
1525 
1526   if (x_tuple == nullptr || model_tuple == nullptr) {
1527     return false;
1528   }
1529 
1530   if (model->IsGeneric()) {
1531     return true;
1532   }
1533 
1534   if (x_tuple->size() != model_tuple->size()) {
1535     return false;
1536   }
1537 
1538   for (size_t i = 0; i < x_tuple->size(); i++) {
1539     bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]);
1540     if (!is_subtype) {
1541       return false;
1542     }
1543   }
1544   return true;
1545 }
1546 
IsSubtypeArray(const AbstractBasePtr x,const TypePtr model)1547 bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
1548   MS_EXCEPTION_IF_NULL(x);
1549   MS_EXCEPTION_IF_NULL(model);
1550   auto x_tensor = dyn_cast<AbstractTensor>(x);
1551   auto model_tensor = dyn_cast<TensorType>(model);
1552 
1553   if (x_tensor == nullptr || model_tensor == nullptr) {
1554     return false;
1555   }
1556 
1557   if (model->IsGeneric()) {
1558     return true;
1559   }
1560 
1561   return IsSubtype(x_tensor->element(), model_tensor->element());
1562 }
1563 
IsSubtypeList(const AbstractBasePtr x,const TypePtr model)1564 bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
1565   MS_EXCEPTION_IF_NULL(x);
1566   MS_EXCEPTION_IF_NULL(model);
1567   auto x_list = dyn_cast<AbstractList>(x);
1568   auto model_list = dyn_cast<List>(model);
1569 
1570   if (x_list == nullptr || model_list == nullptr) {
1571     return false;
1572   }
1573 
1574   if (model->IsGeneric()) {
1575     return true;
1576   }
1577 
1578   if (x_list->size() != model_list->size()) {
1579     return false;
1580   }
1581 
1582   bool is_subtype = true;
1583   for (size_t i = 0; i < x_list->size(); i++) {
1584     is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]);
1585     if (!is_subtype) {
1586       return false;
1587     }
1588   }
1589   return is_subtype;
1590 }
1591 
IsSubtypeClass(const AbstractBasePtr x,const TypePtr model)1592 bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) {
1593   MS_EXCEPTION_IF_NULL(x);
1594   MS_EXCEPTION_IF_NULL(model);
1595   auto x_class = dyn_cast<AbstractClass>(x);
1596   auto model_class = dyn_cast<Class>(model);
1597   if (x_class == nullptr) {
1598     return false;
1599   }
1600   if (model->IsGeneric()) {
1601     return true;
1602   }
1603   MS_EXCEPTION_IF_NULL(model_class);
1604   if (x_class->tag() == model_class->tag()) {
1605     auto m_attributes = model_class->GetAttributes();
1606     auto x_attributes = x_class->attributes();
1607     if (m_attributes.size() != x_attributes.size()) {
1608       return false;
1609     }
1610 
1611     for (size_t i = 0; i < m_attributes.size(); i++) {
1612       if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) {
1613         return false;
1614       }
1615     }
1616     return true;
1617   }
1618 
1619   return false;
1620 }
1621 
IsSubtypeScalar(const AbstractBasePtr x,const TypePtr model)1622 inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) {
1623   MS_EXCEPTION_IF_NULL(x);
1624   MS_EXCEPTION_IF_NULL(model);
1625   if (dyn_cast<AbstractScalar>(x) == nullptr) {
1626     return false;
1627   }
1628   TypePtr x_type = x->GetTypeTrack();
1629   return IsSubType(x_type, model);
1630 }
1631 }  // namespace
1632 
IsSubtype(const AbstractBasePtr x,const TypePtr model)1633 bool IsSubtype(const AbstractBasePtr x, const TypePtr model) {
1634   MS_EXCEPTION_IF_NULL(x);
1635   MS_EXCEPTION_IF_NULL(model);
1636   TypeId model_typeid = model->type_id();
1637   switch (model_typeid) {
1638     case kMetaTypeObject:
1639       return true;
1640     case kObjectTypeTuple:
1641       return IsSubtypeTuple(x, model);
1642     case kObjectTypeTensorType:
1643       return IsSubtypeArray(x, model);
1644     case kObjectTypeList:
1645       return IsSubtypeList(x, model);
1646     case kObjectTypeClass:
1647       return IsSubtypeClass(x, model);
1648     default:
1649       if (IsSubType(model, std::make_shared<Number>())) {
1650         return IsSubtypeScalar(x, model);
1651       }
1652       MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << ".";
1653   }
1654 }
1655 }  // namespace abstract
1656 }  // namespace mindspore
1657