1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/static_analysis/prim.h"
20
21 #include <algorithm>
22 #include <limits>
23 #include <mutex>
24 #include <string>
25 #include <utility>
26 #include <unordered_set>
27
28 #include "frontend/operator/cc_implementations.h"
29 #include "frontend/operator/ops.h"
30 #include "frontend/operator/composite/do_signature.h"
31 #include "frontend/operator/prim_to_function.h"
32 #include "abstract/utils.h"
33 #include "utils/symbolic.h"
34 #include "pipeline/jit/resource.h"
35 #include "pipeline/jit/parse/resolve.h"
36 #include "utils/convert_utils.h"
37 #include "utils/convert_utils_py.h"
38 #include "utils/ms_context.h"
39 #include "pipeline/jit/parse/data_converter.h"
40 #include "abstract/primitive_infer_map.h"
41 #include "abstract/param_validator.h"
42 #include "utils/ms_utils.h"
43 #include "utils/shape_utils.h"
44 #include "utils/parallel_node_check.h"
45
46 namespace mindspore {
47 namespace abstract {
48 using mindspore::parse::PyObjectWrapper;
49
50 std::unordered_set<std::string> prims_to_skip_undetermined_infer{
51 "MakeTuple", "make_list", "Switch", "env_setitem", "env_getitem", "Load", "UpdateState"};
52
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)53 EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
54 const AnfNodeConfigPtr &out_conf) {
55 MS_EXCEPTION_IF_NULL(engine);
56 MS_EXCEPTION_IF_NULL(out_conf);
57 AbstractBasePtrList args_spec_list;
58 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
59 [](const ConfigPtr &ref) -> AbstractBasePtr {
60 MS_EXCEPTION_IF_NULL(ref);
61 MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult());
62 return ref->ObtainEvalResult()->abstract();
63 });
64 auto do_signature = prim_->cast<prim::DoSignaturePrimitivePtr>();
65 MS_EXCEPTION_IF_NULL(do_signature);
66 auto &func = do_signature->function();
67 if (func->isa<Primitive>()) {
68 auto sig_prim = func->cast<PrimitivePtr>();
69 if (prims_to_skip_undetermined_infer.find(sig_prim->name()) == prims_to_skip_undetermined_infer.end()) {
70 auto ret_abstract = AbstractEval(args_spec_list);
71 if (ret_abstract != nullptr) {
72 MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined";
73 return ret_abstract;
74 }
75 }
76 }
77 if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
78 MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
79 }
80
81 auto out_node = dyn_cast<CNode>(out_conf->node());
82 MS_EXCEPTION_IF_NULL(out_node);
83 const auto &out_node_inputs = out_node->inputs();
84 if (out_node->inputs().size() == 0 || (out_node_inputs.size() - 1) != args_conf_list.size()) {
85 MS_LOG(EXCEPTION) << "Op: " << func->ToString() << " args size should equal to inputs size minus 1, but args size "
86 << args_conf_list.size() << ", inputs size " << out_node_inputs.size();
87 }
88 AnfNodePtrList args_inputs{out_node_inputs.begin() + 1, out_node_inputs.end()};
89
90 ScopePtr scope = kDefaultScope;
91 if (out_conf != nullptr) {
92 scope = out_conf->node()->scope();
93 }
94 ScopeGuard scope_guard(scope);
95
96 AnfNodePtr new_node = nullptr;
97 if (bound_node() != nullptr) {
98 TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
99 new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
100 } else {
101 new_node = prim::GenerateCNode(out_node->func_graph(), prim_->ToString(), func, args_spec_list, args_inputs);
102 }
103 AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
104
105 if (out_node->isa<CNode>()) {
106 auto out_cnode = out_node->cast<CNodePtr>();
107 auto new_cnode = new_node->cast<CNodePtr>();
108 new_cnode->CloneCNodeInfo(out_cnode);
109 }
110
111 return engine->ForwardConfig(out_conf, fn_conf);
112 }
113
GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list,bool need_unpack)114 static AbstractBasePtrList GetUnpackGraphSpecArgsList(AbstractBasePtrList args_spec_list, bool need_unpack) {
115 // arg[0] is the func graph to unpack, ignore it
116 AbstractBasePtrList specialize_args_before_unpack(args_spec_list.begin() + 1, args_spec_list.end());
117 AbstractBasePtrList graph_specialize_args;
118 if (need_unpack) {
119 for (size_t index = 0; index < specialize_args_before_unpack.size(); index++) {
120 MS_EXCEPTION_IF_NULL(specialize_args_before_unpack[index]);
121 if (specialize_args_before_unpack[index]->isa<AbstractTuple>()) {
122 auto arg_tuple = specialize_args_before_unpack[index]->cast<AbstractTuplePtr>();
123 std::transform(arg_tuple->elements().begin(), arg_tuple->elements().end(),
124 std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
125 } else if (specialize_args_before_unpack[index]->isa<AbstractDictionary>()) {
126 auto arg_dict = specialize_args_before_unpack[index]->cast<AbstractDictionaryPtr>();
127 auto dict_elems = arg_dict->elements();
128 (void)std::transform(
129 dict_elems.begin(), dict_elems.end(), std::back_inserter(graph_specialize_args),
130 [](const AbstractAttribute &item) { return std::make_shared<AbstractKeywordArg>(item.first, item.second); });
131 } else {
132 MS_LOG(EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got "
133 << specialize_args_before_unpack[index]->ToString();
134 }
135 }
136 } else {
137 graph_specialize_args = specialize_args_before_unpack;
138 }
139 return graph_specialize_args;
140 }
141
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)142 EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
143 const AnfNodeConfigPtr &out_conf) {
144 MS_EXCEPTION_IF_NULL(engine);
145 MS_EXCEPTION_IF_NULL(out_conf);
146 MS_EXCEPTION_IF_NULL(out_conf->node());
147 if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
148 MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
149 }
150
151 auto unpack_graph = prim_->cast<prim::UnpackGraphPrimitivePtr>();
152 MS_EXCEPTION_IF_NULL(unpack_graph);
153 auto out_node = out_conf->node()->cast<CNodePtr>();
154 MS_EXCEPTION_IF_NULL(out_node);
155 const auto &out_node_inputs = out_node->inputs();
156 if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
157 MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
158 << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
159 << ", inputs size " << out_node_inputs.size();
160 }
161 AbstractBasePtrList args_spec_list;
162 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
163 [](const ConfigPtr &ref) -> AbstractBasePtr {
164 MS_EXCEPTION_IF_NULL(ref);
165 MS_EXCEPTION_IF_NULL(ref->ObtainEvalResult());
166 return ref->ObtainEvalResult()->abstract();
167 });
168 // get the forward graph
169 if (args_spec_list.empty()) {
170 MS_LOG(EXCEPTION) << "args_spec_list can't be empty.";
171 }
172 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
173 auto fn = args_spec_list[0]->cast<AbstractFunctionPtr>();
174 if (fn == nullptr) {
175 MS_LOG(EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but " << args_spec_list[0]->ToString();
176 }
177 auto real_fn = fn->cast<FuncGraphAbstractClosurePtr>();
178 MS_EXCEPTION_IF_NULL(real_fn);
179 FuncGraphPtr forward_graph = real_fn->func_graph();
180 MS_EXCEPTION_IF_NULL(forward_graph);
181 AbstractBasePtrList graph_specialize_args =
182 GetUnpackGraphSpecArgsList(args_spec_list, unpack_graph->need_unpack_args());
183 AbstractBasePtrList graph_specialize_args_without_sens;
184 if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
185 MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
186 }
187 (void)std::transform(graph_specialize_args.begin(),
188 graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0),
189 std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; });
190 auto new_graph = forward_graph->GenerateGraph(graph_specialize_args_without_sens);
191 engine->func_graph_manager()->AddFuncGraph(new_graph);
192 ScopePtr scope = kDefaultScope;
193 if (out_conf != nullptr) {
194 scope = out_conf->node()->scope();
195 }
196 ScopeGuard scope_guard(scope);
197 AnfNodePtr new_vnode = NewValueNode(new_graph);
198 AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_vnode, out_conf->context(), out_conf->func_graph());
199
200 return engine->ForwardConfig(out_conf, fn_conf);
201 }
202
MixedPrecisionCastHelper(const AnfNodePtr & source_node,const AbstractBasePtr & node_type,const AnfNodePtr & target_type,const FuncGraphPtr & func_graph)203 AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const AbstractBasePtr &node_type,
204 const AnfNodePtr &target_type, const FuncGraphPtr &func_graph) {
205 MS_EXCEPTION_IF_NULL(node_type);
206 MS_EXCEPTION_IF_NULL(func_graph);
207 AnfNodePtr target_node = source_node;
208 if (node_type->isa<AbstractTensor>()) {
209 auto x = node_type->cast<AbstractTensorPtr>();
210 if (x->element()->BuildType()->isa<Float>()) {
211 auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
212 MS_EXCEPTION_IF_NULL(cast);
213 target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type});
214 }
215 } else if (node_type->isa<AbstractTuple>()) {
216 auto x = node_type->cast<AbstractTuplePtr>();
217 auto &items = x->elements();
218 std::vector<AnfNodePtr> nodes;
219 nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
220 int64_t idx = 0;
221 for (const auto &item : items) {
222 AnfNodePtr tuple_node =
223 func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)});
224 AnfNodePtr node = MixedPrecisionCastHelper(tuple_node, item, target_type, func_graph);
225 nodes.emplace_back(node);
226 ++idx;
227 }
228 target_node = func_graph->NewCNode(nodes);
229 } else if (node_type->isa<AbstractDictionary>()) {
230 auto x = node_type->cast<AbstractDictionaryPtr>();
231 auto &items = x->elements();
232 std::vector<AnfNodePtr> dict_key_nodes;
233 std::vector<AnfNodePtr> dict_value_nodes;
234 dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
235 dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
236 for (const auto &item : items) {
237 AnfNodePtr dict_value_node =
238 func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(item.first)});
239 AnfNodePtr node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph);
240 dict_key_nodes.emplace_back(NewValueNode(item.first));
241 dict_value_nodes.emplace_back(node);
242 }
243 target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(dict_key_nodes),
244 func_graph->NewCNode(dict_value_nodes)});
245 } else if (node_type->isa<AbstractKeywordArg>()) {
246 auto x = node_type->cast<AbstractKeywordArgPtr>();
247 std::string kwarg_key = x->get_key();
248 AnfNodePtr kwarg_value_node =
249 func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node});
250 AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph);
251 target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node});
252 }
253 return target_node;
254 }
255
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)256 EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
257 const AnfNodeConfigPtr &out_conf) {
258 MS_EXCEPTION_IF_NULL(engine);
259 AbstractBasePtrList args_spec_list;
260 MS_EXCEPTION_IF_NULL(out_conf);
261 if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
262 MS_LOG(EXCEPTION) << "Node of out_conf should be CNode";
263 }
264 auto out_node = out_conf->node()->cast<CNodePtr>();
265 MS_EXCEPTION_IF_NULL(out_node);
266 const auto &out_node_inputs = out_node->inputs();
267 if (out_node->inputs().empty() || (out_node_inputs.size() - 1) != args_conf_list.size()) {
268 MS_LOG(EXCEPTION) << "MixedPrecisionCast"
269 << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
270 << ", inputs size " << out_node_inputs.size();
271 }
272 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
273 [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->ObtainEvalResult()->abstract(); });
274
275 ScopePtr scope = kDefaultScope;
276 scope = out_conf->node()->scope();
277 ScopeGuard scope_guard(scope);
278
279 FuncGraphPtr func_graph = out_node->func_graph();
280 constexpr size_t source_node_index = 2;
281 if (out_node_inputs.size() <= source_node_index) {
282 MS_LOG(EXCEPTION) << "Input size:" << out_node_inputs.size() << " should bigger than 2.";
283 }
284
285 AnfNodePtr new_node =
286 MixedPrecisionCastHelper(out_node_inputs[source_node_index], args_spec_list[1], out_node_inputs[1], func_graph);
287 AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
288
289 if (new_node->isa<CNode>()) {
290 auto new_cnode = new_node->cast<CNodePtr>();
291 new_cnode->CloneCNodeInfo(out_node);
292 }
293 return engine->ForwardConfig(out_conf, fn_conf);
294 }
295
296 namespace {
BuildValue(const ValuePtr & value_ptr)297 py::object BuildValue(const ValuePtr &value_ptr) {
298 if (value_ptr == nullptr) {
299 return py::none();
300 } else {
301 return ValueToPyData(value_ptr);
302 }
303 }
304
AbstractTupleToPython(const AbstractBasePtr & abs_base)305 py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base) {
306 auto arg_tuple = dyn_cast<AbstractTuple>(abs_base);
307 MS_EXCEPTION_IF_NULL(arg_tuple);
308 size_t len = arg_tuple->size();
309 py::tuple shape_tuple(len);
310 py::tuple dtype_tuple(len);
311 py::tuple value_tuple(len);
312 py::tuple min_value_tuple(len);
313 py::tuple max_value_tuple(len);
314 py::tuple min_shape_tuple(len);
315 py::tuple max_shape_tuple(len);
316 bool dyn_shape = false;
317 bool dyn_value = false;
318
319 for (size_t i = 0; i < len; i++) {
320 auto arg = arg_tuple->elements()[i];
321 py::dict out = ConvertAbstractToPython(arg);
322 shape_tuple[i] = out[ATTR_SHAPE];
323 dtype_tuple[i] = out[ATTR_DTYPE];
324 value_tuple[i] = out[ATTR_VALUE];
325
326 // Elements in tuple is tensor shape value.
327 if (out.contains(py::str(ATTR_MIN_VALUE)) && out.contains(py::str(ATTR_MAX_VALUE))) {
328 min_value_tuple[i] = out[ATTR_MIN_VALUE];
329 max_value_tuple[i] = out[ATTR_MAX_VALUE];
330 dyn_value = true;
331 }
332
333 // Elements in tuple is tensor, which shape is dynamic.
334 if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
335 min_shape_tuple[i] = out[ATTR_MIN_SHAPE];
336 max_shape_tuple[i] = out[ATTR_MAX_SHAPE];
337 dyn_shape = true;
338 }
339 }
340 auto dic = py::dict();
341 dic[ATTR_SHAPE] = shape_tuple;
342 dic[ATTR_DTYPE] = dtype_tuple;
343 MS_EXCEPTION_IF_NULL(arg_tuple->BuildValue());
344 if (arg_tuple->BuildValue()->isa<AnyValue>()) {
345 dic[ATTR_VALUE] = py::none();
346 } else {
347 dic[ATTR_VALUE] = value_tuple;
348 }
349
350 if (dyn_value) {
351 dic[ATTR_MIN_VALUE] = min_value_tuple;
352 dic[ATTR_MAX_VALUE] = max_value_tuple;
353 }
354 if (dyn_shape) {
355 dic[ATTR_MIN_SHAPE] = min_shape_tuple;
356 dic[ATTR_MAX_SHAPE] = max_shape_tuple;
357 }
358
359 return dic;
360 }
361
AbstractListToPython(const AbstractBasePtr & abs_base)362 py::dict AbstractListToPython(const AbstractBasePtr &abs_base) {
363 auto arg_list = dyn_cast<AbstractList>(abs_base);
364 MS_EXCEPTION_IF_NULL(arg_list);
365 size_t len = arg_list->size();
366 py::list shape_list(len);
367 py::list dtype_list(len);
368 py::list value_list(len);
369 py::list min_shape_list(len);
370 py::list max_shape_list(len);
371 bool dyn_shape = false;
372
373 for (size_t i = 0; i < len; i++) {
374 py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
375 shape_list[i] = out[ATTR_SHAPE];
376 dtype_list[i] = out[ATTR_DTYPE];
377 value_list[i] = out[ATTR_VALUE];
378
379 // Elements in list is tensor, which shape is dynamic.
380 if (out.contains(py::str(ATTR_MIN_SHAPE)) && out.contains(py::str(ATTR_MAX_SHAPE))) {
381 min_shape_list[i] = out[ATTR_MIN_SHAPE];
382 max_shape_list[i] = out[ATTR_MAX_SHAPE];
383 dyn_shape = true;
384 }
385 }
386 auto dic = py::dict();
387 dic[ATTR_SHAPE] = shape_list;
388 dic[ATTR_DTYPE] = dtype_list;
389 MS_EXCEPTION_IF_NULL(arg_list->BuildValue());
390 if (arg_list->BuildValue()->isa<AnyValue>()) {
391 dic[ATTR_VALUE] = py::none();
392 } else {
393 dic[ATTR_VALUE] = value_list;
394 }
395
396 if (dyn_shape) {
397 dic[ATTR_MIN_SHAPE] = min_shape_list;
398 dic[ATTR_MAX_SHAPE] = max_shape_list;
399 }
400
401 return dic;
402 }
403
ConvertAbstractTensorToPython(const AbstractBasePtr & abs_base,py::dict * dic)404 void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
405 auto arg_tensor = dyn_cast<AbstractTensor>(abs_base);
406 MS_EXCEPTION_IF_NULL(dic);
407 MS_EXCEPTION_IF_NULL(arg_tensor);
408 MS_EXCEPTION_IF_NULL(arg_tensor->shape());
409 (*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
410 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
411 const auto &min_shape = arg_tensor->shape()->min_shape();
412 const auto &max_shape = arg_tensor->shape()->max_shape();
413 if (!min_shape.empty() && !max_shape.empty()) {
414 (*dic)[ATTR_MIN_SHAPE] = min_shape;
415 (*dic)[ATTR_MAX_SHAPE] = max_shape;
416 }
417 }
418
419 auto min_value = arg_tensor->get_min_value();
420 auto max_value = arg_tensor->get_max_value();
421 if (min_value != nullptr && max_value != nullptr) {
422 (*dic)[ATTR_MIN_VALUE] = BuildValue(min_value);
423 (*dic)[ATTR_MAX_VALUE] = BuildValue(max_value);
424 }
425
426 (*dic)[ATTR_DTYPE] = arg_tensor->BuildType();
427 (*dic)[ATTR_VALUE] = BuildValue(arg_tensor->BuildValue());
428 }
429
ConvertAbstractFunctionToPython(const AbstractBasePtr & abs_base,py::dict * dic)430 void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
431 MS_EXCEPTION_IF_NULL(dic);
432 MS_EXCEPTION_IF_NULL(abs_base);
433 (*dic)[ATTR_SHAPE] = py::none();
434 (*dic)[ATTR_DTYPE] = abs_base->BuildType();
435 (*dic)[ATTR_VALUE] = py::none();
436 if (abs_base->isa<PartialAbstractClosure>()) {
437 AbstractBasePtrList args = abs_base->cast<PartialAbstractClosurePtr>()->args();
438 if (!args.empty()) {
439 MS_EXCEPTION_IF_NULL(args[0]->BuildValue());
440 auto value = args[0]->BuildValue()->cast<parse::ClassTypePtr>();
441 if (value != nullptr) {
442 (*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
443 (*dic)[ATTR_VALUE] = value->obj();
444 }
445 }
446 }
447 }
448 } // end anonymous namespace
449
ConvertAbstractToPython(const AbstractBasePtr & abs_base)450 py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
451 MS_EXCEPTION_IF_NULL(abs_base);
452 auto dic = py::dict();
453 if (abs_base->isa<AbstractTensor>()) {
454 ConvertAbstractTensorToPython(abs_base, &dic);
455 } else if (abs_base->isa<AbstractRowTensor>()) {
456 auto arg = dyn_cast<AbstractRowTensor>(abs_base);
457 dic[ATTR_SHAPE] = arg->shape()->shape();
458 dic[ATTR_DTYPE] = arg->BuildType();
459 dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
460 } else if (abs_base->isa<AbstractSparseTensor>()) {
461 auto arg = dyn_cast<AbstractSparseTensor>(abs_base);
462 dic[ATTR_SHAPE] = arg->shape()->shape();
463 dic[ATTR_DTYPE] = arg->BuildType();
464 dic[ATTR_VALUE] = BuildValue(arg->BuildValue());
465 } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>() || abs_base->isa<AbstractRefKey>()) {
466 ShapeVector shape;
467 dic[ATTR_SHAPE] = shape;
468 dic[ATTR_DTYPE] = abs_base->BuildType();
469 dic[ATTR_VALUE] = BuildValue(abs_base->BuildValue());
470 } else if (abs_base->isa<AbstractSlice>()) {
471 auto arg_slice = dyn_cast<AbstractSlice>(abs_base);
472 ShapeVector shape;
473 dic[ATTR_SHAPE] = shape;
474 dic[ATTR_DTYPE] = arg_slice->BuildType();
475 dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
476 } else if (abs_base->isa<AbstractEllipsis>()) {
477 dic[ATTR_SHAPE] = py::none();
478 dic[ATTR_DTYPE] = py::ellipsis();
479 dic[ATTR_VALUE] = py::ellipsis();
480 } else if (abs_base->isa<AbstractTuple>()) {
481 return AbstractTupleToPython(abs_base);
482 } else if (abs_base->isa<AbstractList>()) {
483 return AbstractListToPython(abs_base);
484 } else if (abs_base->isa<AbstractNone>()) {
485 dic[ATTR_SHAPE] = py::none();
486 dic[ATTR_DTYPE] = py::none();
487 dic[ATTR_VALUE] = py::none();
488 } else if (abs_base->isa<AbstractFunction>()) {
489 ConvertAbstractFunctionToPython(abs_base, &dic);
490 } else if (abs_base->isa<AbstractUndetermined>()) {
491 auto arg = dyn_cast<AbstractUndetermined>(abs_base);
492 dic[ATTR_SHAPE] = py::none();
493 dic[ATTR_DTYPE] = arg->BuildType();
494 dic[ATTR_VALUE] = py::none();
495 } else if (abs_base->isa<AbstractMonad>()) {
496 dic[ATTR_SHAPE] = py::none();
497 dic[ATTR_DTYPE] = abs_base->BuildType();
498 dic[ATTR_VALUE] = py::none();
499 } else {
500 auto value = abs_base->BuildValue();
501 MS_EXCEPTION_IF_NULL(value);
502 if ((*value == *kAnyValue)) {
503 auto value_desc = abs_base->value_desc();
504 MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
505 << " for python primitive." << abs_base->ToString();
506 }
507 MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is "
508 << value->ToString();
509 }
510 return dic;
511 }
512
513 namespace {
PreparePyInputs(const PrimitivePyPtr &,const AbstractBasePtrList & args)514 py::tuple PreparePyInputs(const PrimitivePyPtr &, const AbstractBasePtrList &args) {
515 // The monad parameter is defined at the end of the parameter and needs to be ignored
516 std::size_t size_args = args.size() - GetAbstractMonadNum(args);
517 py::tuple py_args(size_args);
518 for (size_t i = 0; i < size_args; i++) {
519 auto arg_i = (args)[i];
520 py_args[i] = ConvertAbstractToPython(arg_i);
521 }
522 return py_args;
523 }
524
CheckCustomPrimOutputInferResult(const PrimitivePtr & prim,const AbstractBasePtr & res_spec)525 void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) {
526 MS_EXCEPTION_IF_NULL(prim);
527 MS_EXCEPTION_IF_NULL(res_spec);
528 const string kOutputNum = "output_num";
529 if (prim->IsCustomPrim()) {
530 // Raise error if output_num is not match the infer result.
531 auto output_num_value = prim->GetAttr(kOutputNum);
532 if (output_num_value == nullptr) {
533 MS_LOG(DEBUG) << "The output num may no need to check";
534 return;
535 }
536 int64_t output_num = GetValue<int64_t>(output_num_value);
537 if (res_spec->isa<AbstractTensor>() && output_num != 1) {
538 MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
539 << "]'s attribute[output_num]:" << output_num << " not matches the infer result "
540 << res_spec->ToString();
541 } else if (res_spec->isa<AbstractTuple>() &&
542 (res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
543 MS_LOG(EXCEPTION) << "Custom primitive[" << prim->ToString() << "]'s attribute[output_num]:" << output_num
544 << " not matches the infer result " << res_spec->ToString();
545 }
546 }
547 }
548
PyInferRes2Abstract(const PrimitivePyPtr & prim_py,const py::dict & output)549 AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
550 // Convert to AbstractValue based on type and shape
551 auto out_dtype = output[ATTR_DTYPE];
552 if (output[ATTR_VALUE].is_none()) {
553 auto out_shape = output[ATTR_SHAPE];
554 return MakePyInferRes2Abstract(out_shape, out_dtype, output);
555 }
556 // Convert pyobject to Value, then to AbstractValue
557 ValuePtr converted_ret = nullptr;
558 TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
559 bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
560 if (!converted) {
561 MS_LOG(EXCEPTION) << "Convert data failed";
562 }
563 auto res_spec = FromValue(converted_ret);
564 MS_EXCEPTION_IF_NULL(res_spec);
565 if (res_spec->isa<AbstractTensor>()) {
566 // Replace to tensor constant node in specialize
567 auto res_tensor = res_spec->cast<AbstractTensorPtr>();
568 res_tensor->set_value(converted_ret);
569 SetValueRange(res_tensor, output);
570 }
571 CheckCustomPrimOutputInferResult(prim_py, res_spec);
572 return res_spec;
573 }
574 } // end anonymous namespace
575
RunPyInferValue(const AnalysisEnginePtr & engine,const AbstractBasePtr & abs_base,const AbstractBasePtrList & args)576 EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &engine, const AbstractBasePtr &abs_base,
577 const AbstractBasePtrList &args) {
578 auto prim_py = dyn_cast<PrimitivePy>(prim_);
579 if (prim_py == nullptr) {
580 MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
581 }
582 // Call checking method 'infer_value' for python primitive
583 MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
584 auto py_args = PreparePyInputs(prim_py, args);
585 py::tuple py_vals(py_args.size());
586 auto added_attrs = prim_->evaluate_added_attrs();
587 for (size_t i = 0; i < py_args.size(); ++i) {
588 py_vals[i] = py_args[i][ATTR_VALUE];
589 }
590 py::object py_ret = prim_py->RunInferValue(py_vals);
591 if (py::isinstance<py::none>(py_ret)) {
592 return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
593 }
594 // Convert pyobject to Value, then to AbstractValue
595 ValuePtr converted_ret = nullptr;
596 TypePtr dtype = abs_base->BuildType();
597 bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype);
598 if (!converted) {
599 MS_LOG(EXCEPTION) << "Convert data failed";
600 }
601 auto res_spec = FromValue(converted_ret);
602 MS_EXCEPTION_IF_NULL(res_spec);
603 if (res_spec->isa<AbstractTensor>()) {
604 // Replace to tensor constant node in specialize
605 auto res_tensor = res_spec->cast<AbstractTensorPtr>();
606 res_tensor->set_value(converted_ret);
607 }
608 return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
609 }
610
EvalPyCheckPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)611 EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
612 auto prim_py = dyn_cast<PrimitivePy>(prim_);
613 if (prim_py == nullptr) {
614 MS_LOG(EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
615 }
616 // Call checking method '__check__' for subclass of 'PrimitiveWithCheck'
617 MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
618 auto py_args = PreparePyInputs(prim_py, args);
619 prim_py->RunCheck(py_args);
620
621 prim_->BeginRecordAddAttr();
622 AbstractBasePtr abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
623 prim_->EndRecordAddAttr();
624 auto added_attrs = prim_->evaluate_added_attrs();
625
626 if (!py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
627 return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
628 }
629 // Call method 'infer_value' for primitive with this method for constant propagation
630 return RunPyInferValue(engine, abs_base, args);
631 }
632
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)633 EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
634 if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
635 auto ret_abstract = AbstractEval(args);
636 if (ret_abstract != nullptr) {
637 MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
638 return ret_abstract;
639 }
640 }
641 if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
642 return EvalPyCheckPrim(engine, args);
643 }
644 auto context = MsContext::GetInstance();
645 MS_EXCEPTION_IF_NULL(context);
646 bool need_infer_value = !eval_impl_.in_white_list_;
647 if (need_infer_value == false) {
648 need_infer_value = ((context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode)) &&
649 std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
650 MS_EXCEPTION_IF_NULL(abs);
651 auto value = abs->BuildValue();
652 return (value != nullptr && !value->isa<AnyValue>() && !value->isa<None>() &&
653 !value->isa<Monad>() && !value->isa<FuncGraph>());
654 });
655 }
656 AbstractBasePtr abs_base = nullptr;
657 ValuePtr value = nullptr;
658 prim_->BeginRecordAddAttr();
659 if (need_infer_value && eval_impl_.infer_value_impl_ != nullptr) {
660 value = eval_impl_.infer_value_impl_(prim_, args);
661 if (value != nullptr) {
662 abs_base = value->ToAbstract();
663 prim_->EndRecordAddAttr();
664 auto added_attrs = prim_->evaluate_added_attrs();
665 return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
666 }
667 }
668 abs_base = eval_impl_.infer_shape_impl_(engine, prim_, args);
669 prim_->EndRecordAddAttr();
670 auto added_attrs = prim_->evaluate_added_attrs();
671 auto eval_result = std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
672 return eval_result;
673 }
674
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args)675 EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
676 auto ret_abstract = AbstractEval(args);
677 if (ret_abstract != nullptr) {
678 MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
679 return ret_abstract;
680 }
681 MS_LOG(DEBUG) << "Eval for:" << prim_py_->ToString();
682
683 const auto eval_result = evaluator_cache_mgr_->GetValue(args);
684 if (eval_result != nullptr) {
685 return eval_result;
686 }
687
688 auto py_args = PreparePyInputs(prim_py_, args);
689 prim_py_->BeginRecordAddAttr();
690 py::dict output = prim_py_->RunInfer(py_args);
691 prim_py_->EndRecordAddAttr();
692 auto added_attrs = prim_py_->evaluate_added_attrs();
693 MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);
694 auto res_spec = PyInferRes2Abstract(prim_py_, output);
695 MS_LOG(DEBUG) << "Python InferTensor result spec: " << res_spec->ToString() << ".";
696 auto infer_result = std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
697 evaluator_cache_mgr_->SetValue(args, infer_result);
698 return infer_result;
699 }
700
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args)701 EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
702 auto ret_abstract = AbstractEval(args);
703 if (ret_abstract != nullptr) {
704 MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
705 return ret_abstract;
706 }
707 // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
708 if (nargs_ != args.size()) {
709 MS_LOG(EXCEPTION) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size() << " inputs";
710 }
711 TypePtr ret_value_type = return_value_type_;
712 ValuePtrList value_list;
713 for (const auto &arg : args) {
714 // Check if all arguments are scalar type.
715 MS_EXCEPTION_IF_NULL(arg);
716 if (arg->isa<AbstractScalar>()) {
717 auto arg_scalar = dyn_cast<AbstractScalar>(arg);
718 auto arg_value = arg_scalar->GetValueTrack();
719 value_list.push_back(arg_value);
720 } else {
721 // Raise TypeError Expected Scalar.
722 MS_LOG(EXCEPTION) << "Expect scalar arguments for uniform primitives.";
723 }
724 }
725 for (const auto &item : type_map_) {
726 TypePtrList selections;
727 MS_EXCEPTION_IF_NULL(item.second);
728 (void)std::transform(item.second->begin(), item.second->end(), std::back_inserter(selections),
729 [&args](size_t arg_idx) -> TypePtr {
730 if (arg_idx >= args.size()) {
731 MS_LOG(EXCEPTION) << "Index:" << arg_idx << " out of range:" << args.size();
732 }
733 MS_EXCEPTION_IF_NULL(args[arg_idx]);
734 return args[arg_idx]->GetTypeTrack();
735 });
736 TypePtr res = CheckTypeList(item.first, selections);
737 MS_EXCEPTION_IF_NULL(return_value_type_);
738 MS_EXCEPTION_IF_NULL(item.first);
739 if (*return_value_type_ == *(item.first)) {
740 ret_value_type = res;
741 }
742 }
743
744 ValuePtr evaluated_value = RunImpl(value_list);
745 if (!(*evaluated_value == *kAnyValue)) {
746 ret_value_type = evaluated_value->type();
747 }
748 // for comparison primitives , return type shall have be specified to be bool.
749 if (specify_out_type_ != nullptr) {
750 ret_value_type = specify_out_type_;
751 }
752
753 AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, ret_value_type);
754 return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
755 }
756
RunImpl(const ValuePtrList & args) const757 ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
758 if (!eval_value_) {
759 return kAnyValue;
760 } else {
761 if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) {
762 MS_EXCEPTION_IF_NULL(arg);
763 return arg->isa<AnyValue>();
764 })) {
765 return kAnyValue;
766 }
767 return impl_(args);
768 }
769 }
770
771 // Primitive implementation
772 // static function start
773 namespace {
InitStandardPrimEvaluator(PrimitivePtr primitive,const StandardPrimitiveImplReg eval_impl)774 EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveImplReg eval_impl) {
775 EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl);
776 return prim_evaluator;
777 }
778
InitUniformPrimEvaluator(const PrimitivePtr & primitive,PrimitiveImpl prim_impl,bool eval_value,const TypePtr & specify_out_type)779 EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value,
780 const TypePtr &specify_out_type) {
781 FunctionPtr func = nullptr;
782 (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func);
783 MS_EXCEPTION_IF_NULL(func);
784
785 EvaluatorPtr uniform_primitive_evaluator =
786 std::make_shared<UniformPrimEvaluator>(func, prim_impl, eval_value, specify_out_type);
787 return uniform_primitive_evaluator;
788 }
789
PyObjToGraph(const AnalysisEnginePtr & engine,const ValuePtr & method)790 FuncGraphPtr PyObjToGraph(const AnalysisEnginePtr &engine, const ValuePtr &method) {
791 MS_EXCEPTION_IF_NULL(engine);
792 MS_EXCEPTION_IF_NULL(method);
793 if (!method->isa<parse::PyObjectWrapper>()) {
794 MS_LOG(EXCEPTION) << "Method type error: " << method->ToString();
795 }
796
797 std::shared_ptr<PyObjectWrapper> obj = method->cast<std::shared_ptr<PyObjectWrapper>>();
798 FuncGraphPtr func_graph = mindspore::parse::ConvertToFuncGraph(obj->obj());
799 if (func_graph == nullptr) {
800 MS_LOG(EXCEPTION) << "Parse python object: " << method->ToString() << " failed";
801 }
802
803 FuncGraphManagerPtr manager = engine->func_graph_manager();
804 manager->AddFuncGraph(func_graph);
805 return func_graph;
806 }
807
AddToManager(const AnalysisEnginePtr & engine,const FuncGraphPtr func_graph)808 inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) {
809 MS_EXCEPTION_IF_NULL(engine);
810 FuncGraphManagerPtr manager = engine->func_graph_manager();
811 manager->AddFuncGraph(func_graph);
812 }
813
814 enum class REQUIRE_TYPE { ATTR, METHOD };
815
StaticGetterInferred(const ValuePtr & value,const ConfigPtr & data_conf,const AnfNodeConfigPtr & old_conf,REQUIRE_TYPE require_type=REQUIRE_TYPE::METHOD)816 EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf,
817 REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
818 MS_EXCEPTION_IF_NULL(old_conf);
819
820 AbstractBasePtr abs_ptr = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
821 AbstractFunctionPtr abs_func = dyn_cast<abstract::AbstractFunction>(abs_ptr);
822 MS_EXCEPTION_IF_NULL(abs_func);
823
824 // Create new cnode
825 std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
826 auto func_graph_func = dyn_cast<abstract::FuncGraphAbstractClosure>(abs_func);
827 if (func_graph_func != nullptr) {
828 FuncGraphPtr fg = func_graph_func->func_graph();
829 input.push_back(NewValueNode(fg));
830 } else {
831 auto prim_func = dyn_cast<abstract::PrimitiveAbstractClosure>(abs_func);
832 MS_EXCEPTION_IF_NULL(prim_func);
833 PrimitivePtr prim = prim_func->prim();
834 input.push_back(NewValueNode(prim));
835 }
836
837 AnfNodeConfigPtr conf = dyn_cast<abstract::AnfNodeConfig>(data_conf);
838 MS_EXCEPTION_IF_NULL(conf);
839 input.push_back(conf->node());
840 MS_EXCEPTION_IF_NULL(old_conf);
841 FuncGraphPtr func_graph = old_conf->node()->func_graph();
842 MS_EXCEPTION_IF_NULL(func_graph);
843 CNodePtr new_cnode = func_graph->NewCNode(input);
844 if (require_type == REQUIRE_TYPE::ATTR) {
845 new_cnode = func_graph->NewCNode({new_cnode});
846 }
847 AnalysisEnginePtr eng = old_conf->engine();
848 AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context(), old_conf->func_graph());
849 return eng->ForwardConfig(old_conf, fn_conf);
850 }
851
GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &,const AbstractBasePtrList & args_spec_list,const AnfNodeConfigPtr & out_conf)852 EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &, const AbstractBasePtrList &args_spec_list,
853 const AnfNodeConfigPtr &out_conf) {
854 // args_spec_list: same as StaticGetter
855 if (args_spec_list.size() < 2) {
856 MS_LOG(EXCEPTION) << "Size of args_spec_list is less than 2";
857 }
858 MS_EXCEPTION_IF_NULL(out_conf);
859 // An external type.
860 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
861 MS_EXCEPTION_IF_NULL(args_spec_list[1]);
862 MS_LOG(DEBUG) << "Args[0]: " << args_spec_list[0]->ToString();
863 MS_LOG(DEBUG) << "Args[1]: " << args_spec_list[1]->ToString();
864 auto data_v = args_spec_list[0]->BuildValue();
865 MS_EXCEPTION_IF_NULL(data_v);
866 if (!data_v->isa<parse::NameSpace>()) {
867 MS_LOG(EXCEPTION) << "Not supported to get attribute for " << data_v->ToString()
868 << "\nThe data should be a NameSpace.";
869 }
870
871 auto item_value = args_spec_list[1]->BuildValue();
872 MS_EXCEPTION_IF_NULL(item_value);
873 if (item_value->isa<StringImm>()) {
874 item_value = std::make_shared<parse::Symbol>(item_value->cast<StringImmPtr>()->value());
875 }
876
877 if (!item_value->isa<parse::Symbol>()) {
878 MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
879 }
880
881 // item_name to func addr from obj_map
882 parse::SymbolPtr symbol = item_value->cast<parse::SymbolPtr>();
883 parse::NameSpacePtr name_space = data_v->cast<parse::NameSpacePtr>();
884 MS_EXCEPTION_IF_NULL(out_conf);
885 auto out_node = out_conf->node();
886 FuncGraphPtr func_graph = out_node->func_graph();
887 MS_EXCEPTION_IF_NULL(func_graph);
888 auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_node);
889 if (new_node == nullptr) {
890 MS_LOG(EXCEPTION) << "Resolve node failed";
891 }
892
893 // Replace old node with the resolved new node in order list.
894 func_graph->ReplaceInOrder(out_node, new_node);
895
896 AnalysisEnginePtr eng = out_conf->engine();
897 MS_EXCEPTION_IF_NULL(eng);
898 AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
899 return eng->ForwardConfig(out_conf, fn_conf);
900 }
901
GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ValuePtr & item_value,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)902 EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &engine,
903 const AbstractBasePtrList &args_spec_list,
904 const ValuePtr &item_value, const ConfigPtr &data_conf,
905 const AnfNodeConfigPtr &out_conf) {
906 if (args_spec_list.empty()) {
907 MS_LOG(EXCEPTION) << "args_spec_list is empty";
908 }
909 AbstractClassPtr cls = CheckArg<AbstractClass>("__FUNC__", args_spec_list, 0);
910
911 // If item_value is an attribute, get abstract value from AbstractClass
912 MS_EXCEPTION_IF_NULL(item_value);
913 if (!item_value->isa<StringImm>()) {
914 MS_LOG(EXCEPTION) << "Attribute type error";
915 }
916 std::string item_name = item_value->cast<StringImmPtr>()->value();
917 MS_LOG(DEBUG) << "Resolve name: " << cls->tag().name();
918 MS_LOG(DEBUG) << "Resolve item: " << item_name;
919 MS_EXCEPTION_IF_NULL(cls);
920 AbstractBasePtr attr = cls->GetAttribute(item_name);
921 if (attr != nullptr) {
922 return std::make_shared<EvalResult>(attr, nullptr);
923 }
924
925 ValuePtr method = cls->GetMethod(item_name);
926 if (method->isa<AnyValue>()) {
927 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
928 MS_EXCEPTION_IF_NULL(args_spec_list[0]->BuildType());
929 MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
930 << ", item value: " << item_value->ToString();
931 }
932
933 // Infer class method
934 ValuePtr converted_value = PyObjToGraph(engine, method);
935 return StaticGetterInferred(converted_value, data_conf, out_conf);
936 }
937
GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr & engine,const ValuePtr & item_value,const TypePtr & data_type,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)938 EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine, const ValuePtr &item_value,
939 const TypePtr &data_type, const ConfigPtr &data_conf,
940 const AnfNodeConfigPtr &out_conf) {
941 MS_EXCEPTION_IF_NULL(item_value);
942 MS_EXCEPTION_IF_NULL(data_type);
943 // The method maybe a Primitive or Composite
944 if (!item_value->isa<StringImm>()) {
945 MS_LOG(EXCEPTION) << "Error item is not string";
946 }
947
948 std::string item_name = item_value->cast<StringImmPtr>()->value();
949 REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
950 Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
951 if (require.empty()) {
952 require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
953 if (require.empty()) {
954 MS_LOG(EXCEPTION) << "Not supported to get attribute item name:\'" << item_name << "\' of a type["
955 << data_type->ToString() << "]";
956 }
957 require_type = REQUIRE_TYPE::ATTR;
958 }
959
960 ValuePtr converted_value = nullptr;
961 if (require.is<std::string>()) {
962 // composite registered in standard_method_map go to this branch
963 converted_value = prim::GetPythonOps(require.cast<std::string>());
964 MS_EXCEPTION_IF_NULL(converted_value);
965 if (!converted_value->isa<Primitive>()) {
966 AddToManager(engine, converted_value->cast<FuncGraphPtr>());
967 }
968 } else if (require.is<PrimitivePtr>()) {
969 converted_value = require.cast<PrimitivePtr>();
970 } else {
971 MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
972 }
973 return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
974 }
975
976 enum ResolveType : int64_t {
977 kResolveTypeUserDefineClass = 1,
978 kResolveTypeBuiltInType,
979 kResolveTypeFunction,
980 };
981
GetResolveType(const TypePtr & data_type)982 int64_t GetResolveType(const TypePtr &data_type) {
983 MS_EXCEPTION_IF_NULL(data_type);
984 if (data_type->type_id() == kObjectTypeClass) {
985 return kResolveTypeUserDefineClass;
986 }
987 // Try to search method map, if not found, the data_type should be External type.
988 if (pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id())) {
989 return kResolveTypeBuiltInType;
990 }
991 return kResolveTypeFunction;
992 }
993
StaticGetter(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)994 EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
995 const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
996 // Inputs: namespace and its static function; or class and its member function
997 CheckArgsSize("StaticGetter", args_spec_list, 2);
998
999 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
1000 MS_EXCEPTION_IF_NULL(args_spec_list[1]);
1001 TypePtr data_type = args_spec_list[0]->BuildType();
1002 ValuePtr item_value = args_spec_list[1]->BuildValue();
1003 ScopePtr scope = kDefaultScope;
1004 if (out_conf != nullptr) {
1005 scope = out_conf->node()->scope();
1006 }
1007 ScopeGuard scope_guard(scope);
1008 MS_EXCEPTION_IF_NULL(item_value);
1009 if (item_value->isa<AnyValue>()) {
1010 MS_LOG(EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
1011 }
1012
1013 int64_t resolve_type = GetResolveType(data_type);
1014 if (resolve_type == kResolveTypeUserDefineClass) {
1015 return GetEvaluatedValueForClassAttrOrMethod(engine, args_spec_list, item_value, data_conf, out_conf);
1016 } else if (resolve_type == kResolveTypeBuiltInType) {
1017 return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, item_value, data_type, data_conf, out_conf);
1018 } else {
1019 return GetEvaluatedValueForNameSpaceString(engine, args_spec_list, out_conf);
1020 }
1021 }
1022 } // end anonymous namespace
1023
1024 namespace {
1025 class EmbedEvaluator : public SymbolicPrimEvaluator {
1026 public:
EmbedEvaluator()1027 EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
1028 ~EmbedEvaluator() override = default;
1029 MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
EvalPrim(const ConfigPtrList & args_conf_list)1030 EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
1031 // arg: free variable to be embedded
1032 if (args_conf_list.size() != 1) {
1033 MS_LOG(EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
1034 }
1035 AnfNodeConfigPtr node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
1036 MS_EXCEPTION_IF_NULL(node_conf);
1037 MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
1038 AbstractBasePtr x = node_conf->ObtainEvalResult()->abstract();
1039 x = SensitivityTransform(x);
1040 SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
1041 AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
1042 return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
1043 }
1044 };
1045
FindParameterNodeByString(const FuncGraphManagerPtr & manager,const std::string & name)1046 static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) {
1047 MS_EXCEPTION_IF_NULL(manager);
1048 auto root_g_set = manager->roots();
1049 if (root_g_set.size() != 1) {
1050 return nullptr;
1051 }
1052 const FuncGraphPtr &root_g = root_g_set.back();
1053 for (auto ¶m_node : root_g->parameters()) {
1054 auto param = param_node->cast<ParameterPtr>();
1055 if (param && name == param->name()) {
1056 return param;
1057 }
1058 }
1059 return nullptr;
1060 }
1061
1062 class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
1063 public:
RefToEmbedEvaluator()1064 RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
1065 ~RefToEmbedEvaluator() override = default;
1066 MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
EvalPrim(const ConfigPtrList & args_conf_list)1067 EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
1068 if (args_conf_list.size() != 1) {
1069 MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
1070 return nullptr;
1071 }
1072 static TypePtr type = std::make_shared<SymbolicKeyType>();
1073 auto node_conf = dyn_cast<AnfNodeConfig>(args_conf_list[0]);
1074 if (node_conf == nullptr) {
1075 MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
1076 return nullptr;
1077 }
1078 MS_EXCEPTION_IF_NULL(node_conf->ObtainEvalResult());
1079 AbstractBasePtr abs = node_conf->ObtainEvalResult()->abstract();
1080 MS_EXCEPTION_IF_NULL(abs);
1081 AbstractRefPtr ref_abs = abs->cast<AbstractRefPtr>();
1082 if (ref_abs == nullptr) {
1083 MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
1084 return nullptr;
1085 }
1086 auto key_abs = ref_abs->ref_key();
1087 if (key_abs == nullptr) {
1088 MS_LOG(ERROR) << "RefToEmbed input Ref key is nullptr.";
1089 return nullptr;
1090 }
1091 auto key_value = key_abs->BuildValue();
1092 if (key_value == nullptr) {
1093 MS_LOG(ERROR) << "RefToEmbed input Ref key value is nullptr.";
1094 return nullptr;
1095 }
1096 auto refkey = key_value->cast<RefKeyPtr>();
1097 if (refkey == nullptr) {
1098 auto ret = std::make_shared<AbstractScalar>(type);
1099 auto ref_value = ref_abs->ref();
1100 MS_EXCEPTION_IF_NULL(ref_value);
1101 return std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
1102 }
1103
1104 std::string name = refkey->tag();
1105 MS_EXCEPTION_IF_NULL(node_conf->node());
1106 if (node_conf->node()->func_graph() == nullptr) {
1107 MS_LOG(EXCEPTION) << "Should not evaluate a ValueNode, node: " << node_conf->node()->DebugString();
1108 }
1109 const auto &manager = node_conf->node()->func_graph()->manager();
1110 auto node = FindParameterNodeByString(manager, name);
1111 if (node == nullptr) {
1112 MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
1113 return nullptr;
1114 }
1115 AbstractBasePtr x = ref_abs->ref();
1116 x = SensitivityTransform(x);
1117 std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
1118 std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
1119 return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
1120 }
1121 };
1122
1123 class GetAttrEvaluator : public TransitionPrimEvaluator {
1124 public:
GetAttrEvaluator()1125 GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
1126 ~GetAttrEvaluator() override = default;
1127 MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ConfigPtr & in_conf0,const AnfNodeConfigPtr & out_conf)1128 EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
1129 const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
1130 constexpr auto kGetAttrArgSize = 2;
1131 auto ret_abstract = AbstractEval(args_spec_list);
1132 if (ret_abstract != nullptr) {
1133 MS_LOG(DEBUG) << "GetAttrEvaluator eval Undetermined";
1134 return ret_abstract;
1135 }
1136 // Inputs: data, item
1137 if (args_spec_list.size() != kGetAttrArgSize) {
1138 MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
1139 }
1140 EvalResultPtr ret = nullptr;
1141 if (bound_node() != nullptr) {
1142 TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
1143 ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
1144 } else {
1145 ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
1146 }
1147 // don't lookup from cache, as different out_conf with same node but different context
1148 // may add different entry to anfnode_config_map, like getattr primitive;
1149 evaluator_cache_mgr_->SetValue(args_spec_list, ret);
1150 return ret;
1151 }
1152 };
1153
1154 class ResolveEvaluator : public TransitionPrimEvaluator {
1155 public:
ResolveEvaluator()1156 ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
1157 ~ResolveEvaluator() override = default;
1158 MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ConfigPtr & in_conf0,const AnfNodeConfigPtr & out_conf)1159 EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list,
1160 const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
1161 constexpr auto kResolveArgSize = 2;
1162 // Inputs: namespace, symbol
1163 if (args_spec_list.size() != kResolveArgSize) {
1164 MS_LOG(EXCEPTION) << "Expected args_spec_list size = 2, but has size:" << args_spec_list.size();
1165 }
1166 EvalResultPtr ret = nullptr;
1167 if (bound_node() != nullptr) {
1168 TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
1169 ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
1170 } else {
1171 ret = StaticGetter(engine, args_spec_list, in_conf0, out_conf);
1172 }
1173 return ret;
1174 }
1175 };
1176
1177 class CreateInstanceEvaluator : public TransitionPrimEvaluator {
1178 public:
CreateInstanceEvaluator()1179 CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
1180 ~CreateInstanceEvaluator() override = default;
1181 MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)1182 EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
1183 const AnfNodeConfigPtr &out_conf) override {
1184 if (args_spec_list.empty()) {
1185 MS_LOG(EXCEPTION) << "'args_spec_list' should not be empty";
1186 }
1187
1188 // Get the type parameter.
1189 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
1190 TypePtr type = args_spec_list[0]->GetTypeTrack();
1191 MS_EXCEPTION_IF_NULL(type);
1192 if (type->type_id() != kMetaTypeTypeType) {
1193 MS_LOG(EXCEPTION) << "CreateInstanceEvaluator require first parameter should be an object of TypeType, but got "
1194 << type->ToString();
1195 }
1196
1197 ValuePtr value_track = args_spec_list[0]->GetValueTrack();
1198 MS_EXCEPTION_IF_NULL(value_track);
1199
1200 std::shared_ptr<parse::PyObjectWrapper> type_obj = dyn_cast<parse::PyObjectWrapper>(value_track);
1201 if (type_obj == nullptr) {
1202 MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
1203 }
1204
1205 if (!type_obj->isa<parse::ClassType>()) {
1206 MS_LOG(EXCEPTION) << "CreateInstanceEvaluator the type_obj should be an object of ClassType, but got "
1207 << type_obj->ToString() << ".";
1208 }
1209
1210 auto class_type = type_obj->obj();
1211 MS_LOG(DEBUG) << "Get class type is " << type_obj->ToString() << ".";
1212
1213 // Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
1214 py::tuple params = GetParameters(args_spec_list);
1215
1216 // Create class instance.
1217 auto obj = parse::data_converter::CreatePythonObject(class_type, params);
1218 if (py::isinstance<py::none>(obj)) {
1219 MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_type)
1220 << "` failed, only support to create \'Cell\' or \'Primitive\' object.";
1221 }
1222
1223 // Process the object.
1224 ValuePtr converted_ret = nullptr;
1225 bool converted = parse::ConvertData(obj, &converted_ret, true);
1226 if (!converted) {
1227 MS_LOG(EXCEPTION) << "Convert the python object failed";
1228 }
1229 MS_EXCEPTION_IF_NULL(converted_ret);
1230
1231 if (converted_ret->isa<FuncGraph>()) {
1232 AddToManager(engine, converted_ret->cast<FuncGraphPtr>());
1233 }
1234
1235 AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf);
1236 auto infer_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
1237 evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
1238 return infer_result;
1239 }
1240
GetParameters(const AbstractBasePtrList & args_spec_list) const1241 py::tuple GetParameters(const AbstractBasePtrList &args_spec_list) const {
1242 // Exclude class type by minus 1;
1243 std::size_t params_size = args_spec_list.size() - 1;
1244 auto params = py::tuple(params_size);
1245 if (params_size > params.size()) {
1246 MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size << ", params.size():" << params.size();
1247 }
1248 if (params_size > 0) {
1249 for (size_t i = 0; i < params_size; i++) {
1250 // Only support the Scalar parameters type. Bypass class type by offset with 1.
1251 auto arg = args_spec_list[i + 1];
1252 MS_EXCEPTION_IF_NULL(arg);
1253 // Because the Tensor's AbstractTensor can't get value from GetValueTrack.
1254 ValuePtr param_value = arg->BuildValue();
1255 py::object param = ValueToPyData(param_value);
1256 params[i] = param;
1257 }
1258 }
1259 return params;
1260 }
1261 };
1262
1263 class PyInterpretEvaluator : public TransitionPrimEvaluator {
1264 public:
PyInterpretEvaluator()1265 PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
1266 ~PyInterpretEvaluator() override = default;
1267 MS_DECLARE_PARENT(PyInterpretEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_spec_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)1268 EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list, const ConfigPtr &,
1269 const AnfNodeConfigPtr &out_conf) override {
1270 if (args_spec_list.empty()) {
1271 MS_LOG(ERROR) << "'args_spec_list' should not be empty";
1272 }
1273
1274 // Get the type parameter.
1275 MS_EXCEPTION_IF_NULL(args_spec_list[0]);
1276 ValuePtr value_track = args_spec_list[0]->GetValueTrack();
1277 MS_EXCEPTION_IF_NULL(value_track);
1278
1279 std::shared_ptr<parse::Script> script_obj = dyn_cast<parse::Script>(value_track);
1280 if (script_obj == nullptr) {
1281 MS_LOG(EXCEPTION) << "Cast value failed, not PyObjectWrapper:" << value_track->ToString() << ".";
1282 }
1283
1284 // Make global and local parameters.
1285 py::tuple params = MakeParameters(args_spec_list);
1286
1287 // Call python script string.
1288 MS_LOG(DEBUG) << "Call script: " << script_obj->script() << ", params: " << py::str(params);
1289 auto obj = parse::data_converter::CallPythonScript(py::str(script_obj->script()), params);
1290 if (py::isinstance<py::none>(obj)) {
1291 MS_LOG(EXCEPTION) << "Failed to call python script: `" << script_obj->script() << "`";
1292 }
1293
1294 ValuePtr converted_val = nullptr;
1295 bool converted = parse::ConvertData(obj, &converted_val, true);
1296 if (!converted) {
1297 MS_LOG(EXCEPTION) << "Convert the python object failed";
1298 }
1299 MS_EXCEPTION_IF_NULL(converted_val);
1300
1301 AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
1302 auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
1303 evaluator_cache_mgr_->SetValue(args_spec_list, infer_result);
1304 return infer_result;
1305 }
1306
MakeParameters(const AbstractBasePtrList & args_spec_list) const1307 py::tuple MakeParameters(const AbstractBasePtrList &args_spec_list) const {
1308 constexpr int params_size = 3;
1309 if (params_size != args_spec_list.size()) {
1310 MS_LOG(EXCEPTION) << "Unexpected params_size: " << params_size
1311 << ", not equal to arguments.size:" << args_spec_list.size();
1312 }
1313 // The first argument is script string, ignore it.
1314 auto params = py::tuple(params_size - 1);
1315
1316 // Make the global parameters.
1317 auto global_dict = dyn_cast<AbstractDictionary>(args_spec_list[1]); // Global parameters dict.
1318 MS_EXCEPTION_IF_NULL(global_dict);
1319 MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString() << ", [" << global_dict->type_name() << "]";
1320 ValuePtr global_dict_value = global_dict->BuildValue();
1321 py::object global_params_dict = ValueToPyData(global_dict_value);
1322 MS_LOG(DEBUG) << "arg_1, python global_params_dict: " << py::str(global_params_dict);
1323 params[0] = global_params_dict;
1324
1325 // Make the local parameters.
1326 auto local_dict = dyn_cast<AbstractDictionary>(args_spec_list[2]); // Local parameters dict.
1327 MS_EXCEPTION_IF_NULL(local_dict);
1328 MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString() << ", [" << local_dict->type_name() << "]";
1329 ValuePtr local_dict_value = local_dict->BuildValue();
1330 py::object local_params_dict = ValueToPyData(local_dict_value);
1331 MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << py::str(local_params_dict);
1332 params[1] = local_params_dict;
1333 return params;
1334 }
1335 };
1336
1337 class PartialEvaluator : public Evaluator {
1338 public:
PartialEvaluator()1339 PartialEvaluator() : Evaluator("PartialEvaluator") {}
1340 ~PartialEvaluator() override = default;
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)1341 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
1342 const AnfNodeConfigPtr &out_conf) override {
1343 if (args_conf_list.size() == 0) {
1344 MS_LOG(EXCEPTION) << "Args size should be greater than 0";
1345 }
1346
1347 MS_EXCEPTION_IF_NULL(out_conf);
1348 MS_EXCEPTION_IF_NULL(out_conf->node());
1349 MS_EXCEPTION_IF_NULL(args_conf_list[0]);
1350 MS_EXCEPTION_IF_NULL(args_conf_list[0]->ObtainEvalResult());
1351 auto arg0_value = args_conf_list[0]->ObtainEvalResult()->abstract();
1352 MS_EXCEPTION_IF_NULL(arg0_value);
1353 AbstractBasePtrList args_spec_list{arg0_value};
1354 // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
1355 if (arg0_value->isa<AbstractError>()) {
1356 MS_EXCEPTION_IF_NULL(arg0_value->GetValueTrack());
1357 auto ret = std::make_shared<AbstractError>(arg0_value->GetValueTrack()->cast<StringImmPtr>(), out_conf->node());
1358 MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString()
1359 << " as func is: " << arg0_value->ToString();
1360 auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
1361 evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
1362 return eval_result;
1363 }
1364 auto func = CheckArg<AbstractFunction>("partial", args_spec_list, 0);
1365 // Sometimes, node[0] in out_conf becomes phi0;
1366 if (func->isa<PrimitiveAbstractClosure>()) {
1367 auto prim_func = dyn_cast<PrimitiveAbstractClosure>(func);
1368 MS_EXCEPTION_IF_NULL(prim_func->prim());
1369 if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
1370 prim::DoSignaturePrimitivePtr do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(prim_func->prim());
1371 return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
1372 }
1373 }
1374
1375 (void)std::transform(
1376 args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list),
1377 [](const ConfigPtr &config) -> AbstractBasePtr { return config->ObtainEvalResult()->abstract(); });
1378 AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end());
1379
1380 auto cnode = out_conf->node()->cast<CNodePtr>();
1381 MS_EXCEPTION_IF_NULL(cnode);
1382 if (cnode->size() != (args_conf_list.size() + 1)) {
1383 MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString()
1384 << ", args_conf_list: " << mindspore::ToString(args_conf_list);
1385 }
1386 AbstractFuncAtomPtrList partial_funcs_list;
1387 auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) {
1388 auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode);
1389 partial_funcs_list.push_back(new_func);
1390 };
1391 func->Visit(build_partial);
1392
1393 auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
1394 auto eval_result = std::make_shared<EvalResult>(ret, std::make_shared<AttrValueMap>());
1395 evaluator_cache_mgr_->SetValue(args_spec_list, eval_result);
1396 return eval_result;
1397 }
1398
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)1399 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
1400 MS_LOG(EXCEPTION) << "Eval() should not be called, Run() method should be called";
1401 }
1402
HandleDoSignature(const AnalysisEnginePtr & engine,const ValuePtr & signature_value,const AnfNodeConfigPtr & out_conf) const1403 EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
1404 const AnfNodeConfigPtr &out_conf) const {
1405 MS_EXCEPTION_IF_NULL(engine);
1406 MS_EXCEPTION_IF_NULL(out_conf);
1407 MS_EXCEPTION_IF_NULL(out_conf->node());
1408 auto cnode = out_conf->node()->cast<CNodePtr>();
1409 if (cnode == nullptr) {
1410 MS_LOG(EXCEPTION) << "Cnode is nullptr";
1411 }
1412 std::vector<AnfNodePtr> new_nodes_inputs = cnode->inputs();
1413 auto new_signature_value = std::make_shared<prim::DoSignatureMetaFuncGraph>("signature", signature_value);
1414 new_nodes_inputs[1] = NewValueNode(new_signature_value);
1415 FuncGraphPtr func_graph = cnode->func_graph();
1416
1417 ScopePtr scope = out_conf->node()->scope();
1418 ScopeGuard scope_guard(scope);
1419 MS_EXCEPTION_IF_NULL(func_graph);
1420 CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
1421 AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
1422 return engine->ForwardConfig(out_conf, fn_conf);
1423 }
1424 };
1425
1426 struct PrimitiveImplInferValue {
1427 PrimitiveImpl impl_; // implement function of primitive
1428 bool eval_value_; // whether evaluate value
1429 TypePtr specify_out_type_; // whether specify return type
1430 bool in_white_list_; // true if this Primitive in white list, else false.
1431 };
1432
1433 using PrimitiveToImplMap = std::unordered_map<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
GetUniformPrimitiveToImplMap()1434 PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
1435 static PrimitiveToImplMap uniform_prim_implement_map = {
1436 {prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}},
1437 {prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}},
1438 {prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
1439 {prim::kPrimScalarDiv, {prim::ScalarDiv, true, nullptr, true}},
1440 {prim::kPrimScalarMod, {prim::ScalarMod, true, nullptr, true}},
1441 {prim::kPrimScalarPow, {prim::ScalarPow, true, nullptr, true}},
1442 {prim::kPrimScalarFloordiv, {prim::ScalarFloordiv, true, nullptr, true}},
1443 {prim::kPrimScalarUadd, {prim::ScalarUAdd, true, nullptr, true}},
1444 {prim::kPrimScalarUsub, {prim::ScalarUSub, true, nullptr, true}},
1445 {prim::kPrimScalarLog, {prim::ScalarLog, true, nullptr, true}},
1446 {prim::kPrimScalarEq, {prim::ScalarEq, true, std::make_shared<Bool>(), true}},
1447 {prim::kPrimScalarLt, {prim::ScalarLt, true, std::make_shared<Bool>(), true}},
1448 {prim::kPrimScalarGt, {prim::ScalarGt, true, std::make_shared<Bool>(), true}},
1449 {prim::kPrimScalarNe, {prim::ScalarNe, true, std::make_shared<Bool>(), true}},
1450 {prim::kPrimScalarLe, {prim::ScalarLe, true, std::make_shared<Bool>(), true}},
1451 {prim::kPrimScalarGe, {prim::ScalarGe, true, std::make_shared<Bool>(), true}},
1452 {prim::kPrimBoolNot, {prim::BoolNot, true, std::make_shared<Bool>(), true}},
1453 {prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}},
1454 {prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}},
1455 {prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}},
1456 };
1457 return uniform_prim_implement_map;
1458 }
1459
1460 PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap();
1461 std::mutex PrimEvaluatorConstructorMutex;
1462
InitPrimEvaluatorConstructors()1463 void InitPrimEvaluatorConstructors() {
1464 PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
1465
1466 for (const auto &iter : GetPrimitiveToEvalImplMap()) {
1467 constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second);
1468 }
1469
1470 for (const auto &iter : GetUniformPrimitiveToImplMap()) {
1471 constructor[iter.first] =
1472 InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
1473 }
1474 constructor[prim::kPrimEmbed] = std::make_shared<EmbedEvaluator>();
1475 constructor[prim::kPrimRefToEmbed] = std::make_shared<RefToEmbedEvaluator>();
1476 constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
1477 constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
1478 constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
1479 constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
1480 constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
1481 }
1482 } // namespace
1483
ClearPrimEvaluatorMap()1484 void ClearPrimEvaluatorMap() {
1485 PrimEvaluatorConstructors.clear();
1486 GetPrimitiveToEvalImplMap().clear();
1487 GetUniformPrimitiveToImplMap().clear();
1488 }
1489
IsInWhiteList(const PrimitivePtr & primitive)1490 bool IsInWhiteList(const PrimitivePtr &primitive) {
1491 MS_EXCEPTION_IF_NULL(primitive);
1492
1493 auto iter = GetPrimitiveToEvalImplMap().find(primitive);
1494 if (iter != GetPrimitiveToEvalImplMap().end()) {
1495 return iter->second.in_white_list_;
1496 }
1497
1498 auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
1499 if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
1500 return uni_iter->second.in_white_list_;
1501 }
1502
1503 return false;
1504 }
1505
GetPrimEvaluatorConstructors()1506 PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
1507 PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
1508 if (!constructor.empty()) {
1509 return constructor;
1510 }
1511 std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
1512 if (constructor.empty()) {
1513 InitPrimEvaluatorConstructors();
1514 }
1515
1516 return constructor;
1517 }
1518
1519 namespace {
IsSubtypeTuple(const AbstractBasePtr x,const TypePtr model)1520 bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
1521 MS_EXCEPTION_IF_NULL(x);
1522 MS_EXCEPTION_IF_NULL(model);
1523 auto x_tuple = dyn_cast<AbstractTuple>(x);
1524 auto model_tuple = dyn_cast<Tuple>(model);
1525
1526 if (x_tuple == nullptr || model_tuple == nullptr) {
1527 return false;
1528 }
1529
1530 if (model->IsGeneric()) {
1531 return true;
1532 }
1533
1534 if (x_tuple->size() != model_tuple->size()) {
1535 return false;
1536 }
1537
1538 for (size_t i = 0; i < x_tuple->size(); i++) {
1539 bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]);
1540 if (!is_subtype) {
1541 return false;
1542 }
1543 }
1544 return true;
1545 }
1546
IsSubtypeArray(const AbstractBasePtr x,const TypePtr model)1547 bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
1548 MS_EXCEPTION_IF_NULL(x);
1549 MS_EXCEPTION_IF_NULL(model);
1550 auto x_tensor = dyn_cast<AbstractTensor>(x);
1551 auto model_tensor = dyn_cast<TensorType>(model);
1552
1553 if (x_tensor == nullptr || model_tensor == nullptr) {
1554 return false;
1555 }
1556
1557 if (model->IsGeneric()) {
1558 return true;
1559 }
1560
1561 return IsSubtype(x_tensor->element(), model_tensor->element());
1562 }
1563
IsSubtypeList(const AbstractBasePtr x,const TypePtr model)1564 bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
1565 MS_EXCEPTION_IF_NULL(x);
1566 MS_EXCEPTION_IF_NULL(model);
1567 auto x_list = dyn_cast<AbstractList>(x);
1568 auto model_list = dyn_cast<List>(model);
1569
1570 if (x_list == nullptr || model_list == nullptr) {
1571 return false;
1572 }
1573
1574 if (model->IsGeneric()) {
1575 return true;
1576 }
1577
1578 if (x_list->size() != model_list->size()) {
1579 return false;
1580 }
1581
1582 bool is_subtype = true;
1583 for (size_t i = 0; i < x_list->size(); i++) {
1584 is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]);
1585 if (!is_subtype) {
1586 return false;
1587 }
1588 }
1589 return is_subtype;
1590 }
1591
IsSubtypeClass(const AbstractBasePtr x,const TypePtr model)1592 bool IsSubtypeClass(const AbstractBasePtr x, const TypePtr model) {
1593 MS_EXCEPTION_IF_NULL(x);
1594 MS_EXCEPTION_IF_NULL(model);
1595 auto x_class = dyn_cast<AbstractClass>(x);
1596 auto model_class = dyn_cast<Class>(model);
1597 if (x_class == nullptr) {
1598 return false;
1599 }
1600 if (model->IsGeneric()) {
1601 return true;
1602 }
1603 MS_EXCEPTION_IF_NULL(model_class);
1604 if (x_class->tag() == model_class->tag()) {
1605 auto m_attributes = model_class->GetAttributes();
1606 auto x_attributes = x_class->attributes();
1607 if (m_attributes.size() != x_attributes.size()) {
1608 return false;
1609 }
1610
1611 for (size_t i = 0; i < m_attributes.size(); i++) {
1612 if (!IsSubtype(x_attributes[i].second, m_attributes[i].second)) {
1613 return false;
1614 }
1615 }
1616 return true;
1617 }
1618
1619 return false;
1620 }
1621
IsSubtypeScalar(const AbstractBasePtr x,const TypePtr model)1622 inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) {
1623 MS_EXCEPTION_IF_NULL(x);
1624 MS_EXCEPTION_IF_NULL(model);
1625 if (dyn_cast<AbstractScalar>(x) == nullptr) {
1626 return false;
1627 }
1628 TypePtr x_type = x->GetTypeTrack();
1629 return IsSubType(x_type, model);
1630 }
1631 } // namespace
1632
IsSubtype(const AbstractBasePtr x,const TypePtr model)1633 bool IsSubtype(const AbstractBasePtr x, const TypePtr model) {
1634 MS_EXCEPTION_IF_NULL(x);
1635 MS_EXCEPTION_IF_NULL(model);
1636 TypeId model_typeid = model->type_id();
1637 switch (model_typeid) {
1638 case kMetaTypeObject:
1639 return true;
1640 case kObjectTypeTuple:
1641 return IsSubtypeTuple(x, model);
1642 case kObjectTypeTensorType:
1643 return IsSubtypeArray(x, model);
1644 case kObjectTypeList:
1645 return IsSubtypeList(x, model);
1646 case kObjectTypeClass:
1647 return IsSubtypeClass(x, model);
1648 default:
1649 if (IsSubType(model, std::make_shared<Number>())) {
1650 return IsSubtypeScalar(x, model);
1651 }
1652 MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << ".";
1653 }
1654 }
1655 } // namespace abstract
1656 } // namespace mindspore
1657