1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2024 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/ps/static_analysis/prim.h"
20
21 #include <algorithm>
22 #include <limits>
23 #include <map>
24 #include <mutex>
25 #include <string>
26 #include <utility>
27
28 #include "abstract/abstract_value.h"
29 #include "abstract/ops/primitive_infer_map.h"
30 #include "abstract/param_validator.h"
31 #include "abstract/utils.h"
32 #include "frontend/operator/cc_implementations.h"
33 #include "frontend/operator/composite/do_signature.h"
34 #include "frontend/operator/ops.h"
35 #include "frontend/operator/ops_front_infer_function.h"
36 #include "frontend/operator/prim_to_function.h"
37 #include "frontend/operator/composite/unpack_call.h"
38 #include "include/common/fallback.h"
39 #include "include/common/utils/convert_utils.h"
40 #include "include/common/utils/convert_utils_py.h"
41 #include "include/common/utils/primfunc_utils.h"
42 #include "ir/anf.h"
43 #include "ir/cell.h"
44 #include "ops/arithmetic_ops.h"
45 #include "ops/comparison_ops.h"
46 #include "ops/framework_ops.h"
47 #include "ops/other_ops.h"
48 #include "ops/sequence_ops.h"
49 #include "ops/structure_ops.h"
50 #include "ops/array_op_name.h"
51 #include "ops/op_utils.h"
52 #include "pipeline/jit/ps/debug/trace.h"
53 #include "pipeline/jit/ps/fallback.h"
54 #include "pipeline/jit/ps/parse/data_converter.h"
55 #include "pipeline/jit/ps/parse/parse_base.h"
56 #include "pipeline/jit/ps/parse/resolve.h"
57 #include "pipeline/jit/ps/pipeline.h"
58 #include "pipeline/jit/ps/resource.h"
59 #include "pipeline/jit/ps/static_analysis/evaluator.h"
60 #include "pipeline/jit/ps/static_analysis/builtin_prim.h"
61 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
62 #include "utils/check_convert_utils.h"
63 #include "utils/hash_set.h"
64 #include "utils/log_adapter.h"
65 #include "utils/ms_context.h"
66 #include "utils/ms_utils.h"
67 #include "utils/parallel_node_check.h"
68 #include "utils/shape_utils.h"
69 #include "utils/symbolic.h"
70 #include "utils/compile_config.h"
71
72 namespace mindspore {
73 using ClassTypePtr = std::shared_ptr<parse::ClassType>;
74 namespace abstract {
75 using mindspore::parse::PyObjectWrapper;
76
77 mindspore::HashSet<std::string> prims_to_skip_undetermined_infer{kMakeTupleOpName, kMakeListOpName, kSwitchOpName,
78 kEnvironSetOpName, kEnvironGetOpName, kLoadOpName,
79 kUpdateStateOpName};
80
81 // The Python primitives who visit tuple/list elements, but not consume all elements.
82 // Including:
83 // - Consume no element. For instance, MakeTuple.
84 // - Consume partial elements, not all. For instance, TupleGetItem.
85 // Map{"primitive name", {vector<int>:"index to transparent pass, -1 means all elements"}}
86 mindspore::HashMap<std::string, std::vector<int>> prims_transparent_pass_sequence{
87 {kReturnOpName, std::vector({0})}, {kDependOpName, std::vector({0})}, {kidentityOpName, std::vector({0})},
88 {kMakeTupleOpName, std::vector({-1})}, {kMakeListOpName, std::vector({-1})}, {kListAppendOpName, std::vector({0})},
89 {kTupleGetItemOpName, std::vector({0})}, {kListGetItemOpName, std::vector({0})}};
90
OpDtypeToInt(ops::OP_DTYPE dtype)91 inline int64_t OpDtypeToInt(ops::OP_DTYPE dtype) { return static_cast<int64_t>(dtype); }
92
GetNodeAfterTypeConversion(const AnfNodePtr & node,const ops::OpInputArg & op_arg,const FuncGraphPtr & fg)93 AnfNodePtr GetNodeAfterTypeConversion(const AnfNodePtr &node, const ops::OpInputArg &op_arg, const FuncGraphPtr &fg) {
94 MS_EXCEPTION_IF_NULL(fg);
95 // If src_cast_dtype is empty, do no need to do type conversion.
96 if (op_arg.cast_dtype_.empty()) {
97 return node;
98 }
99 const auto convert_func =
100 prim::GetPythonOps(parse::PYTHON_MOD_PRIMITIVE_OP_TYPE_CAST, parse::PYTHON_MOD_PRIMITIVE_ARG_DTYPE_CAST_MODULE);
101 auto convert_fg = dyn_cast<FuncGraph>(convert_func);
102 MS_EXCEPTION_IF_NULL(convert_fg);
103 convert_fg->set_manager(fg->manager());
104 return fg->NewCNodeInOrder({NewValueNode(convert_fg), node, NewValueNode(OpDtypeToInt(op_arg.arg_dtype_))});
105 }
106
GetNodeAfterArgHandler(const AnfNodePtr & node,const std::string & op_name,const ops::OpInputArg & op_arg,const AbstractBasePtr & abs,const FuncGraphPtr & fg)107 AnfNodePtr GetNodeAfterArgHandler(const AnfNodePtr &node, const std::string &op_name, const ops::OpInputArg &op_arg,
108 const AbstractBasePtr &abs, const FuncGraphPtr &fg) {
109 if (op_arg.arg_handler_.empty()) {
110 return node;
111 }
112 if (op_arg.is_optional_ && abs->isa<AbstractNone>()) {
113 return node;
114 }
115 const auto arg_handler_func = prim::GetPythonOps(op_arg.arg_handler_, parse::PYTHON_MOD_PRIMITIVE_ARG_HANDLER_MODULE);
116 MS_LOG(DEBUG) << "The arg handler function for '" << op_arg.arg_name_ << "' of Primitive[" << op_name << "] is "
117 << arg_handler_func->ToString() << ".";
118 if (arg_handler_func->isa<Primitive>()) {
119 auto arg_handler_fg = dyn_cast<Primitive>(arg_handler_func);
120 MS_EXCEPTION_IF_NULL(arg_handler_fg);
121 return fg->NewCNodeInOrder(
122 {NewValueNode(arg_handler_fg), NewValueNode(op_name), NewValueNode(op_arg.arg_name_), node});
123 }
124 auto arg_handler_fg = dyn_cast<FuncGraph>(arg_handler_func);
125 MS_EXCEPTION_IF_NULL(arg_handler_fg);
126 arg_handler_fg->set_manager(fg->manager());
127 return fg->NewCNodeInOrder(
128 {NewValueNode(arg_handler_fg), NewValueNode(op_name), NewValueNode(op_arg.arg_name_), node});
129 }
130
GenerateNewNodeBySignatures(const ValuePtr & func,const AbstractBasePtrList & args_abs_list,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf)131 CNodePtr DoSignatureEvaluator::GenerateNewNodeBySignatures(const ValuePtr &func,
132 const AbstractBasePtrList &args_abs_list,
133 const AnalysisEnginePtr &engine,
134 const AnfNodeConfigPtr &out_conf) {
135 if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
136 MS_LOG(INTERNAL_EXCEPTION) << "Node of out_conf should be CNode";
137 }
138 auto out_cnode = dyn_cast<CNode>(out_conf->node());
139 MS_EXCEPTION_IF_NULL(out_cnode);
140 auto fg = out_cnode->func_graph();
141 MS_EXCEPTION_IF_NULL(fg);
142 if (out_cnode->size() == 0 || (out_cnode->size() - 1) != args_abs_list.size()) {
143 MS_LOG(EXCEPTION) << "Op: " << func->ToString() << " args size should equal to inputs size minus 1, but args size "
144 << args_abs_list.size() << ", inputs size " << out_cnode->size();
145 }
146
147 // Handle primitive signatures.
148 AnfNodePtrList args_inputs;
149 (void)std::transform(out_cnode->weak_inputs().cbegin() + 1, out_cnode->weak_inputs().cend(),
150 std::back_inserter(args_inputs), [](const AnfNodeWeakPtr &weak_node) {
151 const auto &node = weak_node.lock();
152 MS_EXCEPTION_IF_NULL(node);
153 return node;
154 });
155 auto op_inputs = prim::GetNewInputsBySignatures(fg, prim_->ToString(), func, args_abs_list, args_inputs);
156 AnfNodePtrList new_inputs{NewValueNode(func)};
157 (void)std::copy(op_inputs.begin(), op_inputs.end(), std::back_inserter(new_inputs));
158 return fg->NewCNodeInOrder(new_inputs);
159 }
160
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)161 EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
162 const AnfNodeConfigPtr &out_conf) {
163 MS_EXCEPTION_IF_NULL(engine);
164 MS_EXCEPTION_IF_NULL(out_conf);
165 auto do_signature = prim_->cast_ptr<prim::DoSignaturePrimitive>();
166 MS_EXCEPTION_IF_NULL(do_signature);
167 auto &func = do_signature->function();
168 MS_EXCEPTION_IF_NULL(func);
169
170 AbstractBasePtrList args_abs_list;
171 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
172 [](const ConfigPtr &config) -> AbstractBasePtr {
173 MS_EXCEPTION_IF_NULL(config);
174 const auto &eval_result = config->ObtainEvalResult();
175 MS_EXCEPTION_IF_NULL(eval_result);
176 return eval_result->abstract();
177 });
178 if (func->isa<Primitive>()) {
179 auto do_signature_func = func->cast<PrimitivePtr>();
180 if (do_signature_func->name() == kIsInstanceOpName) {
181 // Handle for DDE.
182 for (size_t i = 0; i < args_abs_list.size(); ++i) {
183 MS_EXCEPTION_IF_NULL(args_abs_list[i]);
184 if (args_abs_list[i]->isa<abstract::AbstractSequence>()) {
185 MS_LOG(DEBUG) << "Primitive \'IsInstance\' is consuming tuple/list arguments[" << i
186 << "]: " << args_abs_list[i]->ToString();
187 SetSequenceElementsUseFlagsRecursively(args_abs_list[i], true);
188 }
189 }
190 }
191 // Do undetermined infer firstly.
192 if (prims_to_skip_undetermined_infer.find(do_signature_func->name()) == prims_to_skip_undetermined_infer.end()) {
193 auto res_abstract = EvalUndeterminedArgs(args_abs_list);
194 if (res_abstract != nullptr) {
195 MS_LOG(DEBUG) << "DoSignatureEvaluator eval Undetermined for " << do_signature_func->name()
196 << ", res_abstract: " << res_abstract->ToString();
197 return res_abstract;
198 }
199 }
200 }
201
202 CNodePtr new_cnode = nullptr;
203 ScopePtr scope = out_conf->node()->scope();
204 ScopeGuard scope_guard(scope);
205 if (bound_node() != nullptr) {
206 TraceGuard trace_guard(std::make_shared<TraceDoSignature>(bound_node()->debug_info()));
207 new_cnode = GenerateNewNodeBySignatures(func, args_abs_list, engine, out_conf);
208 } else {
209 new_cnode = GenerateNewNodeBySignatures(func, args_abs_list, engine, out_conf);
210 }
211 // Update new CNode info.
212 auto out_cnode = dyn_cast<CNode>(out_conf->node());
213 MS_EXCEPTION_IF_NULL(out_cnode);
214 new_cnode->CloneCNodeInfo(out_cnode);
215
216 // Do forward with old config and new config.
217 AnfNodeConfigPtr new_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
218 return engine->ForwardConfig(out_conf, new_conf);
219 }
220
GetUnpackGraphSpecArgsList(const AbstractBasePtrList & args_abs_list,bool need_unpack)221 static AbstractBasePtrList GetUnpackGraphSpecArgsList(const AbstractBasePtrList &args_abs_list, bool need_unpack) {
222 if (!need_unpack) {
223 // arg[0] is the func graph to unpack, ignore it
224 AbstractBasePtrList specialize_args_before_unpack(args_abs_list.begin() + 1, args_abs_list.end());
225 return specialize_args_before_unpack;
226 }
227
228 AbstractBasePtrList graph_specialize_args;
229 // arg[0] is the func graph to unpack, ignore it
230 for (size_t index = 1; index < args_abs_list.size(); index++) {
231 MS_EXCEPTION_IF_NULL(args_abs_list[index]);
232 if (args_abs_list[index]->isa<AbstractTuple>()) {
233 const auto &arg_tuple = args_abs_list[index]->cast_ptr<AbstractTuple>();
234 (void)std::transform(arg_tuple->elements().cbegin(), arg_tuple->elements().cend(),
235 std::back_inserter(graph_specialize_args), [](AbstractBasePtr abs) { return abs; });
236 } else if (args_abs_list[index]->isa<AbstractDictionary>()) {
237 auto arg_dict = args_abs_list[index]->cast_ptr<AbstractDictionary>();
238 MS_EXCEPTION_IF_NULL(arg_dict);
239 const auto &dict_elems = arg_dict->elements();
240 (void)std::transform(dict_elems.cbegin(), dict_elems.cend(), std::back_inserter(graph_specialize_args),
241 [](const AbstractElementPair &item) {
242 MS_EXCEPTION_IF_NULL(item.first);
243 // Dict_elems's first element represents parameter names, which should be string type.
244 return std::make_shared<AbstractKeywordArg>(
245 GetValue<std::string>(item.first->BuildValue()), item.second);
246 });
247 } else {
248 MS_LOG(INTERNAL_EXCEPTION) << "UnpackGraph require args should be tuple or dict, but got "
249 << args_abs_list[index]->ToString();
250 }
251 }
252 return graph_specialize_args;
253 }
254
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)255 EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
256 const AnfNodeConfigPtr &out_conf) {
257 MS_EXCEPTION_IF_NULL(engine);
258 MS_EXCEPTION_IF_NULL(out_conf);
259 MS_EXCEPTION_IF_NULL(out_conf->node());
260 if (!out_conf->node()->isa<CNode>()) {
261 MS_LOG(INTERNAL_EXCEPTION) << "Node of out_conf should be CNode";
262 }
263 MS_EXCEPTION_IF_NULL(prim_);
264 auto unpack_graph = prim_->cast_ptr<prim::UnpackGraphPrimitive>();
265 MS_EXCEPTION_IF_NULL(unpack_graph);
266 auto out_cnode = out_conf->node()->cast_ptr<CNode>();
267 MS_EXCEPTION_IF_NULL(out_cnode);
268 if (out_cnode->empty() || (out_cnode->size() - 1) != args_conf_list.size()) {
269 MS_LOG(EXCEPTION) << "UnpackGraphPrimitive"
270 << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
271 << ", inputs size " << out_cnode->size();
272 }
273 AbstractBasePtrList args_abs_list;
274 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
275 [](const ConfigPtr &ref) -> AbstractBasePtr {
276 MS_EXCEPTION_IF_NULL(ref);
277 const auto &eval_result = ref->ObtainEvalResult();
278 MS_EXCEPTION_IF_NULL(eval_result);
279 return eval_result->abstract();
280 });
281 // Get the forward graph
282 if (args_abs_list.empty()) {
283 MS_LOG(INTERNAL_EXCEPTION) << "args_abs_list can't be empty.";
284 }
285 MS_EXCEPTION_IF_NULL(args_abs_list[0]);
286 auto fn = args_abs_list[0]->cast_ptr<AbstractFunction>();
287 if (fn == nullptr) {
288 MS_LOG(INTERNAL_EXCEPTION) << "UnpackGraphPrimitive arg0 must be AbstractFunction, but "
289 << args_abs_list[0]->ToString();
290 }
291 AbstractBasePtrList graph_specialize_args_without_sens;
292 FuncGraphAbstractClosure *real_fn = nullptr;
293 // If it's Partial closure, fetch the func graph from it.
294 const auto &partial_fn_abs = fn->cast_ptr<PartialAbstractClosure>();
295 if (partial_fn_abs != nullptr) {
296 const auto &partial_fn = partial_fn_abs->fn();
297 MS_EXCEPTION_IF_NULL(partial_fn);
298 real_fn = partial_fn->cast_ptr<FuncGraphAbstractClosure>();
299 } else {
300 real_fn = fn->cast_ptr<FuncGraphAbstractClosure>();
301 }
302 MS_EXCEPTION_IF_NULL(real_fn);
303 FuncGraphPtr forward_graph = real_fn->func_graph();
304 MS_EXCEPTION_IF_NULL(forward_graph);
305 AbstractBasePtrList graph_specialize_args =
306 GetUnpackGraphSpecArgsList(args_abs_list, unpack_graph->need_unpack_args());
307 if (unpack_graph->with_sens_in_args() && graph_specialize_args.empty()) {
308 MS_EXCEPTION(ValueError) << "Grad with sens, but the sens is not provided.";
309 }
310 // If it's Partial closure, copy the arg list in advance.
311 if (partial_fn_abs != nullptr) {
312 (void)std::copy(partial_fn_abs->args().begin(), partial_fn_abs->args().end(),
313 std::back_inserter(graph_specialize_args_without_sens));
314 }
315 (void)std::transform(graph_specialize_args.begin(),
316 graph_specialize_args.end() - (unpack_graph->with_sens_in_args() ? 1 : 0),
317 std::back_inserter(graph_specialize_args_without_sens), [](AbstractBasePtr abs) { return abs; });
318 MS_LOG(DEBUG) << "forward_graph: " << forward_graph->ToString()
319 << ", graph_specialize_args_without_sens size: " << graph_specialize_args_without_sens.size();
320 auto new_forward_graph = forward_graph->GenerateFuncGraph(graph_specialize_args_without_sens);
321 MS_EXCEPTION_IF_NULL(engine->func_graph_manager());
322 engine->func_graph_manager()->AddFuncGraph(new_forward_graph);
323 ScopePtr scope = kDefaultScope;
324 if (out_conf != nullptr) {
325 scope = out_conf->node()->scope();
326 }
327 ScopeGuard scope_guard(scope);
328 AnfNodePtr new_node = NewValueNode(new_forward_graph);
329 AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
330 return engine->ForwardConfig(out_conf, fn_conf);
331 }
332
MixedPrecisionCastHelper(const AnfNodePtr & source_node,const AbstractBasePtr & node_type,const AnfNodePtr & target_type,const FuncGraphPtr & func_graph)333 AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const AbstractBasePtr &node_type,
334 const AnfNodePtr &target_type, const FuncGraphPtr &func_graph) {
335 MS_EXCEPTION_IF_NULL(node_type);
336 MS_EXCEPTION_IF_NULL(func_graph);
337 AnfNodePtr target_node = source_node;
338 if (node_type->isa<AbstractTensor>()) {
339 auto x = node_type->cast_ptr<AbstractTensor>();
340 MS_EXCEPTION_IF_NULL(x->element());
341 MS_EXCEPTION_IF_NULL(x->element()->BuildType());
342 if (x->element()->BuildType()->isa<Float>() || x->element()->BuildType()->isa<BFloat>()) {
343 auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional");
344 MS_EXCEPTION_IF_NULL(cast);
345 target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type});
346 }
347 } else if (node_type->isa<AbstractSequence>()) {
348 auto x = node_type->cast_ptr<AbstractSequence>();
349 auto &items = x->elements();
350 std::vector<AnfNodePtr> nodes;
351 (void)nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
352 int64_t idx = 0;
353 for (const auto &item : items) {
354 AnfNodePtr sequence_node = nullptr;
355 if (node_type->isa<AbstractList>()) {
356 sequence_node = func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), source_node, NewValueNode(idx)});
357 } else {
358 sequence_node = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), source_node, NewValueNode(idx)});
359 }
360 AnfNodePtr node = MixedPrecisionCastHelper(sequence_node, item, target_type, func_graph);
361 (void)nodes.emplace_back(node);
362 ++idx;
363 }
364 target_node = func_graph->NewCNode(nodes);
365 } else if (node_type->isa<AbstractDictionary>()) {
366 auto x = node_type->cast_ptr<AbstractDictionary>();
367 auto &items = x->elements();
368 std::vector<AnfNodePtr> dict_key_nodes;
369 std::vector<AnfNodePtr> dict_value_nodes;
370 (void)dict_key_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
371 (void)dict_value_nodes.emplace_back(NewValueNode(prim::kPrimMakeTuple));
372 for (const auto &item : items) {
373 MS_EXCEPTION_IF_NULL(item.first);
374 auto key_value = item.first->BuildValue();
375 MS_EXCEPTION_IF_NULL(key_value);
376 AnfNodePtr dict_key_node = NewValueNode(key_value);
377 AnfNodePtr dict_value_node =
378 func_graph->NewCNode({NewValueNode(prim::kPrimDictGetItem), source_node, NewValueNode(key_value)});
379 AnfNodePtr key_node = MixedPrecisionCastHelper(dict_key_node, item.first, target_type, func_graph);
380 AnfNodePtr value_node = MixedPrecisionCastHelper(dict_value_node, item.second, target_type, func_graph);
381 (void)dict_key_nodes.emplace_back(key_node);
382 (void)dict_value_nodes.emplace_back(value_node);
383 }
384 target_node =
385 func_graph->NewCNode({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNode(std::move(dict_key_nodes)),
386 func_graph->NewCNode(std::move(dict_value_nodes))});
387 } else if (node_type->isa<AbstractKeywordArg>()) {
388 auto x = node_type->cast_ptr<AbstractKeywordArg>();
389 std::string kwarg_key = x->get_key();
390 AnfNodePtr kwarg_value_node =
391 func_graph->NewCNode({NewValueNode(prim::kPrimExtractKeywordArg), NewValueNode(kwarg_key), source_node});
392 AnfNodePtr node = MixedPrecisionCastHelper(kwarg_value_node, x->get_arg(), target_type, func_graph);
393 target_node = func_graph->NewCNode({NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(kwarg_key), node});
394 }
395 return target_node;
396 }
397
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)398 EvalResultPtr MixedPrecisionCastEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
399 const AnfNodeConfigPtr &out_conf) {
400 MS_EXCEPTION_IF_NULL(engine);
401 AbstractBasePtrList args_abs_list;
402 MS_EXCEPTION_IF_NULL(out_conf);
403 if (out_conf->node() == nullptr || !out_conf->node()->isa<CNode>()) {
404 MS_LOG(INTERNAL_EXCEPTION) << "Node of out_conf should be CNode";
405 }
406 auto out_cnode = out_conf->node()->cast<CNodePtr>();
407 MS_EXCEPTION_IF_NULL(out_cnode);
408 if (out_cnode->empty() || (out_cnode->size() - 1) != args_conf_list.size()) {
409 MS_LOG(EXCEPTION) << "MixedPrecisionCast"
410 << " args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
411 << ", inputs size " << out_cnode->size();
412 }
413 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
414 [](const ConfigPtr &ref) -> AbstractBasePtr {
415 MS_EXCEPTION_IF_NULL(ref);
416 const auto &eval_result = ref->ObtainEvalResult();
417 MS_EXCEPTION_IF_NULL(eval_result);
418 return eval_result->abstract();
419 });
420
421 ScopeGuard scope_guard(out_conf->node()->scope());
422 TraceGuard trace_guard(std::make_shared<TraceMixedPrecision>(out_conf->node()->debug_info()));
423
424 FuncGraphPtr func_graph = out_cnode->func_graph();
425 constexpr size_t source_node_index = 2;
426 if (out_cnode->size() <= source_node_index) {
427 MS_LOG(EXCEPTION) << "Input size: " << out_cnode->size() << " should bigger than 2.";
428 }
429
430 AnfNodePtr new_node =
431 MixedPrecisionCastHelper(out_cnode->input(source_node_index), args_abs_list[1], out_cnode->input(1), func_graph);
432 AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
433
434 if (new_node->isa<CNode>()) {
435 auto new_cnode = new_node->cast_ptr<CNode>();
436 new_cnode->CloneCNodeInfo(out_cnode);
437 }
438 return engine->ForwardConfig(out_conf, fn_conf);
439 }
440
441 namespace {
CheckTensorCondValid(const AbstractBasePtr & cond)442 void CheckTensorCondValid(const AbstractBasePtr &cond) {
443 // Tensor condition must be one element or dynamic shape.
444 auto base_shape = cond->BuildShape();
445 MS_EXCEPTION_IF_NULL(base_shape);
446 ShapeVector cond_shape = base_shape->cast<ShapePtr>()->shape();
447 if (cond_shape.empty()) {
448 return;
449 }
450 constexpr auto num_one = 1;
451 for (size_t i = 0; i < cond_shape.size(); i++) {
452 if (cond_shape[i] != num_one && cond_shape[i] != Shape::kShapeDimAny && cond_shape[i] != Shape::kShapeRankAny) {
453 MS_LOG(ERROR) << "The condition value of control flow can be a tensor with one element, "
454 << "but got tensor with shape " << base_shape->ToString();
455 MS_EXCEPTION(ValueError) << "The truth value of an array with more than one element is ambiguous.";
456 }
457 }
458 }
459 } // namespace
460
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)461 EvalResultPtr SwitchEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
462 const AnfNodeConfigPtr &out_conf) {
463 MS_EXCEPTION_IF_NULL(engine);
464 AbstractBasePtrList args_abs_list;
465 MS_EXCEPTION_IF_NULL(out_conf);
466 MS_EXCEPTION_IF_NULL(out_conf->node());
467 if (!out_conf->node()->isa<CNode>()) {
468 MS_LOG(INTERNAL_EXCEPTION) << "Node of out_conf should be CNode";
469 }
470 auto out_cnode = out_conf->node()->cast<CNodePtr>();
471 MS_EXCEPTION_IF_NULL(out_cnode);
472 if (out_cnode->empty() || (out_cnode->size() - 1) != args_conf_list.size()) {
473 MS_LOG(EXCEPTION) << "For 'Switch',"
474 << " the args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
475 << ", inputs size " << out_cnode->size();
476 }
477
478 // Inputs: condition, true branch, false branch
479 constexpr auto switch_input_size = 3;
480 if (args_conf_list.size() != switch_input_size) {
481 MS_LOG(EXCEPTION) << "Switch evaluator requires 3 parameters, while the input size is " << args_abs_list.size()
482 << ".";
483 }
484
485 auto eval_func = [](const ConfigPtr &conf) -> AbstractBasePtr {
486 MS_EXCEPTION_IF_NULL(conf);
487 const auto &eval_result = conf->ObtainEvalResult();
488 MS_EXCEPTION_IF_NULL(eval_result);
489 auto abs = eval_result->abstract();
490 MS_EXCEPTION_IF_NULL(abs);
491 return abs;
492 };
493
494 auto cond_abstract = eval_func(args_conf_list[0]);
495 ValuePtr cond_value = cond_abstract->GetValueTrack();
496 MS_EXCEPTION_IF_NULL(cond_value);
497 // If the value of condition is ValueAny or the abstract of condition is AbstractTensor,
498 // keeps both true and false branch.
499 if (cond_value->isa<ValueAny>() || cond_abstract->isa<AbstractTensor>()) {
500 if (cond_abstract->isa<AbstractTensor>()) {
501 CheckTensorCondValid(cond_abstract);
502 }
503 auto true_branch = eval_func(args_conf_list[1]);
504 // Need record two func_graph
505 constexpr auto false_branch_index = 2;
506 auto false_branch = eval_func(args_conf_list[false_branch_index]);
507 SetVariableFlag(true_branch);
508 SetVariableFlag(false_branch);
509 auto res_abs = true_branch->Join(false_branch);
510 auto eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>());
511 return eval_result;
512 }
513
514 if (cond_value->isa<Scalar>()) {
515 AbstractBasePtr res_abs = nullptr;
516 if (cond_value->cast<ScalarPtr>()->IsOne()) {
517 const auto &true_branch = eval_func(args_conf_list[1]);
518 res_abs = true_branch;
519 } else {
520 constexpr auto false_branch_index = 2;
521 auto false_branch = eval_func(args_conf_list[false_branch_index]);
522 res_abs = false_branch;
523 }
524 auto eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>());
525 return eval_result;
526 }
527 MS_LOG(EXCEPTION) << "Not support this condition value: " << cond_abstract->GetValueTrack()->ToString();
528 }
529
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)530 EvalResultPtr SwitchLayerEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
531 const AnfNodeConfigPtr &out_conf) {
532 MS_EXCEPTION_IF_NULL(engine);
533 AbstractBasePtrList args_abs_list;
534 MS_EXCEPTION_IF_NULL(out_conf);
535 MS_EXCEPTION_IF_NULL(out_conf->node());
536 if (!out_conf->node()->isa<CNode>()) {
537 MS_LOG(INTERNAL_EXCEPTION) << "Node of out_conf should be CNode";
538 }
539 auto out_cnode = out_conf->node()->cast<CNodePtr>();
540 MS_EXCEPTION_IF_NULL(out_cnode);
541 if (out_cnode->empty() || (out_cnode->size() - 1) != args_conf_list.size()) {
542 MS_LOG(EXCEPTION) << "For 'SwitchLayer',"
543 << " the args size should equal to inputs size minus 1, but args size " << args_conf_list.size()
544 << ", inputs size " << out_cnode->size();
545 }
546
547 // Inputs: condition, true branch, false branch
548 constexpr auto switch_input_size = 3;
549 if (args_conf_list.size() != switch_input_size) {
550 MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 3 parameters, while the input size is " << args_abs_list.size()
551 << ".";
552 }
553 auto eval_func = [](const ConfigPtr &conf) -> AbstractBasePtr {
554 MS_EXCEPTION_IF_NULL(conf);
555 const auto &eval_result = conf->ObtainEvalResult();
556 MS_EXCEPTION_IF_NULL(eval_result);
557 auto abs = eval_result->abstract();
558 MS_EXCEPTION_IF_NULL(abs);
559 return abs;
560 };
561 auto cond_abstract = eval_func(args_conf_list[0]);
562 ValuePtr cond_value = cond_abstract->GetValueTrack();
563 MS_EXCEPTION_IF_NULL(cond_value);
564 MS_LOG(EXCEPTION) << "Not support this condition value: " << cond_value->ToString();
565 }
566
567 namespace {
BuildPyObject(const ValuePtr & value_ptr)568 py::object BuildPyObject(const ValuePtr &value_ptr) {
569 if (value_ptr == nullptr) {
570 return py::none();
571 } else {
572 return ValueToPyData(value_ptr);
573 }
574 }
575
AbstractTupleValueToPython(const AbstractTuple * tuple_abs)576 py::object AbstractTupleValueToPython(const AbstractTuple *tuple_abs) {
577 MS_EXCEPTION_IF_NULL(tuple_abs);
578 if (tuple_abs->dynamic_len()) {
579 return py::none();
580 }
581 const auto &elements = tuple_abs->elements();
582 size_t len = elements.size();
583 py::tuple value_tuple(len);
584 for (size_t i = 0; i < len; ++i) {
585 value_tuple[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE];
586 }
587 return value_tuple;
588 }
589
AbstractTupleToPython(const AbstractBasePtr & abs_base,bool only_convert_value)590 py::dict AbstractTupleToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
591 auto arg_tuple = dyn_cast_ptr<AbstractTuple>(abs_base);
592 MS_EXCEPTION_IF_NULL(arg_tuple);
593 auto dic = py::dict();
594 if (only_convert_value) {
595 dic[ATTR_VALUE] = AbstractTupleValueToPython(arg_tuple);
596 return dic;
597 }
598 if (arg_tuple->dynamic_len()) {
599 dic[ATTR_VALUE] = py::none();
600 dic[ATTR_SHAPE] = ShapeVector{abstract::Shape::kShapeDimAny};
601 dic[ATTR_DTYPE] = arg_tuple->BuildType();
602 return dic;
603 }
604 size_t len = arg_tuple->size();
605 py::tuple shape_tuple(len);
606 py::tuple dtype_tuple(len);
607 py::tuple value_tuple(len);
608 std::vector<py::dict> res;
609
610 for (size_t i = 0; i < len; i++) {
611 py::dict out = ConvertAbstractToPython(arg_tuple->elements()[i]);
612 res.push_back(out);
613 shape_tuple[i] = out[ATTR_SHAPE];
614 dtype_tuple[i] = out[ATTR_DTYPE];
615 value_tuple[i] = out[ATTR_VALUE];
616 }
617 dic[ATTR_SHAPE] = shape_tuple;
618 dic[ATTR_DTYPE] = dtype_tuple;
619 dic[ATTR_VALUE] = value_tuple;
620
621 return dic;
622 }
623
AbstractDictionaryToPython(const AbstractBasePtr & abs_base)624 py::dict AbstractDictionaryToPython(const AbstractBasePtr &abs_base) {
625 auto arg_dict = dyn_cast_ptr<AbstractDictionary>(abs_base);
626 MS_EXCEPTION_IF_NULL(arg_dict);
627
628 size_t len = arg_dict->size();
629 const auto &arg_dict_elements = arg_dict->elements();
630 py::list shape_list(len);
631 py::list dtype_list(len);
632 py::dict value_dict = py::dict();
633
634 for (size_t i = 0; i < len; ++i) {
635 auto cur_attr = arg_dict_elements[i];
636 auto cur_key = cur_attr.first;
637 auto cur_value = cur_attr.second;
638
639 py::dict cur_value_out = ConvertAbstractToPython(cur_value);
640 shape_list[i] = cur_value_out[ATTR_SHAPE];
641 dtype_list[i] = cur_value_out[ATTR_DTYPE];
642 MS_EXCEPTION_IF_NULL(cur_key);
643 value_dict[ValueToPyData(cur_key->BuildValue())] = cur_value_out[ATTR_VALUE];
644 }
645
646 py::dict dic = py::dict();
647 dic[ATTR_SHAPE] = shape_list;
648 dic[ATTR_DTYPE] = dtype_list;
649 MS_EXCEPTION_IF_NULL(arg_dict->BuildValue());
650 dic[ATTR_VALUE] = value_dict;
651 return dic;
652 }
653
AbstractKWArgsToPython(const AbstractBasePtr & abs_base)654 py::object AbstractKWArgsToPython(const AbstractBasePtr &abs_base) {
655 MS_EXCEPTION_IF_NULL(abs_base);
656 auto abs_keyword_arg = abs_base->cast_ptr<abstract::AbstractKeywordArg>();
657 MS_EXCEPTION_IF_NULL(abs_keyword_arg);
658 auto args_abs = abs_keyword_arg->get_arg();
659 auto args_obj = BuildPyObject(args_abs->BuildValue());
660 // if the args is none but the type is not none means the input is a variable.
661 if (!args_abs->isa<AbstractNone>() && py::isinstance<py::none>(args_obj)) {
662 return py::none();
663 }
664 return BuildPyObject(abs_base->BuildValue());
665 }
666
AbstractListValueToPython(const AbstractList * list_abs)667 py::object AbstractListValueToPython(const AbstractList *list_abs) {
668 MS_EXCEPTION_IF_NULL(list_abs);
669 if (list_abs->dynamic_len()) {
670 return py::none();
671 }
672 const auto &elements = list_abs->elements();
673 size_t len = elements.size();
674 py::list value_list(len);
675 for (size_t i = 0; i < len; ++i) {
676 value_list[i] = ConvertAbstractToPython(elements[i], true)[ATTR_VALUE];
677 }
678 return value_list;
679 }
680
AbstractListToPython(const AbstractBasePtr & abs_base,bool only_convert_value)681 py::dict AbstractListToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
682 auto arg_list = dyn_cast_ptr<AbstractList>(abs_base);
683 MS_EXCEPTION_IF_NULL(arg_list);
684 auto dic = py::dict();
685 if (only_convert_value) {
686 dic[ATTR_VALUE] = AbstractListValueToPython(arg_list);
687 return dic;
688 }
689 if (arg_list->dynamic_len()) {
690 auto elem_out = ConvertAbstractToPython(arg_list->dynamic_len_element_abs());
691 dic[ATTR_VALUE] = py::none();
692 dic[ATTR_SHAPE] = elem_out[ATTR_SHAPE];
693 dic[ATTR_DTYPE] = elem_out[ATTR_DTYPE];
694 return dic;
695 }
696 size_t len = arg_list->size();
697 py::list shape_list(len);
698 py::list dtype_list(len);
699 py::list value_list(len);
700 std::vector<py::dict> res;
701
702 for (size_t i = 0; i < len; i++) {
703 py::dict out = ConvertAbstractToPython(arg_list->elements()[i]);
704 res.push_back(out);
705 shape_list[i] = out[ATTR_SHAPE];
706 dtype_list[i] = out[ATTR_DTYPE];
707 value_list[i] = out[ATTR_VALUE];
708 }
709
710 dic[ATTR_SHAPE] = shape_list;
711 dic[ATTR_DTYPE] = dtype_list;
712 dic[ATTR_VALUE] = value_list;
713 return dic;
714 }
715
ConvertAbstractTensorToPython(const AbstractBasePtr & abs_base,bool only_convert_value,py::dict * dic)716 void ConvertAbstractTensorToPython(const AbstractBasePtr &abs_base, bool only_convert_value, py::dict *dic) {
717 auto arg_tensor = dyn_cast_ptr<AbstractTensor>(abs_base);
718 MS_EXCEPTION_IF_NULL(dic);
719 MS_EXCEPTION_IF_NULL(arg_tensor);
720 if (only_convert_value) {
721 (*dic)[ATTR_VALUE] = BuildPyObject(arg_tensor->BuildValue());
722 return;
723 }
724 MS_EXCEPTION_IF_NULL(arg_tensor->shape());
725 (*dic)[ATTR_SHAPE] = arg_tensor->shape()->shape();
726
727 (*dic)[ATTR_DTYPE] = arg_tensor->BuildType();
728 (*dic)[ATTR_VALUE] = BuildPyObject(arg_tensor->BuildValue());
729 }
730 namespace {
GetPyObjForPrimitiveAbstract(const PrimitiveAbstractClosurePtr & prim_abs)731 py::object GetPyObjForPrimitiveAbstract(const PrimitiveAbstractClosurePtr &prim_abs) {
732 MS_EXCEPTION_IF_NULL(prim_abs);
733 auto prim = prim_abs->BuildValue();
734 if (prim == nullptr) {
735 return py::none();
736 }
737 if (prim->isa<prim::DoSignaturePrimitive>()) {
738 auto do_sig_prim = prim->cast_ptr<prim::DoSignaturePrimitive>();
739 auto value = do_sig_prim->function();
740 MS_EXCEPTION_IF_NULL(value);
741 if (!value->isa<PrimitivePy>()) {
742 return py::none();
743 }
744 auto prim_py = value->cast_ptr<PrimitivePy>();
745 return prim_py->GetPyObj();
746 }
747 if (prim->isa<PrimitivePy>()) {
748 auto prim_py = prim->cast_ptr<PrimitivePy>();
749 return prim_py->GetPyObj();
750 }
751 return py::none();
752 }
753 } // namespace
754
ConvertAbstractFunctionToPython(const AbstractBasePtr & abs_base,py::dict * dic)755 void ConvertAbstractFunctionToPython(const AbstractBasePtr &abs_base, py::dict *dic) {
756 MS_EXCEPTION_IF_NULL(dic);
757 MS_EXCEPTION_IF_NULL(abs_base);
758 (*dic)[ATTR_SHAPE] = py::none();
759 (*dic)[ATTR_DTYPE] = abs_base->BuildType();
760 (*dic)[ATTR_VALUE] = py::none();
761 if (abs_base->isa<PartialAbstractClosure>()) {
762 auto partial_abs = abs_base->cast<PartialAbstractClosurePtr>();
763 AbstractBasePtrList args = partial_abs->args();
764 if (!args.empty()) {
765 auto value = args[0]->BuildValue();
766 MS_EXCEPTION_IF_NULL(value);
767 auto value_obj = value->cast_ptr<parse::ClassType>();
768 if (value_obj != nullptr) {
769 (*dic)[ATTR_DTYPE] = std::make_shared<TypeType>();
770 (*dic)[ATTR_VALUE] = value_obj->obj();
771 }
772 }
773 }
774 if (abs_base->isa<PrimitiveAbstractClosure>()) {
775 (*dic)[ATTR_VALUE] = GetPyObjForPrimitiveAbstract(abs_base->cast<PrimitiveAbstractClosurePtr>());
776 }
777 }
778
CheckType(const TypePtr & expected_type,const TypePtr & x)779 bool CheckType(const TypePtr &expected_type, const TypePtr &x) {
780 // As x and predicate both are mindspore type statically, here we only to judge whether
781 // x is predicate or is a subclass of predicate.
782 return IsIdentidityOrSubclass(x, expected_type);
783 }
784
785 // Join all types in args_type_list;
TypeJoin(const TypePtrList & args_type_list)786 TypePtr TypeJoin(const TypePtrList &args_type_list) {
787 if (args_type_list.empty()) {
788 MS_LOG(INTERNAL_EXCEPTION) << "args_type_list is empty";
789 }
790
791 TypePtr type_tmp = args_type_list[0];
792 for (std::size_t i = 1; i < args_type_list.size(); i++) {
793 type_tmp = abstract::TypeJoin(type_tmp, args_type_list[i]);
794 }
795 return type_tmp;
796 }
797
CheckTypeList(const TypePtr & predicate,const TypePtrList & args_type_list)798 TypePtr CheckTypeList(const TypePtr &predicate, const TypePtrList &args_type_list) {
799 MS_EXCEPTION_IF_NULL(predicate);
800 for (const auto &arg_type : args_type_list) {
801 MS_EXCEPTION_IF_NULL(arg_type);
802 if (!CheckType(predicate, arg_type)) {
803 MS_LOG(INTERNAL_EXCEPTION) << "The expected is " << predicate->ToString() << ", not " << arg_type->ToString();
804 }
805 }
806 return TypeJoin(args_type_list);
807 }
808 } // namespace
809
UnknownAbstract(const AbstractBasePtr & abs_base)810 void UnknownAbstract(const AbstractBasePtr &abs_base) {
811 auto value = abs_base->BuildValue();
812 MS_EXCEPTION_IF_NULL(value);
813 if ((*value == *kValueAny)) {
814 auto value_desc = abs_base->value_desc();
815 MS_EXCEPTION(TypeError) << "Unsupported parameter " << (value_desc.empty() ? "type" : value_desc)
816 << " for python primitive." << abs_base->ToString();
817 }
818 MS_EXCEPTION(TypeError) << "Unsupported parameter type for python primitive, the parameter value is "
819 << value->ToString();
820 }
821
ConvertAbstractToPython(const AbstractBasePtr & abs_base,bool only_convert_value)822 py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base, bool only_convert_value) {
823 MS_EXCEPTION_IF_NULL(abs_base);
824 auto dic = py::dict();
825 if (abs_base->isa<AbstractTensor>()) {
826 ConvertAbstractTensorToPython(abs_base, only_convert_value, &dic);
827 } else if (abs_base->isa<AbstractScalar>() || abs_base->isa<AbstractType>()) {
828 ShapeVector shape;
829 dic[ATTR_SHAPE] = shape;
830 dic[ATTR_DTYPE] = abs_base->BuildType();
831 dic[ATTR_VALUE] = BuildPyObject(abs_base->BuildValue());
832 } else if (abs_base->isa<AbstractTuple>()) {
833 return AbstractTupleToPython(abs_base, only_convert_value);
834 } else if (abs_base->isa<AbstractList>()) {
835 return AbstractListToPython(abs_base, only_convert_value);
836 } else if (abs_base->isa<AbstractDictionary>()) {
837 return AbstractDictionaryToPython(abs_base);
838 } else if (abs_base->isa<AbstractSlice>()) {
839 auto arg_slice = dyn_cast_ptr<AbstractSlice>(abs_base);
840 ShapeVector shape;
841 dic[ATTR_SHAPE] = shape;
842 dic[ATTR_DTYPE] = arg_slice->BuildType();
843 dic[ATTR_VALUE] = BuildPyObject(arg_slice->BuildValue());
844 } else if (abs_base->isa<AbstractRowTensor>()) {
845 auto arg = dyn_cast_ptr<AbstractRowTensor>(abs_base);
846 MS_EXCEPTION_IF_NULL(arg->shape());
847 dic[ATTR_SHAPE] = arg->shape()->shape();
848 dic[ATTR_DTYPE] = arg->BuildType();
849 dic[ATTR_VALUE] = BuildPyObject(arg->BuildValue());
850 } else if (abs_base->isa<AbstractCOOTensor>()) {
851 auto arg = dyn_cast_ptr<AbstractCOOTensor>(abs_base);
852 MS_EXCEPTION_IF_NULL(arg->shape());
853 AbstractBasePtrList sparse_shape = arg->shape()->elements();
854 ShapeVector sparse_shape_vector;
855 (void)std::transform(sparse_shape.begin(), sparse_shape.end(), std::back_inserter(sparse_shape_vector),
856 [](const AbstractBasePtr &e) -> int64_t {
857 MS_EXCEPTION_IF_NULL(e);
858 MS_EXCEPTION_IF_NULL(e->cast_ptr<AbstractScalar>());
859 ValuePtr value = e->cast_ptr<AbstractScalar>()->BuildValue();
860 return GetValue<int64_t>(value);
861 });
862 dic[ATTR_SHAPE] = sparse_shape_vector;
863 dic[ATTR_DTYPE] = arg->BuildType();
864 dic[ATTR_VALUE] = BuildPyObject(arg->BuildValue());
865 } else if (abs_base->isa<AbstractCSRTensor>()) {
866 auto arg = dyn_cast_ptr<AbstractCSRTensor>(abs_base);
867 MS_EXCEPTION_IF_NULL(arg->shape());
868 AbstractBasePtrList sparse_shape = arg->shape()->elements();
869 ShapeVector sparse_shape_vector;
870 (void)std::transform(sparse_shape.begin(), sparse_shape.end(), std::back_inserter(sparse_shape_vector),
871 [](const AbstractBasePtr &e) -> int64_t {
872 MS_EXCEPTION_IF_NULL(e);
873 MS_EXCEPTION_IF_NULL(e->cast_ptr<AbstractScalar>());
874 ValuePtr value = e->cast_ptr<AbstractScalar>()->BuildValue();
875 return GetValue<int64_t>(value);
876 });
877 dic[ATTR_SHAPE] = sparse_shape_vector;
878 dic[ATTR_DTYPE] = arg->BuildType();
879 dic[ATTR_VALUE] = BuildPyObject(arg->BuildValue());
880 } else if (abs_base->isa<AbstractEllipsis>()) {
881 dic[ATTR_SHAPE] = py::none();
882 dic[ATTR_DTYPE] = py::ellipsis();
883 dic[ATTR_VALUE] = py::ellipsis();
884 } else if (abs_base->isa<AbstractNone>()) {
885 dic[ATTR_SHAPE] = py::none();
886 dic[ATTR_DTYPE] = py::none();
887 dic[ATTR_VALUE] = py::none();
888 } else if (abs_base->isa<AbstractFunction>()) {
889 ConvertAbstractFunctionToPython(abs_base, &dic);
890 } else if (abs_base->isa<AbstractClass>()) {
891 auto arg_class = dyn_cast_ptr<AbstractClass>(abs_base);
892 ShapeVector shape;
893 dic[ATTR_SHAPE] = shape;
894 dic[ATTR_DTYPE] = arg_class->BuildType();
895 dic[ATTR_VALUE] = BuildPyObject(arg_class->BuildValue());
896 } else if (abs_base->isa<AbstractUndetermined>()) {
897 auto arg = dyn_cast_ptr<AbstractUndetermined>(abs_base);
898 dic[ATTR_SHAPE] = py::none();
899 dic[ATTR_DTYPE] = arg->BuildType();
900 dic[ATTR_VALUE] = py::none();
901 } else if (abs_base->isa<AbstractMonad>()) {
902 dic[ATTR_SHAPE] = py::none();
903 dic[ATTR_DTYPE] = abs_base->BuildType();
904 dic[ATTR_VALUE] = py::none();
905 } else if (abs_base->isa<AbstractKeywordArg>()) {
906 dic[ATTR_SHAPE] = py::none();
907 dic[ATTR_DTYPE] = abs_base->BuildType();
908 dic[ATTR_VALUE] = AbstractKWArgsToPython(abs_base);
909 } else {
910 UnknownAbstract(abs_base);
911 }
912 return dic;
913 }
914
915 namespace {
CheckCustomPrimOutputInferResult(const PrimitivePtr & prim,const AbstractBasePtr & res_spec)916 void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) {
917 MS_EXCEPTION_IF_NULL(prim);
918 MS_EXCEPTION_IF_NULL(res_spec);
919 const string kOutputNum = "output_num";
920 if (prim->IsCustomPrim()) {
921 // Raise error if output_num is not match the infer result.
922 auto output_num_value = prim->GetAttr(kOutputNum);
923 if (output_num_value == nullptr) {
924 MS_LOG(DEBUG) << "The output num may no need to check";
925 return;
926 }
927 int64_t output_num = GetValue<int64_t>(output_num_value);
928 if (res_spec->isa<AbstractTensor>() && output_num != 1) {
929 MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
930 << "]'s attribute[output_num]: " << output_num << ", not matches the infer result "
931 << res_spec->ToString();
932 } else if (res_spec->isa<AbstractTuple>() &&
933 (res_spec->cast_ptr<AbstractTuple>()->size() != LongToSize(output_num))) {
934 MS_LOG(EXCEPTION) << "Custom operator primitive[" << prim->ToString()
935 << "]'s attribute[output_num]: " << output_num << ", not matches the infer result "
936 << res_spec->ToString();
937 }
938 }
939 }
940
IsMonadType(const py::object & type_obj)941 static bool IsMonadType(const py::object &type_obj) {
942 if (py::isinstance<Type>(type_obj)) {
943 auto type = type_obj.cast<Type *>();
944 return type->isa<MonadType>();
945 }
946 return false;
947 }
948
ToMonadAbstract(const py::object & type_obj)949 AbstractBasePtr ToMonadAbstract(const py::object &type_obj) {
950 if (py::isinstance<Type>(type_obj)) {
951 auto type = type_obj.cast<Type *>();
952 if (!type->isa<MonadType>()) {
953 MS_LOG(INTERNAL_EXCEPTION) << "Not a monad type object: " << py::str(type_obj);
954 }
955 return abstract::MakeMonadAbstract(type->cast<MonadTypePtr>());
956 }
957 MS_LOG(INTERNAL_EXCEPTION) << "Not a type object: " << py::str(type_obj);
958 }
959
GetPyAbsItemOfTupleOut(const py::object & output,const size_t index)960 py::object GetPyAbsItemOfTupleOut(const py::object &output, const size_t index) {
961 auto out_dict = output.cast<py::dict>();
962 auto type_obj = out_dict[ATTR_DTYPE];
963 auto shape_obj = out_dict[ATTR_SHAPE];
964 auto out_item = py::dict();
965 auto shape_tuple = shape_obj.cast<py::tuple>();
966 auto typeid_tuple = type_obj.cast<py::tuple>();
967 out_item[ATTR_DTYPE] = typeid_tuple[index];
968 out_item[ATTR_SHAPE] = shape_tuple[index];
969 out_item[ATTR_VALUE] = py::none();
970 return out_item;
971 }
972
MakePyInferRes2AbstractTensor(const py::object & shape_obj,const py::object & type_obj)973 AbstractBasePtr MakePyInferRes2AbstractTensor(const py::object &shape_obj, const py::object &type_obj) {
974 auto res_vec = shape_obj.cast<ShapeVector>();
975 auto res_dtype = type_obj.cast<TypePtr>();
976
977 auto res_shape = std::make_shared<abstract::Shape>(res_vec);
978 AbstractBasePtr tensor = MakeAbstractTensor(res_shape, res_dtype);
979 return tensor;
980 }
981
MakePyInferRes2Abstract(const py::object & output)982 AbstractBasePtr MakePyInferRes2Abstract(const py::object &output) {
983 auto out_dict = output.cast<py::dict>();
984 auto type_obj = out_dict[ATTR_DTYPE];
985 auto shape_obj = out_dict[ATTR_SHAPE];
986 if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
987 auto res_vec = shape_obj.cast<ShapeVector>();
988 auto res_dtype = type_obj.cast<TypePtr>();
989 MS_EXCEPTION_IF_NULL(res_dtype);
990 // if the size of shape list is empty, return an scalar abstract
991 if (res_vec.empty() && (!res_dtype->isa<TensorType>())) {
992 abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kValueAny, res_dtype);
993 return abs_scalar;
994 }
995 return MakePyInferRes2AbstractTensor(shape_obj, type_obj);
996 } else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
997 auto typeid_tuple = type_obj.cast<py::tuple>();
998 AbstractBasePtrList ptr_list;
999 for (size_t it = 0; it < typeid_tuple.size(); ++it) {
1000 auto output_it = GetPyAbsItemOfTupleOut(output, it);
1001 auto tensor_it = MakePyInferRes2Abstract(output_it);
1002 ptr_list.push_back(tensor_it);
1003 }
1004 auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
1005 return tuple;
1006 } else if (py::isinstance<py::list>(shape_obj) && py::isinstance<py::list>(type_obj)) {
1007 auto typeid_list = type_obj.cast<py::list>();
1008 AbstractBasePtrList ptr_list;
1009 for (size_t it = 0; it < typeid_list.size(); ++it) {
1010 auto output_it = GetPyAbsItemOfTupleOut(output, it);
1011 auto tensor_it = MakePyInferRes2Abstract(output_it);
1012 ptr_list.push_back(tensor_it);
1013 }
1014 auto list = std::make_shared<abstract::AbstractList>(ptr_list);
1015 return list;
1016 } else if (shape_obj.is_none() && type_obj.is_none()) {
1017 // AbstractNone indicates there is no output for this CNode node.
1018 auto abstract_none = std::make_shared<abstract::AbstractNone>();
1019 return abstract_none;
1020 } else if (IsMonadType(type_obj)) {
1021 // Return monad abstract if it is monad type.
1022 return ToMonadAbstract(type_obj);
1023 } else {
1024 MS_LOG(INTERNAL_EXCEPTION) << "Python evaluator return invalid shape or type. " << py::str(type_obj);
1025 }
1026 }
1027 } // namespace
PreparePyInputs(const AbstractBasePtrList & args)1028 py::tuple PreparePyInputs(const AbstractBasePtrList &args) {
1029 // The monad parameter is defined at the end of the parameter and needs to be ignored
1030 std::size_t args_size = args.size() - GetAbstractMonadNum(args);
1031 py::tuple py_args(args_size);
1032 for (size_t i = 0; i < args_size; i++) {
1033 py_args[i] = ConvertAbstractToPython(args[i]);
1034 }
1035 return py_args;
1036 }
1037
PyInferRes2Abstract(const PrimitivePyPtr & prim_py,const py::dict & output)1038 AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
1039 // Convert to AbstractValue based on type and shape
1040 if (output[ATTR_VALUE].is_none()) {
1041 return MakePyInferRes2Abstract(output);
1042 }
1043
1044 // Convert pyobject to Value, then to AbstractValue
1045 auto out_dtype = output[ATTR_DTYPE];
1046 TypePtr dtype = py::isinstance<Type>(out_dtype) ? out_dtype.cast<TypePtr>() : nullptr;
1047 ValuePtr converted_ret = nullptr;
1048 bool converted = parse::ConvertData(output[ATTR_VALUE], &converted_ret, false, dtype);
1049 if (!converted) {
1050 MS_LOG(INTERNAL_EXCEPTION) << "Convert data failed";
1051 }
1052 auto res_spec = FromValue(converted_ret);
1053 MS_EXCEPTION_IF_NULL(res_spec);
1054 if (res_spec->isa<AbstractTensor>()) {
1055 // Replace to tensor constant node in specialize
1056 auto res_tensor = res_spec->cast<AbstractTensorPtr>();
1057 res_tensor->set_value(converted_ret);
1058 }
1059 CheckCustomPrimOutputInferResult(prim_py, res_spec);
1060 return res_spec;
1061 }
1062
RunPyInferValue(const AnalysisEnginePtr &,const AbstractBasePtr & abs_base,const AbstractBasePtrList & args)1063 EvalResultPtr StandardPrimEvaluator::RunPyInferValue(const AnalysisEnginePtr &, const AbstractBasePtr &abs_base,
1064 const AbstractBasePtrList &args) {
1065 auto prim_py = dyn_cast<PrimitivePy>(prim_);
1066 if (prim_py == nullptr) {
1067 MS_LOG(INTERNAL_EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
1068 }
1069 // Call checking method 'infer_value' for python primitive
1070 MS_LOG(DEBUG) << "Begin input args checking for: " << prim_py->ToString();
1071 auto py_args = PreparePyInputs(args);
1072 py::tuple py_vals(py_args.size());
1073 MS_EXCEPTION_IF_NULL(prim_);
1074 auto added_attrs = prim_->evaluate_added_attrs();
1075 for (size_t i = 0; i < py_args.size(); ++i) {
1076 py_vals[i] = py_args[i][ATTR_VALUE];
1077 }
1078 py::object py_ret = prim_py->RunInferValue(py_vals);
1079 if (py::isinstance<py::none>(py_ret)) {
1080 return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
1081 }
1082 // Convert pyobject to Value, then to AbstractValue
1083 ValuePtr converted_ret = nullptr;
1084 MS_EXCEPTION_IF_NULL(abs_base);
1085 TypePtr dtype = abs_base->BuildType();
1086 bool converted = parse::ConvertData(py_ret, &converted_ret, false, dtype);
1087 if (!converted) {
1088 MS_LOG(INTERNAL_EXCEPTION) << "Convert data failed";
1089 }
1090 auto res_spec = FromValue(converted_ret);
1091 MS_EXCEPTION_IF_NULL(res_spec);
1092 if (res_spec->isa<AbstractTensor>()) {
1093 // Replace to tensor constant node in specialize
1094 auto res_tensor = res_spec->cast_ptr<AbstractTensor>();
1095 res_tensor->set_value(converted_ret);
1096 }
1097 return std::make_shared<EvalResult>(res_spec, std::make_shared<AttrValueMap>(added_attrs));
1098 }
1099
1100 // Apply EvalResult from cached result for a given primitive.
ApplyCacheEvalResult(const PrimitivePtr & prim,const EvalResultPtr & result)1101 static inline EvalResultPtr ApplyCacheEvalResult(const PrimitivePtr &prim, const EvalResultPtr &result) {
1102 auto &attrs = result->attribute();
1103 if (attrs != nullptr) {
1104 prim->set_evaluate_added_attrs(*attrs);
1105 }
1106 return std::make_shared<EvalResult>(result->abstract()->Clone(), attrs);
1107 }
1108
EvalPyCheckPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)1109 EvalResultPtr StandardPrimEvaluator::EvalPyCheckPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
1110 // Try to get infer result from evaluator cache.
1111 auto eval_result = evaluator_cache_mgr_->GetValue(args);
1112 if (eval_result != nullptr) {
1113 // Evaluator cache hit.
1114 return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
1115 }
1116 // In pynative mode (engine == nullptr), it is difficult to set added_attrs to
1117 // python object by C++ code, so we disable global eval cache in pynative mode.
1118 const bool enable_global_cache = (engine != nullptr);
1119 if (enable_global_cache) {
1120 // Try to get infer result from global primitive evaluate cache.
1121 eval_result = eval_cache_->Get(prim_, args);
1122 if (eval_result != nullptr) {
1123 // Global primitive evaluate cache hit.
1124 evaluator_cache_mgr_->SetValue(args, eval_result);
1125 return ApplyCacheEvalResult(prim_, eval_result);
1126 }
1127 }
1128 // PrimitivePy is expected for EvalPyCheckPrim.
1129 auto prim_py = dyn_cast<PrimitivePy>(prim_);
1130 if (prim_py == nullptr) {
1131 MS_LOG(INTERNAL_EXCEPTION) << "The primitive with type 'kPrimTypePyCheck' should be a python primitive.";
1132 }
1133 // We should copy attributes before running check and infer,
1134 // since they may be changed during check and infer.
1135 auto input_attrs = prim_py->attrs();
1136 prim_py->BeginRecordAddAttr();
1137 auto py_args = PreparePyInputs(args);
1138 // Call checking method '__check__' for subclass of 'PrimitiveWithCheck'.
1139 prim_py->RunCheck(py_args);
1140 auto abs = eval_impl_.InferShapeAndType(nullptr, prim_py, args);
1141 MS_EXCEPTION_IF_NULL(abs);
1142 prim_py->EndRecordAddAttr();
1143 auto &added_attrs = prim_py->evaluate_added_attrs();
1144 eval_result = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>(added_attrs));
1145 if (py::hasattr(prim_py->GetPyObj(), PY_PRIM_METHOD_INFER_VALUE)) {
1146 // Call 'infer_value()' method if it is existed, for constant propagation.
1147 eval_result = RunPyInferValue(engine, eval_result->abstract(), args);
1148 }
1149 // Save infer result to caches (evaluator cache and global cache).
1150 if (enable_global_cache) {
1151 eval_cache_->Put(prim_py, std::move(input_attrs), args, eval_result);
1152 }
1153 evaluator_cache_mgr_->SetValue(args, eval_result);
1154 return eval_result;
1155 }
1156
1157 namespace {
CheckSequenceArgumentForCppPrimitive(const PrimitivePtr & prim,const AbstractBasePtrList & args)1158 void CheckSequenceArgumentForCppPrimitive(const PrimitivePtr &prim, const AbstractBasePtrList &args) {
1159 // To check tuple/list operations with a white list of Python primitive.
1160 MS_EXCEPTION_IF_NULL(prim);
1161 auto iter = prims_transparent_pass_sequence.find(prim->name());
1162 if (iter == prims_transparent_pass_sequence.end()) {
1163 // The primitive use all elements of each argument.
1164 for (size_t i = 0; i < args.size(); ++i) {
1165 MS_EXCEPTION_IF_NULL(args[i]);
1166 if (args[i]->isa<abstract::AbstractSequence>()) {
1167 MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming tuple/list arguments[" << i
1168 << "]: " << args[i]->ToString();
1169 SetSequenceElementsUseFlagsRecursively(args[i], true);
1170 }
1171 }
1172 return;
1173 }
1174
1175 // It's transparent pass primitive or using partial elements primitive.
1176 auto index_list = iter->second;
1177 if (index_list.empty()) {
1178 MS_LOG(INTERNAL_EXCEPTION) << "The primitive list should not be empty for " << prim->name();
1179 }
1180 // Ignore all arguments, no need checking if AbstractSequence.
1181 if (index_list[0] == -1) {
1182 return;
1183 }
1184 // Check the specific arguments index.
1185 for (size_t i = 0; i < args.size(); ++i) {
1186 MS_EXCEPTION_IF_NULL(args[i]);
1187 if (!args[i]->isa<abstract::AbstractSequence>()) {
1188 continue;
1189 }
1190 if (std::find(index_list.begin(), index_list.end(), i) == index_list.end()) {
1191 // For current tuple/list argument, it's not a primitive of total transparent pass or partial element use.
1192 MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming specific tuple/list arguments[" << i
1193 << "]: " << args[i]->ToString();
1194 SetSequenceElementsUseFlagsRecursively(args[i], true);
1195 }
1196 }
1197 }
1198
CheckSequenceArgumentForPythonPrimitive(const PrimitivePtr & prim,const AbstractBasePtrList & args)1199 void CheckSequenceArgumentForPythonPrimitive(const PrimitivePtr &prim, const AbstractBasePtrList &args) {
1200 MS_EXCEPTION_IF_NULL(prim);
1201 // Consider all primitive implemented python infer() real use the tuple/list arguments.
1202 for (size_t i = 0; i < args.size(); ++i) {
1203 MS_EXCEPTION_IF_NULL(args[i]);
1204 if (args[i]->isa<abstract::AbstractSequence>()) {
1205 MS_EXCEPTION_IF_NULL(args[i]);
1206 MS_LOG(DEBUG) << "Primitive \'" << prim->name() << "\' is consuming tuple/list arguments[" << i
1207 << "]: " << args[i]->ToString();
1208 SetSequenceElementsUseFlagsRecursively(args[i], true);
1209 }
1210 }
1211 }
1212
ValidateArgOptional(const AbstractBasePtr & abs_arg,const ops::OpInputArg & input_arg)1213 bool ValidateArgOptional(const AbstractBasePtr &abs_arg, const ops::OpInputArg &input_arg) {
1214 if (!input_arg.is_optional_) {
1215 return false;
1216 }
1217
1218 auto abs_type = abs_arg->BuildType();
1219 MS_EXCEPTION_IF_NULL(abs_type);
1220 return abs_type->isa<TypeNone>();
1221 }
1222 } // namespace
1223
PrimitiveFunctionEvaluator(const PrimitivePtr & prim_func)1224 PrimitiveFunctionEvaluator::PrimitiveFunctionEvaluator(const PrimitivePtr &prim_func)
1225 : TrivialPrimEvaluator("PrimitiveFunctionEvaluator"), prim_func_(prim_func) {
1226 frontend_func_impl_ = mindspore::ops::GetOpFrontendFuncImplPtr(prim_func->name());
1227 op_def_ = mindspore::ops::GetOpDef(prim_func->name());
1228 }
1229
HasAbstractUndetermined(const AbstractBasePtr & abs)1230 bool HasAbstractUndetermined(const AbstractBasePtr &abs) {
1231 if (abs->isa<AbstractSequence>()) {
1232 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
1233 return std::any_of(abs_seq->elements().cbegin(), abs_seq->elements().cend(), HasAbstractUndetermined);
1234 }
1235 return abs->IsSameTypeId(AbstractUndetermined::kTypeId);
1236 }
1237
CheckArgsSizeAndType(const AbstractBasePtrList & abs_args)1238 void PrimitiveFunctionEvaluator::CheckArgsSizeAndType(const AbstractBasePtrList &abs_args) {
1239 auto op_args = op_def_->args_;
1240 // Ignore monad.
1241 AbstractBasePtrList real_abs_args;
1242 (void)std::copy_if(abs_args.cbegin(), abs_args.cend(), std::back_inserter(real_abs_args),
1243 [](const AbstractBasePtr &abs) {
1244 MS_EXCEPTION_IF_NULL(abs);
1245 return !abs->isa<abstract::AbstractMonad>();
1246 });
1247 // Check inputs number.
1248 if (op_args.size() != real_abs_args.size()) {
1249 MS_EXCEPTION(TypeError) << "For Operator[" << op_def_->name_ << "], the inputs number should be " << op_args.size()
1250 << " but got " << real_abs_args.size() << ".";
1251 }
1252
1253 // Check inputs type.
1254 for (size_t i = 0; i < op_args.size(); i++) {
1255 if (HasAbstractUndetermined(real_abs_args[i])) {
1256 continue;
1257 }
1258 if (!ValidateArgOptional(real_abs_args[i], op_args[i]) &&
1259 !ops::ValidateArgsType(real_abs_args[i], op_args[i].arg_dtype_)) {
1260 std::vector<std::string> op_type_list;
1261 for (const auto &op_abs : real_abs_args) {
1262 (void)op_type_list.emplace_back(op_abs->BuildType()->ToString());
1263 }
1264 MS_INTERNAL_EXCEPTION(TypeError)
1265 << "For Operator[" << op_def_->name_ << "], " << op_args[i].arg_name_ << "'s type '"
1266 << real_abs_args[i]->BuildType()->ToString() << "' does not match expected type '"
1267 << ops::EnumToString(op_args[i].arg_dtype_)
1268 << "'.\nThe reason may be: lack of definition of type cast, or incorrect type when creating the node.";
1269 }
1270 }
1271 }
1272
CheckAndInfer(const AbstractBasePtrList & args)1273 AbstractBasePtr PrimitiveFunctionEvaluator::CheckAndInfer(const AbstractBasePtrList &args) {
1274 if (op_def_ != nullptr) {
1275 (void)op_def_->func_impl_.CheckValidation(prim_func_, args);
1276 if (frontend_func_impl_ != nullptr) {
1277 auto infer_result = frontend_func_impl_->InferAbstract(prim_func_, args);
1278 if (infer_result != nullptr) {
1279 return infer_result;
1280 }
1281 }
1282
1283 auto type = op_def_->func_impl_.InferType(prim_func_, args);
1284 auto shape = op_def_->func_impl_.InferShape(prim_func_, args);
1285 return MakeAbstract(shape, type);
1286 }
1287 MS_LOG(INTERNAL_EXCEPTION) << "Find infer function failed, primitive: " << prim_func_->ToString();
1288 }
1289
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)1290 EvalResultPtr PrimitiveFunctionEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
1291 MS_EXCEPTION_IF_NULL(prim_func_);
1292 CheckArgsSizeAndType(args);
1293 // To check tuple/list operations with a white list of Python primitive.
1294 CheckSequenceArgumentForCppPrimitive(prim_func_, args);
1295
1296 bool need_infer_value = std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
1297 MS_EXCEPTION_IF_NULL(abs);
1298 auto value = abs->BuildValue();
1299 return (value != nullptr && !value->isa<Monad>() && !value->isa<FuncGraph>());
1300 });
1301
1302 AbstractBasePtr abs_base = nullptr;
1303 prim_func_->BeginRecordAddAttr();
1304 if (need_infer_value && frontend_func_impl_ != nullptr) {
1305 auto value = frontend_func_impl_->InferValue(prim_func_, args);
1306 if (value != nullptr && !value->ContainsValueAny()) {
1307 abs_base = value->ToAbstract();
1308 prim_func_->EndRecordAddAttr();
1309 auto added_attrs = prim_func_->evaluate_added_attrs();
1310 return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
1311 }
1312 }
1313 abs_base = CheckAndInfer(args);
1314 MS_EXCEPTION_IF_NULL(abs_base);
1315 prim_func_->EndRecordAddAttr();
1316 const auto &added_attrs = prim_func_->evaluate_added_attrs();
1317 return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
1318 }
1319
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)1320 EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
1321 // To check tuple/list operations with a white list of Python primitive.
1322 CheckSequenceArgumentForCppPrimitive(prim_, args);
1323 MS_EXCEPTION_IF_NULL(prim_);
1324 if (prims_to_skip_undetermined_infer.find(prim_->name()) == prims_to_skip_undetermined_infer.end()) {
1325 auto res_abstract = EvalUndeterminedArgs(args);
1326 if (res_abstract != nullptr) {
1327 MS_LOG(DEBUG) << "StandardPrimEvaluator eval Undetermined";
1328 return res_abstract;
1329 }
1330 }
1331 if (prim_->prim_type() == PrimType::kPrimTypePyCheck) {
1332 return EvalPyCheckPrim(engine, args);
1333 }
1334 bool need_infer_value = std::all_of(args.begin(), args.end(), [](const AbstractBasePtr &abs) -> bool {
1335 MS_EXCEPTION_IF_NULL(abs);
1336 auto value = abs->BuildValue();
1337 return (value != nullptr && !value->ContainsValueAny() && !value->isa<None>() && !value->isa<Monad>() &&
1338 !value->isa<FuncGraph>());
1339 });
1340
1341 AbstractBasePtr abs_base = nullptr;
1342 ValuePtr value = nullptr;
1343 prim_->BeginRecordAddAttr();
1344 if (need_infer_value && eval_impl_.IsImplInferValue()) {
1345 value = eval_impl_.InferValue(prim_, args);
1346 if (value != nullptr) {
1347 abs_base = value->ToAbstract();
1348 prim_->EndRecordAddAttr();
1349 auto added_attrs = prim_->evaluate_added_attrs();
1350 return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
1351 }
1352 }
1353 abs_base = eval_impl_.InferShapeAndType(nullptr, prim_, args);
1354 MS_EXCEPTION_IF_NULL(abs_base);
1355 prim_->EndRecordAddAttr();
1356 const auto &added_attrs = prim_->evaluate_added_attrs();
1357 return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>(added_attrs));
1358 }
1359
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args)1360 EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args) {
1361 // Consider all primitive implemented python infer() real use the tuple/list arguments.
1362 CheckSequenceArgumentForPythonPrimitive(prim_py_, args);
1363
1364 // Ensure input arguments are evaluated.
1365 auto res_abstract = EvalUndeterminedArgs(args);
1366 if (res_abstract != nullptr) {
1367 MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
1368 return res_abstract;
1369 }
1370 MS_EXCEPTION_IF_NULL(prim_py_);
1371 auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT);
1372 if (!forbid_reuse) {
1373 // Try to get infer result from evaluator cache.
1374 EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args);
1375 if (eval_result != nullptr) {
1376 MS_EXCEPTION_IF_NULL(eval_result->abstract());
1377 return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
1378 }
1379 }
1380 // In pynative mode (engine == nullptr), it is difficult to set added_attrs to
1381 // python object by C++ code, so we disable global eval cache in pynative mode.
1382 const bool enable_global_cache = (engine != nullptr && !forbid_reuse);
1383 if (enable_global_cache) {
1384 // Try to get infer result from global primitive eval cache.
1385 EvalResultPtr eval_result = eval_cache_->Get(prim_py_, args);
1386 if (eval_result != nullptr) {
1387 // Global cache hit.
1388 evaluator_cache_mgr_->SetValue(args, eval_result);
1389 return ApplyCacheEvalResult(prim_py_, eval_result);
1390 }
1391 }
1392 // Cache miss, run infer. We should copy attributes before
1393 // running infer, since they may be changed during infer.
1394 auto input_attrs = prim_py_->attrs();
1395 auto py_args = PreparePyInputs(args);
1396 prim_py_->BeginRecordAddAttr();
1397 py::dict output = prim_py_->RunInfer(py_args);
1398 prim_py_->EndRecordAddAttr();
1399 const auto &added_attrs = prim_py_->evaluate_added_attrs();
1400 MS_LOG(DEBUG) << "Output type is " << py::str(output);
1401 auto res_abs = PyInferRes2Abstract(prim_py_, output);
1402 MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
1403 EvalResultPtr eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
1404 // Save result to global primitive eval cache.
1405 if (enable_global_cache) {
1406 eval_cache_->Put(prim_py_, std::move(input_attrs), args, eval_result);
1407 }
1408 evaluator_cache_mgr_->SetValue(args, eval_result);
1409 return eval_result;
1410 }
1411
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args)1412 EvalResultPtr UniformPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) {
1413 auto res_abstract = EvalUndeterminedArgs(args);
1414 if (res_abstract != nullptr) {
1415 MS_LOG(DEBUG) << "UniformPrimEvaluator eval Undetermined";
1416 return res_abstract;
1417 }
1418 // if func_desc_.retval type is super class of parameter type, then make the retval type as parameter type.
1419 if (nargs_ != args.size()) {
1420 MS_LOG(INTERNAL_EXCEPTION) << "UniformPrimEvaluator expect " << nargs_ << " args, but got " << args.size()
1421 << " inputs";
1422 }
1423 TypePtr res_value_type = return_value_type_;
1424 ValuePtrList value_list;
1425 for (const auto &arg : args) {
1426 // Check if all arguments are scalar type.
1427 MS_EXCEPTION_IF_NULL(arg);
1428 if (arg->isa<AbstractScalar>()) {
1429 auto arg_scalar = dyn_cast_ptr<AbstractScalar>(arg);
1430 const auto &arg_value = arg_scalar->GetValueTrack();
1431 value_list.push_back(arg_value);
1432 } else {
1433 // Raise TypeError Expected Scalar.
1434 MS_LOG(INTERNAL_EXCEPTION) << "Expect scalar arguments for uniform primitives.";
1435 }
1436 }
1437 for (const auto &item : type_map_) {
1438 TypePtrList selections;
1439 (void)std::transform(item.second.begin(), item.second.end(), std::back_inserter(selections),
1440 [&args](size_t arg_idx) -> TypePtr {
1441 if (arg_idx >= args.size()) {
1442 MS_LOG(EXCEPTION) << "Index: " << arg_idx << " out of range: " << args.size();
1443 }
1444 MS_EXCEPTION_IF_NULL(args[arg_idx]);
1445 return args[arg_idx]->GetTypeTrack();
1446 });
1447 TypePtr res = CheckTypeList(item.first, selections);
1448 MS_EXCEPTION_IF_NULL(return_value_type_);
1449 MS_EXCEPTION_IF_NULL(item.first);
1450 if (*return_value_type_ == *(item.first)) {
1451 res_value_type = res;
1452 }
1453 }
1454
1455 ValuePtr evaluated_value = RunImpl(value_list);
1456 MS_EXCEPTION_IF_NULL(evaluated_value);
1457 if (!(*evaluated_value == *kValueAny)) {
1458 res_value_type = evaluated_value->type();
1459 }
1460 // for comparison primitives , return type shall have be specified to be bool.
1461 if (specify_out_type_ != nullptr) {
1462 res_value_type = specify_out_type_;
1463 }
1464
1465 AbstractScalarPtr abs_base = std::make_shared<AbstractScalar>(evaluated_value, res_value_type);
1466 return std::make_shared<EvalResult>(abs_base, std::make_shared<AttrValueMap>());
1467 }
1468
RunImpl(const ValuePtrList & args) const1469 ValuePtr UniformPrimEvaluator::RunImpl(const ValuePtrList &args) const {
1470 if (!eval_value_) {
1471 return kValueAny;
1472 } else {
1473 if (std::any_of(args.begin(), args.end(), [](const ValuePtr &arg) {
1474 MS_EXCEPTION_IF_NULL(arg);
1475 return arg->ContainsValueAny();
1476 })) {
1477 return kValueAny;
1478 }
1479 return impl_(args);
1480 }
1481 }
1482
1483 // Primitive implementation
1484 // static function start
1485 namespace {
InitStandardPrimEvaluator(PrimitivePtr primitive,const StandardPrimitiveImplReg eval_impl)1486 EvaluatorPtr InitStandardPrimEvaluator(PrimitivePtr primitive, const StandardPrimitiveImplReg eval_impl) {
1487 EvaluatorPtr prim_evaluator = std::make_shared<StandardPrimEvaluator>(primitive, eval_impl);
1488 return prim_evaluator;
1489 }
1490
InitUniformPrimEvaluator(const PrimitivePtr & primitive,PrimitiveImpl prim_impl,bool eval_value,const TypePtr & specify_out_type)1491 EvaluatorPtr InitUniformPrimEvaluator(const PrimitivePtr &primitive, PrimitiveImpl prim_impl, bool eval_value,
1492 const TypePtr &specify_out_type) {
1493 FunctionPtr func = nullptr;
1494 (void)prim::PrimToFunction::GetInstance().GetFunction(primitive, &func);
1495 MS_EXCEPTION_IF_NULL(func);
1496
1497 EvaluatorPtr uniform_primitive_evaluator =
1498 std::make_shared<UniformPrimEvaluator>(func, prim_impl, eval_value, specify_out_type);
1499 return uniform_primitive_evaluator;
1500 }
1501
AddToManager(const AnalysisEnginePtr & engine,const FuncGraphPtr func_graph)1502 inline void AddToManager(const AnalysisEnginePtr &engine, const FuncGraphPtr func_graph) {
1503 MS_EXCEPTION_IF_NULL(engine);
1504 FuncGraphManagerPtr manager = engine->func_graph_manager();
1505 MS_EXCEPTION_IF_NULL(manager);
1506 manager->AddFuncGraph(func_graph);
1507 }
1508
1509 enum class REQUIRE_TYPE { ATTR, METHOD };
1510
IsPyExecuteData(const AbstractBasePtr & data_abstract)1511 bool IsPyExecuteData(const AbstractBasePtr &data_abstract) {
1512 MS_EXCEPTION_IF_NULL(data_abstract);
1513 return data_abstract->isa<abstract::AbstractAny>();
1514 }
1515
CheckObjAttrValid(const TypePtr & data_type,const std::string & item_name,const AbstractBasePtr & data_args)1516 void CheckObjAttrValid(const TypePtr &data_type, const std::string &item_name, const AbstractBasePtr &data_args) {
1517 MS_EXCEPTION_IF_NULL(data_type);
1518 MS_EXCEPTION_IF_NULL(data_args);
1519 // Check if the obj's attr is invalid or decoratored by @jit_forbidden_register
1520 std::string data_type_str = TypeIdLabel(NormalizeTypeId(data_type->type_id()));
1521 if (data_args->isa<AbstractRefTensor>()) {
1522 data_type_str = "Parameter";
1523 } else if (data_args->isa<AbstractNamedTuple>()) {
1524 data_type_str = "NamedTuple";
1525 }
1526 py::module mod1 = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
1527 py::object obj_define = python_adapter::CallPyModFn(mod1, parse::PYTHON_MOD_GET_OBJ_DEFINED, data_type_str);
1528 if (py::isinstance<py::none>(obj_define)) {
1529 return;
1530 }
1531 py::module mod2 = python_adapter::GetPyModule(parse::PYTHON_MOD_MODULE);
1532 auto is_jit_forbidden_method =
1533 python_adapter::CallPyModFn(mod2, parse::PYTHON_MOD_IS_INVALID_METHOD, obj_define, data_type_str, item_name);
1534 if (py::cast<bool>(is_jit_forbidden_method) || data_args->isa<AbstractRefTensor>()) {
1535 MS_LOG(EXCEPTION) << "Failed to compile in GRAPH_MODE because the '" << data_type_str << "' object's method '"
1536 << item_name << "' is not supported in 'construct' or function with @jit decorator. "
1537 << "Try to use the '" << data_type_str << "." << item_name << "' externally "
1538 << "such as initialized in the method '__init__' before assigning"
1539 << ".\nFor more details, please refer to "
1540 << "https://www.mindspore.cn/docs/zh-CN/master/design/dynamic_graph_and_static_graph.html \n";
1541 }
1542 }
1543
SetTypeForGetAttr(const AnfNodePtr & getattr_node,const AbstractBasePtr & value_abs)1544 AnfNodePtr SetTypeForGetAttr(const AnfNodePtr &getattr_node, const AbstractBasePtr &value_abs) {
1545 // Set setattr's abstract as getattr's abstract.
1546 if (value_abs != nullptr &&
1547 (value_abs->isa<abstract::AbstractTensor>() || value_abs->isa<abstract::AbstractScalar>())) {
1548 auto type = value_abs->BuildType();
1549 auto shape = value_abs->BuildShape();
1550 fallback::SetRealType<AnfNode, Type>(getattr_node, type);
1551 fallback::SetRealShape<AnfNode, abstract::BaseShape>(getattr_node, shape);
1552 auto abs_tensor = value_abs->cast_ptr<abstract::AbstractTensor>();
1553 if (abs_tensor != nullptr) {
1554 if (abs_tensor != nullptr && abs_tensor->is_adapter()) {
1555 getattr_node->set_user_data<bool>(fallback::kIsAdapter, std::make_shared<bool>(true));
1556 }
1557 }
1558 }
1559 return getattr_node;
1560 }
1561
InterpretGetAttrNode(const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)1562 EvalResultPtr InterpretGetAttrNode(const AbstractBasePtrList &args_abs_list, const AnfNodeConfigPtr &out_conf) {
1563 MS_EXCEPTION_IF_NULL(out_conf);
1564 auto out_node = out_conf->node();
1565 MS_EXCEPTION_IF_NULL(out_node);
1566 const auto cnode = dyn_cast<CNode>(out_node);
1567 MS_EXCEPTION_IF_NULL(cnode);
1568 auto fg = cnode->func_graph();
1569
1570 auto data_args = args_abs_list[0];
1571 MS_EXCEPTION_IF_NULL(data_args);
1572 // Not check if the data is from PyExecute CNode.
1573 // Do not check the validity of the attribute in the variable scenario.
1574 if (!IsPyExecuteData(data_args) && !raiseutils::HasVariableCondition(fg)) {
1575 TypePtr data_type = data_args->BuildType();
1576 MS_EXCEPTION_IF_NULL(data_type);
1577 auto item_args = args_abs_list[1];
1578 MS_EXCEPTION_IF_NULL(item_args);
1579 ValuePtr item_value = item_args->BuildValue();
1580 auto item_str = item_value->cast_ptr<StringImm>();
1581 MS_EXCEPTION_IF_NULL(item_str);
1582 std::string item_name = item_str->value();
1583 CheckObjAttrValid(data_type, item_name, data_args);
1584 }
1585
1586 constexpr auto debug_recursive_level = 2;
1587 const auto &debug_info = trace::GetSourceCodeDebugInfo(out_node->debug_info());
1588 const auto &location = debug_info->location();
1589 if (location == nullptr) {
1590 MS_LOG(WARNING) << "Location info is null, node: " << out_node->DebugString(debug_recursive_level);
1591 return nullptr;
1592 }
1593 const auto expr = location->expr_src();
1594 if (expr.empty()) {
1595 MS_LOG(WARNING) << "Location's expr is empty, node: " << out_node->DebugString(debug_recursive_level);
1596 return nullptr;
1597 }
1598
1599 constexpr auto item_index = 1;
1600 auto item_arg = args_abs_list.at(item_index);
1601 MS_EXCEPTION_IF_NULL(item_arg);
1602 auto attr_name = GetValue<string>(item_arg->BuildValue());
1603 AnfNodePtr getattr_node;
1604 auto obj_change = cnode->user_data<bool>(fallback::kObjectAttrChange);
1605 if (obj_change != nullptr && *obj_change) {
1606 // The object is changed by setattr node, directly convert it to PyExecute node.
1607 getattr_node = fallback::ConvertCNodeToPyExecuteForPrim(cnode, "getattr");
1608 constexpr auto args_size = 3;
1609 if (args_abs_list.size() == args_size) { // Has setattr node as input.
1610 auto getattr_cnode = getattr_node->cast<CNodePtr>();
1611 MS_EXCEPTION_IF_NULL(getattr_cnode);
1612 getattr_cnode->add_input(cnode->input(args_size));
1613 constexpr auto value_index = 2;
1614 getattr_node = SetTypeForGetAttr(getattr_cnode, args_abs_list[value_index]);
1615 }
1616 } else {
1617 getattr_node = fallback::ConvertGetAttrNodeToPyInterpret(fg, cnode, attr_name);
1618 }
1619 MS_LOG(DEBUG) << "Convert: " << cnode->DebugString() << " -> " << getattr_node->DebugString();
1620 auto eng = out_conf->engine();
1621 MS_EXCEPTION_IF_NULL(eng);
1622 auto fn_conf = eng->MakeConfig(getattr_node, out_conf->context(), out_conf->func_graph());
1623 return eng->ForwardConfig(out_conf, fn_conf);
1624 }
1625
InterpretSetAttrNode(const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)1626 EvalResultPtr InterpretSetAttrNode(const AbstractBasePtrList &args_abs_list, const AnfNodeConfigPtr &out_conf) {
1627 MS_EXCEPTION_IF_NULL(out_conf);
1628 auto out_node = out_conf->node();
1629 MS_EXCEPTION_IF_NULL(out_node);
1630 const auto cnode = dyn_cast<CNode>(out_node);
1631 MS_EXCEPTION_IF_NULL(cnode);
1632 auto fg = cnode->func_graph();
1633 MS_EXCEPTION_IF_NULL(fg);
1634 auto owner_abs = args_abs_list[0];
1635 MS_EXCEPTION_IF_NULL(owner_abs);
1636 if (owner_abs->isa<abstract::AbstractRefTensor>()) {
1637 MS_EXCEPTION(ValueError) << "Do not support to set attribute for a parameter.";
1638 }
1639 auto owner_value = owner_abs->BuildValue();
1640 auto owner_node = cnode->input(1);
1641 constexpr auto debug_recursive_level = 2;
1642 MS_EXCEPTION_IF_NULL(owner_value);
1643 MS_LOG(DEBUG) << "node: " << out_conf->node()->DebugString(debug_recursive_level)
1644 << ", owner_value: " << owner_value->ToString();
1645 if (owner_value->isa<parse::InterpretedObject>()) {
1646 const auto &interpreted_value = dyn_cast<parse::InterpretedObject>(owner_value);
1647 const auto &key = interpreted_value->name();
1648 owner_node = fallback::ConvertPyObjectToPyExecute(fg, key, interpreted_value->obj(), owner_node, true);
1649 }
1650
1651 ValuePtr attr_str_value = args_abs_list[1]->BuildValue();
1652 MS_EXCEPTION_IF_NULL(attr_str_value);
1653 if (!attr_str_value->isa<StringImm>()) {
1654 MS_LOG(EXCEPTION) << "Expect a string, but got: " << attr_str_value->ToString();
1655 }
1656 auto attr_str = attr_str_value->cast<StringImmPtr>();
1657 MS_EXCEPTION_IF_NULL(attr_str);
1658
1659 constexpr auto internal_setattr_owner_str = "__internal_setattr_owner__";
1660 constexpr auto internal_setattr_value_str = "__internal_setattr_value__";
1661 std::stringstream script_buffer;
1662 script_buffer << "__import__('mindspore').common._utils._jit_fallback_set_attr(" << internal_setattr_owner_str << ", "
1663 << attr_str->value() << ", " << internal_setattr_value_str << ")";
1664 MS_LOG(DEBUG) << "script: " << script_buffer.str();
1665 const auto script_setattr_str = std::make_shared<StringImm>(script_buffer.str());
1666
1667 std::vector<ValuePtr> key_list;
1668 (void)key_list.emplace_back(std::make_shared<StringImm>(internal_setattr_owner_str));
1669 (void)key_list.emplace_back(attr_str);
1670 (void)key_list.emplace_back(std::make_shared<StringImm>(internal_setattr_value_str));
1671 const auto key_tuple = std::make_shared<ValueTuple>(key_list);
1672
1673 std::vector<AnfNodePtr> value_list{NewValueNode(prim::kPrimMakeTuple)};
1674 (void)value_list.emplace_back(owner_node);
1675 (void)value_list.emplace_back(NewValueNode(attr_str));
1676 constexpr auto value_node_index = 3;
1677 (void)value_list.emplace_back(cnode->input(value_node_index));
1678 const auto value_tuple_node = fg->NewCNode(value_list);
1679
1680 const auto setattr_node =
1681 fallback::CreatePyExecuteCNode(cnode, NewValueNode(script_setattr_str), NewValueNode(key_tuple), value_tuple_node);
1682 MS_LOG(DEBUG) << "setattr_node: " << setattr_node->DebugString(debug_recursive_level);
1683
1684 // Save abstract for getattr.
1685 constexpr auto value_abs_index = 2;
1686 auto value_abs = args_abs_list[value_abs_index];
1687 if (value_abs != nullptr &&
1688 (value_abs->isa<abstract::AbstractTensor>() || value_abs->isa<abstract::AbstractScalar>())) {
1689 auto type = value_abs->BuildType();
1690 auto shape = value_abs->BuildShape();
1691 fallback::SetRealType<AnfNode, Type>(setattr_node, type);
1692 fallback::SetRealShape<AnfNode, abstract::BaseShape>(setattr_node, shape);
1693 auto abs_tensor = value_abs->cast_ptr<abstract::AbstractTensor>();
1694 if (abs_tensor != nullptr && abs_tensor->is_adapter()) {
1695 setattr_node->set_user_data<bool>(fallback::kIsAdapter, std::make_shared<bool>(true));
1696 }
1697 }
1698
1699 auto eng = out_conf->engine();
1700 MS_EXCEPTION_IF_NULL(eng);
1701 auto fn_conf = eng->MakeConfig(setattr_node, out_conf->context(), out_conf->func_graph());
1702 return eng->ForwardConfig(out_conf, fn_conf);
1703 }
1704
StaticGetterInferred(const ValuePtr & value,const ConfigPtr & data_conf,const AnfNodeConfigPtr & old_conf,REQUIRE_TYPE require_type=REQUIRE_TYPE::METHOD)1705 EvalResultPtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &data_conf, const AnfNodeConfigPtr &old_conf,
1706 REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD) {
1707 MS_EXCEPTION_IF_NULL(old_conf);
1708 AbstractBasePtr abstract = ToAbstract(value, AnalysisContext::DummyContext(), old_conf);
1709 // Create new cnode
1710 std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
1711 auto func_graph_func = dyn_cast_ptr<abstract::FuncGraphAbstractClosure>(abstract);
1712 if (func_graph_func != nullptr) {
1713 FuncGraphPtr fg = func_graph_func->func_graph();
1714 input.push_back(NewValueNode(fg));
1715 } else {
1716 auto prim_func = dyn_cast_ptr<abstract::PrimitiveAbstractClosure>(abstract);
1717 MS_EXCEPTION_IF_NULL(prim_func);
1718 PrimitivePtr prim = prim_func->prim();
1719 input.push_back(NewValueNode(prim));
1720 }
1721
1722 auto conf = dyn_cast_ptr<abstract::AnfNodeConfig>(data_conf);
1723 MS_EXCEPTION_IF_NULL(conf);
1724 input.push_back(conf->node());
1725 MS_EXCEPTION_IF_NULL(old_conf);
1726 MS_EXCEPTION_IF_NULL(old_conf->node());
1727 FuncGraphPtr func_graph = old_conf->node()->func_graph();
1728 MS_EXCEPTION_IF_NULL(func_graph);
1729 CNodePtr new_cnode = func_graph->NewCNode(input);
1730 if (require_type == REQUIRE_TYPE::ATTR) {
1731 new_cnode = func_graph->NewCNode({new_cnode});
1732 }
1733 AnalysisEnginePtr eng = old_conf->engine();
1734 AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, old_conf->context(), old_conf->func_graph());
1735 return eng->ForwardConfig(old_conf, fn_conf);
1736 }
1737
SetSideEffectFlag(const PrimitivePtr & prim,const AnfNodeConfigPtr & out_conf)1738 void SetSideEffectFlag(const PrimitivePtr &prim, const AnfNodeConfigPtr &out_conf) {
1739 if (prim == nullptr) {
1740 return;
1741 }
1742 auto effect_info = GetPrimEffectInfo(prim);
1743 if (effect_info.memory || effect_info.io) {
1744 const auto &cnode = dyn_cast<CNode>(out_conf->node());
1745 MS_EXCEPTION_IF_NULL(cnode);
1746 MS_EXCEPTION_IF_NULL(out_conf->func_graph());
1747 MS_LOG(DEBUG) << "Found side-effect, cnode: " << cnode->DebugString()
1748 << ", func_graph: " << out_conf->func_graph()->ToString();
1749 cnode->set_has_side_effect_node(true);
1750 out_conf->func_graph()->set_has_side_effect_node(true);
1751 }
1752 }
1753
SetOriginObject(const AnfNodePtr & node,const AnfNodeConfigPtr & out_conf)1754 void SetOriginObject(const AnfNodePtr &node, const AnfNodeConfigPtr &out_conf) {
1755 if (!node->isa<ValueNode>()) {
1756 return;
1757 }
1758 auto vnode = node->cast<ValueNodePtr>();
1759 if (vnode->value()->has_user_data("origin_object")) {
1760 auto origin_object = vnode->value()->user_data<py::object>("origin_object");
1761 out_conf->node()->set_user_data<py::object>("origin_object", origin_object);
1762 }
1763 }
1764
SetSparseBpropFlag(const PrimitivePtr & prim,const AnfNodeConfigPtr & out_conf)1765 void SetSparseBpropFlag(const PrimitivePtr &prim, const AnfNodeConfigPtr &out_conf) {
1766 if (GetPrimitiveFlag(prim, GRAPH_FLAG_BPROP_RETURN_SPARSE)) {
1767 out_conf->func_graph()->set_flag(FUNC_GRAPH_FLAG_SPARSE_BPROP, true);
1768 EnvSetSparseResultMgr::GetInstance().Set(true);
1769 }
1770 }
1771
GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList & args_abs_list,const ValuePtr & data_value,const AnfNodeConfigPtr & out_conf,const std::string & data)1772 EvalResultPtr GetEvaluatedValueForNameSpaceString(const AbstractBasePtrList &args_abs_list, const ValuePtr &data_value,
1773 const AnfNodeConfigPtr &out_conf, const std::string &data) {
1774 constexpr size_t item_index = 1;
1775 auto item_args = args_abs_list[item_index];
1776 MS_EXCEPTION_IF_NULL(item_args);
1777 ValuePtr item_value = item_args->BuildValue();
1778 MS_EXCEPTION_IF_NULL(data_value);
1779 MS_EXCEPTION_IF_NULL(item_value);
1780 if (item_value->isa<StringImm>()) {
1781 auto string_value = item_value->cast_ptr<StringImm>();
1782 MS_EXCEPTION_IF_NULL(string_value);
1783 item_value = std::make_shared<parse::Symbol>(string_value->value());
1784 }
1785 if (!item_value->isa<parse::Symbol>()) {
1786 MS_LOG(INTERNAL_EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
1787 }
1788
1789 // item_name to func addr from obj_map
1790 auto symbol = item_value->cast<parse::SymbolPtr>();
1791 auto name_space = data_value->cast<parse::NameSpacePtr>();
1792 constexpr auto tensors_queue_attr = "__is_tensors_queue__";
1793 constexpr auto pop_attr = "pop";
1794 if (name_space != nullptr && py::hasattr(name_space->namespace_obj(), tensors_queue_attr) &&
1795 symbol->symbol() == pop_attr) {
1796 constexpr auto graph_pop_attr = "__graph_pop__";
1797 symbol = std::make_shared<parse::Symbol>(graph_pop_attr);
1798 }
1799 MS_EXCEPTION_IF_NULL(out_conf);
1800 auto out_node = out_conf->node();
1801 MS_EXCEPTION_IF_NULL(out_node);
1802 FuncGraphPtr func_graph = out_node->func_graph();
1803 MS_EXCEPTION_IF_NULL(func_graph);
1804 auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_node);
1805 if (new_node == nullptr) {
1806 MS_LOG(INTERNAL_EXCEPTION) << "Resolve node failed";
1807 }
1808
1809 auto prim = GetPrimitiveWithoutDoSignature(new_node);
1810 SetSparseBpropFlag(prim, out_conf);
1811 SetSideEffectFlag(prim, out_conf);
1812 SetOriginObject(new_node, out_conf);
1813
1814 if (IsValueNode<TypeNull>(new_node)) {
1815 // Do not find the attribute.
1816 constexpr auto max_args_len = 3;
1817 bool has_default = (args_abs_list.size() == max_args_len);
1818 if (!has_default) {
1819 MS_EXCEPTION(AttributeError) << data << " object has no attribute " << symbol->symbol();
1820 }
1821 auto out_cnode = out_node->cast_ptr<CNode>();
1822 MS_EXCEPTION_IF_NULL(out_cnode);
1823 constexpr auto default_index = 3;
1824 auto default_node = out_cnode->input(default_index);
1825 auto eng = out_conf->engine();
1826 MS_EXCEPTION_IF_NULL(eng);
1827 auto fn_conf = eng->MakeConfig(default_node, out_conf->context(), out_conf->func_graph());
1828 return eng->ForwardConfig(out_conf, fn_conf);
1829 }
1830
1831 auto new_node_to_fg = GetValueNode<FuncGraphPtr>(new_node);
1832 if (new_node_to_fg != nullptr) {
1833 bool has_recompute_scope = (out_node->scope() != nullptr &&
1834 out_node->scope()->name().compare(0, strlen(kAttrRecompute), kAttrRecompute) == 0);
1835 if (has_recompute_scope) {
1836 parse::UpdateRecomputeScope(new_node_to_fg);
1837 } else if (MsContext::GetInstance()->get_param<int>(MS_CTX_DEBUG_LEVEL) == kLevelDebug) {
1838 UpdateDebugInfo(new_node_to_fg, out_node->scope(), out_node->debug_info());
1839 }
1840 }
1841
1842 AnalysisEnginePtr eng = out_conf->engine();
1843 MS_EXCEPTION_IF_NULL(eng);
1844 AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
1845 return eng->ForwardConfig(out_conf, fn_conf);
1846 }
1847
GenerateFuncGraphForOverriddenMethod(AnfNodePtr node,const ValuePtr & item_value,const AnfNodeConfigPtr & out_conf)1848 EvalResultPtr GenerateFuncGraphForOverriddenMethod(AnfNodePtr node, const ValuePtr &item_value,
1849 const AnfNodeConfigPtr &out_conf) {
1850 const auto &item_str = item_value->cast_ptr<StringImm>();
1851 FuncGraphPtr inner_fg = nullptr;
1852 py::object overridden_method = py::none();
1853 py::object value_obj = py::none();
1854 if (item_str != nullptr) {
1855 const std::string &item_name = item_str->value();
1856 if (node->has_user_data(item_name)) {
1857 value_obj = *node->user_data<py::object>(item_name);
1858 overridden_method = value_obj.attr("__class__").attr(item_name.c_str());
1859 }
1860 }
1861 bool is_getattr = node->has_user_data("__getattr__");
1862 if (is_getattr) {
1863 value_obj = *node->user_data<py::object>("__getattr__");
1864 try {
1865 overridden_method = value_obj.attr("__class__").attr("__getattr__");
1866 } catch (const std::exception &e) {
1867 MS_LOG(DEBUG) << value_obj << " has no attribute getattr.";
1868 }
1869 }
1870 if (py::isinstance<py::none>(overridden_method) || py::isinstance<py::none>(value_obj)) {
1871 return nullptr;
1872 }
1873 {
1874 MS_LOG_TRY_CATCH_SCOPE;
1875 inner_fg = parse::ParsePythonCode(overridden_method);
1876 }
1877 MS_EXCEPTION_IF_NULL(out_conf);
1878 auto eng = out_conf->engine();
1879 MS_EXCEPTION_IF_NULL(eng);
1880 auto cnode = node->cast<CNodePtr>();
1881 MS_EXCEPTION_IF_NULL(cnode);
1882 FuncGraphPtr func_graph = node->func_graph();
1883 MS_EXCEPTION_IF_NULL(func_graph);
1884 const auto &interpreted_obj = std::make_shared<parse::InterpretedObject>(value_obj);
1885 const auto &value_node = NewValueNode(interpreted_obj);
1886 if (inner_fg == nullptr) {
1887 std::vector<AnfNodePtr> new_inputs;
1888 for (size_t i = 0; i < cnode->size(); i++) {
1889 if (i == 1) {
1890 new_inputs.push_back(value_node);
1891 } else {
1892 new_inputs.push_back(cnode->input(i));
1893 }
1894 }
1895 CNodePtr new_cnode = func_graph->NewCNode(new_inputs);
1896 auto fn_conf = eng->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
1897 return eng->ForwardConfig(out_conf, fn_conf);
1898 }
1899 AddToManager(eng, inner_fg);
1900 if (is_getattr) {
1901 std::vector<AnfNodePtr> new_inputs = {NewValueNode(inner_fg)};
1902 for (size_t i = 0; i < cnode->size(); i++) {
1903 if (i > 0) {
1904 new_inputs.push_back(cnode->input(i));
1905 }
1906 }
1907 CNodePtr new_cnode = func_graph->NewCNode(new_inputs);
1908 auto fn_conf = eng->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
1909 return eng->ForwardConfig(out_conf, fn_conf);
1910 }
1911 std::vector<AnfNodePtr> input = {NewValueNode(prim::kPrimPartial)};
1912 input.push_back(NewValueNode(inner_fg));
1913 input.push_back(value_node);
1914 CNodePtr new_cnode = func_graph->NewCNode(input);
1915 auto fn_conf = eng->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
1916 return eng->ForwardConfig(out_conf, fn_conf);
1917 }
1918
GetEvaluatedValueForNameSpace(const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf,const bool check_override=false)1919 EvalResultPtr GetEvaluatedValueForNameSpace(const AbstractBasePtrList &args_abs_list, const AnfNodeConfigPtr &out_conf,
1920 const bool check_override = false) {
1921 // args_abs_list: same as StaticGetter
1922 constexpr size_t args_min_size = 2;
1923 if (args_abs_list.size() < args_min_size) {
1924 MS_LOG(INTERNAL_EXCEPTION) << "Size of args_abs_list is less than 2";
1925 }
1926 MS_EXCEPTION_IF_NULL(out_conf);
1927 // An external type.
1928 constexpr auto data_index = 0;
1929 constexpr auto item_index = 1;
1930 auto data = args_abs_list[data_index];
1931 auto item = args_abs_list[item_index];
1932 MS_EXCEPTION_IF_NULL(data);
1933 MS_EXCEPTION_IF_NULL(item);
1934 MS_EXCEPTION_IF_NULL(out_conf->node());
1935 auto data_value = data->BuildValue();
1936 MS_EXCEPTION_IF_NULL(data_value);
1937 auto data_type = data->BuildType();
1938 MS_EXCEPTION_IF_NULL(data_type);
1939 auto item_value = item->BuildValue();
1940 std::string data_id_str = TypeIdToString(data_type->type_id());
1941 if (check_override) {
1942 auto inner_fg_res = GenerateFuncGraphForOverriddenMethod(out_conf->node(), item_value, out_conf);
1943 if (inner_fg_res != nullptr) return inner_fg_res;
1944 }
1945 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
1946 if (data_value->isa<parse::ClassType>()) {
1947 auto class_val = dyn_cast_ptr<parse::ClassType>(data_value);
1948 auto class_obj = class_val->obj();
1949 py::object ns_obj = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, class_obj);
1950 data_value = std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, ns_obj);
1951 data_id_str = class_val->name();
1952 }
1953 if (data_value->isa<parse::MsClassObject>()) {
1954 auto class_val = dyn_cast_ptr<parse::MsClassObject>(data_value);
1955 auto class_obj = class_val->obj();
1956 py::object ns_obj = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, class_obj);
1957 data_value = std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, ns_obj);
1958 data_id_str = class_val->name();
1959 }
1960 if (!data_value->isa<parse::NameSpace>()) {
1961 MS_EXCEPTION_IF_NULL(item_value);
1962 MS_LOG(DEBUG) << "Evaluate " << data_value->ToString() << " attribute: " << item_value->ToString()
1963 << ".\nnode: " << out_conf->node()->DebugString() << "\n"
1964 << trace::GetDebugInfoStr(out_conf->node()->debug_info());
1965 auto res = InterpretGetAttrNode(args_abs_list, out_conf);
1966 if (res == nullptr) {
1967 MS_EXCEPTION(AttributeError) << data_value->ToString() << " object has no attribute: " << item_value->ToString();
1968 }
1969 return res;
1970 }
1971 return GetEvaluatedValueForNameSpaceString(args_abs_list, data_value, out_conf, data_id_str);
1972 }
1973
GetEvaluatedValueForPrimitiveAttr(const AbstractBasePtrList & args_abs_list,const AbstractFunctionPtr & data_args)1974 EvalResultPtr GetEvaluatedValueForPrimitiveAttr(const AbstractBasePtrList &args_abs_list,
1975 const AbstractFunctionPtr &data_args) {
1976 MS_EXCEPTION_IF_NULL(data_args);
1977 if (!data_args->isa<PrimitiveAbstractClosure>()) {
1978 return nullptr;
1979 }
1980 auto prim_abs = dyn_cast_ptr<PrimitiveAbstractClosure>(data_args);
1981 const auto &prim = prim_abs->prim();
1982 MS_EXCEPTION_IF_NULL(prim);
1983 constexpr auto item_index = 1;
1984 auto item_arg = args_abs_list.at(item_index);
1985 MS_EXCEPTION_IF_NULL(item_arg);
1986 auto attr_name = GetValue<string>(item_arg->BuildValue());
1987 auto value = prim->GetAttr(attr_name);
1988 if (value == nullptr) {
1989 MS_LOG(INFO) << "The Primitive: " << prim->ToString() << " has not attr " << attr_name;
1990 MS_LOG(INFO) << "PrimAttr: " << prim->GetAttrsText();
1991 return nullptr;
1992 }
1993 return std::make_shared<EvalResult>(value->ToAbstract(), nullptr);
1994 }
1995
GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEnginePtr & engine,const AbstractBasePtr & data_args,const AbstractBasePtr & item_args,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)1996 EvalResultPtr GetEvaluatedValueForAdapterTensorAttrOrMethod(const AnalysisEnginePtr &engine,
1997 const AbstractBasePtr &data_args,
1998 const AbstractBasePtr &item_args,
1999 const ConfigPtr &data_conf,
2000 const AnfNodeConfigPtr &out_conf) {
2001 MS_EXCEPTION_IF_NULL(data_args);
2002 MS_EXCEPTION_IF_NULL(item_args);
2003 // Check whether it is AdapterTensor or AdapterParameter.
2004 auto abs = data_args->cast_ptr<abstract::AbstractTensor>();
2005 if (abs == nullptr || !abs->is_adapter()) {
2006 return nullptr;
2007 }
2008
2009 // Get the name of attr/method.
2010 ValuePtr item_value = item_args->BuildValue();
2011 MS_EXCEPTION_IF_NULL(item_value);
2012 if (!item_value->isa<StringImm>()) {
2013 MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
2014 }
2015 std::string item_name = item_value->cast_ptr<StringImm>()->value();
2016
2017 constexpr size_t attr_index = 0;
2018 constexpr size_t flag_index = 1;
2019 constexpr size_t info_required_size = 2;
2020 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
2021 py::tuple attr_info = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_ADAPTER_TENSOR_ATTR, py::str(item_name));
2022 if (attr_info.size() != info_required_size) {
2023 MS_EXCEPTION(NameError) << "attr info size should be 2, but got " << attr_info.size();
2024 }
2025 // If func is none, it means there is no such attr or method.
2026 py::object func = attr_info[attr_index];
2027 if (py::isinstance<py::none>(func)) {
2028 return nullptr;
2029 }
2030 ValuePtr converted_value = nullptr;
2031 bool success = parse::ConvertData(func, &converted_value);
2032 if (!success || converted_value == nullptr || !converted_value->isa<FuncGraph>()) {
2033 return nullptr;
2034 }
2035 AddToManager(engine, converted_value->cast<FuncGraphPtr>());
2036
2037 // Check whether it is an attribute or a method.
2038 bool is_attr = attr_info[flag_index].cast<bool>();
2039 REQUIRE_TYPE require_type = is_attr ? REQUIRE_TYPE::ATTR : REQUIRE_TYPE::METHOD;
2040 return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
2041 }
2042
GetOriginObj(const AnfNodePtr & node)2043 py::object GetOriginObj(const AnfNodePtr &node) {
2044 MS_EXCEPTION_IF_NULL(node);
2045 py::object origin_obj;
2046 if (node->has_user_data("origin_object")) {
2047 return *node->user_data<py::object>("origin_object");
2048 }
2049 if (!node->isa<ValueNode>()) {
2050 return origin_obj;
2051 }
2052 auto vnode = node->cast<ValueNodePtr>();
2053 if (vnode->value()->has_user_data("origin_object")) {
2054 return *vnode->value()->user_data<py::object>("origin_object");
2055 }
2056 return origin_obj;
2057 }
2058
GetEvaluatedValueForAttrOrMethodNotInMap(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf,const std::string & item_name,const TypePtr & data_type)2059 EvalResultPtr GetEvaluatedValueForAttrOrMethodNotInMap(const AnalysisEnginePtr &engine,
2060 const AbstractBasePtrList &args_abs_list,
2061 const AnfNodeConfigPtr &out_conf, const std::string &item_name,
2062 const TypePtr &data_type) {
2063 constexpr auto max_args_len = 3;
2064 bool has_default = (args_abs_list.size() == max_args_len);
2065 auto out_node = out_conf->node();
2066 auto out_cnode = out_node->cast_ptr<CNode>();
2067 MS_EXCEPTION_IF_NULL(out_cnode);
2068 auto eng = out_conf->engine();
2069 MS_EXCEPTION_IF_NULL(eng);
2070 if (has_default) {
2071 constexpr auto default_index = 3;
2072 auto default_node = out_cnode->input(default_index);
2073 auto fn_conf = eng->MakeConfig(default_node, out_conf->context(), out_conf->func_graph());
2074 return eng->ForwardConfig(out_conf, fn_conf);
2075 }
2076
2077 py::object value_obj = GetOriginObj(out_cnode->input(1));
2078 if (value_obj.ptr() != nullptr) {
2079 std::vector<AnfNodePtr> new_inputs;
2080 std::string data_type_str = TypeIdLabel(NormalizeTypeId(data_type->type_id()));
2081 py::module mod1 = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
2082 py::object obj_define = python_adapter::CallPyModFn(mod1, parse::PYTHON_MOD_GET_OBJ_DEFINED, data_type_str);
2083 py::object check_res =
2084 python_adapter::CallPyModFn(mod1, parse::PYTHON_MOD_CHECK_IS_SUBCLASS, value_obj, obj_define);
2085 if (py::cast<bool>(check_res)) {
2086 for (size_t i = 0; i < out_cnode->size(); i++) {
2087 if (i == 1) {
2088 const auto &interpreted_obj = std::make_shared<parse::InterpretedObject>(value_obj);
2089 const auto &value_node = NewValueNode(interpreted_obj);
2090 new_inputs.push_back(value_node);
2091 } else {
2092 new_inputs.push_back(out_cnode->input(i));
2093 }
2094 }
2095 CNodePtr new_cnode = out_conf->func_graph()->NewCNode(new_inputs);
2096 auto fn_conf = eng->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
2097 return eng->ForwardConfig(out_conf, fn_conf);
2098 }
2099 }
2100 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
2101 if (!allow_fallback_runtime) {
2102 MS_EXCEPTION(AttributeError) << "In JIT strict mode, cannot get attributes " << item_name << " or the "
2103 << data_type->ToString() << " object has no attribute: " << item_name
2104 << "'. You can use os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' "
2105 << "to enable the JIT lax mode to support the current syntax.\n\n"
2106 << trace::GetDebugInfoStr(out_conf->node()->debug_info());
2107 }
2108
2109 constexpr auto recursive_level = 3;
2110 MS_LOG(DEBUG) << "Evaluate " << data_type->ToString() << " attribute: " << item_name
2111 << ".\nnode: " << out_conf->node()->DebugString(recursive_level) << "\n"
2112 << trace::GetDebugInfoStr(out_conf->node()->debug_info());
2113 auto res = InterpretGetAttrNode(args_abs_list, out_conf);
2114 if (res == nullptr) {
2115 MS_EXCEPTION(AttributeError) << data_type->ToString() << " object has no attribute: " << item_name;
2116 }
2117 return res;
2118 }
2119
GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)2120 EvalResultPtr GetEvaluatedValueForBuiltinTypeAttrOrMethod(const AnalysisEnginePtr &engine,
2121 const AbstractBasePtrList &args_abs_list,
2122 const ConfigPtr &data_conf,
2123 const AnfNodeConfigPtr &out_conf) {
2124 constexpr size_t data_index = 0;
2125 constexpr size_t item_index = 1;
2126 auto data_args = args_abs_list[data_index];
2127 auto item_args = args_abs_list[item_index];
2128 MS_EXCEPTION_IF_NULL(data_args);
2129 MS_EXCEPTION_IF_NULL(item_args);
2130 ValuePtr item_value = item_args->BuildValue();
2131 MS_EXCEPTION_IF_NULL(item_value);
2132 TypePtr data_type = data_args->BuildType();
2133 MS_EXCEPTION_IF_NULL(data_type);
2134 // Handle NameTuple: getattr(XX, item_value) -> ValueNode().
2135 if (data_args->isa<AbstractNamedTuple>()) {
2136 auto named_tuple = data_args->cast<AbstractNamedTuplePtr>();
2137 const auto &keys = named_tuple->key();
2138 for (size_t it = 0; it < keys.size(); ++it) {
2139 auto key_value = keys[it]->BuildValue();
2140 MS_EXCEPTION_IF_NULL(key_value);
2141 if (*item_value == *key_value) {
2142 auto getattr_node = NewValueNode(named_tuple->elements()[it]->BuildValue());
2143 auto eng = out_conf->engine();
2144 MS_EXCEPTION_IF_NULL(eng);
2145 auto fn_conf = eng->MakeConfig(getattr_node, out_conf->context(), out_conf->func_graph());
2146 return eng->ForwardConfig(out_conf, fn_conf);
2147 }
2148 }
2149 }
2150
2151 // The method maybe a Primitive or Composite
2152 if (!item_value->isa<StringImm>()) {
2153 MS_LOG(EXCEPTION) << "Expect a string, but got: " << item_value->ToString();
2154 }
2155 auto item_str = item_value->cast_ptr<StringImm>();
2156 MS_EXCEPTION_IF_NULL(item_str);
2157 std::string item_name = item_str->value();
2158 REQUIRE_TYPE require_type = REQUIRE_TYPE::METHOD;
2159 Any require = pipeline::Resource::GetMethodPtr(data_type->type_id(), item_name);
2160 MS_EXCEPTION_IF_NULL(out_conf->node());
2161 if (require.empty()) {
2162 require = pipeline::Resource::GetAttrPtr(data_type->type_id(), item_name);
2163 if (require.empty()) {
2164 return GetEvaluatedValueForAttrOrMethodNotInMap(engine, args_abs_list, out_conf, item_name, data_type);
2165 }
2166 require_type = REQUIRE_TYPE::ATTR;
2167 }
2168
2169 ValuePtr converted_value = nullptr;
2170 if (require.is<std::string>()) {
2171 // composite registered in standard_method_map go to this branch
2172 converted_value = prim::GetPythonOps(require.cast<std::string>());
2173 MS_EXCEPTION_IF_NULL(converted_value);
2174
2175 auto converted_fg = converted_value->cast<FuncGraphPtr>();
2176 if (converted_fg != nullptr) {
2177 bool has_recompute_scope =
2178 (out_conf->node()->scope() != nullptr &&
2179 out_conf->node()->scope()->name().compare(0, strlen(kAttrRecompute), kAttrRecompute) == 0);
2180 if (has_recompute_scope) {
2181 parse::UpdateRecomputeScope(converted_fg);
2182 } else if (MsContext::GetInstance()->get_param<int>(MS_CTX_DEBUG_LEVEL) == kLevelDebug) {
2183 UpdateDebugInfo(converted_fg, out_conf->node()->scope(), out_conf->node()->debug_info());
2184 }
2185 }
2186
2187 if (!converted_value->isa<Primitive>()) {
2188 AddToManager(engine, converted_value->cast<FuncGraphPtr>());
2189 }
2190 } else if (require.is<PrimitivePtr>()) {
2191 converted_value = require.cast<PrimitivePtr>();
2192 } else {
2193 MS_LOG(EXCEPTION) << "Expect to get string or PrimitivePtr from attr or method map, but got " << require.ToString();
2194 }
2195 return StaticGetterInferred(converted_value, data_conf, out_conf, require_type);
2196 }
2197
TransPropertyToFunc(const AnfNodeConfigPtr & out_conf,py::object property_net_obj,std::string item_name)2198 EvalResultPtr TransPropertyToFunc(const AnfNodeConfigPtr &out_conf, py::object property_net_obj,
2199 std::string item_name) {
2200 py::object property_func = py::none();
2201 try {
2202 property_func = property_net_obj.attr("__class__").attr(py::str(item_name));
2203 } catch (const std::exception &e) {
2204 MS_LOG(ERROR) << property_net_obj << " has no attribute " << item_name;
2205 }
2206 py::object property_func_fget = property_func.attr(py::str("fget"));
2207 auto inner_fg = parse::ParsePythonCode(property_func_fget);
2208 auto eng = out_conf->engine();
2209 MS_EXCEPTION_IF_NULL(eng);
2210 AddToManager(eng, inner_fg);
2211 auto node = out_conf->node();
2212 auto cnode = node->cast<CNodePtr>();
2213 MS_EXCEPTION_IF_NULL(cnode);
2214 FuncGraphPtr func_graph = node->func_graph();
2215 MS_EXCEPTION_IF_NULL(func_graph);
2216 std::vector<AnfNodePtr> new_inputs = {NewValueNode(inner_fg)};
2217 new_inputs.push_back(cnode->input(1));
2218 CNodePtr new_cnode = func_graph->NewCNode(new_inputs);
2219 MS_LOG(DEBUG) << "new_cnode:" << new_cnode->DebugString();
2220 auto fn_conf = eng->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
2221 return eng->ForwardConfig(out_conf, fn_conf);
2222 }
2223
GetClassAttrFromPyObject(const py::object & cls_obj,const std::string & cls_name,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)2224 EvalResultPtr GetClassAttrFromPyObject(const py::object &cls_obj, const std::string &cls_name,
2225 const AbstractBasePtrList &args_abs_list, const AnfNodeConfigPtr &out_conf) {
2226 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
2227 constexpr auto item_index = 1;
2228 auto item_arg = args_abs_list.at(item_index);
2229 MS_EXCEPTION_IF_NULL(item_arg);
2230 auto attr_name = GetValue<string>(item_arg->BuildValue());
2231 bool is_property =
2232 (python_adapter::CallPyModFn(mod, parse::PYTHON_PARSE_CHECK_ATTR_IS_PROPERTY, cls_obj, attr_name)).cast<bool>();
2233 if (is_property) {
2234 ValuePtr item_value = item_arg->BuildValue();
2235 MS_EXCEPTION_IF_NULL(item_value);
2236 const auto &item_str = item_value->cast_ptr<StringImm>();
2237 const std::string &item_name = item_str->value();
2238 return TransPropertyToFunc(out_conf, cls_obj, item_name);
2239 }
2240 py::object ns_obj = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, cls_obj);
2241 auto ns = std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, ns_obj);
2242 return GetEvaluatedValueForNameSpaceString(args_abs_list, ns, out_conf, cls_name);
2243 }
2244
GetFuncAbstractAttr(const AbstractFunctionPtr & data_args,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)2245 EvalResultPtr GetFuncAbstractAttr(const AbstractFunctionPtr &data_args, const AbstractBasePtrList &args_abs_list,
2246 const AnfNodeConfigPtr &out_conf) {
2247 if (data_args == nullptr) {
2248 return nullptr;
2249 }
2250 // Get attribute or method of PartialAbstractClosure, the object could be nn.Cell/ms_class object.
2251 auto data_partial = dyn_cast_ptr<PartialAbstractClosure>(data_args);
2252 if (data_partial != nullptr) {
2253 const auto &partial_args = data_partial->args();
2254 auto prim_abs = dyn_cast_ptr<PrimitiveAbstractClosure>(data_partial->fn());
2255 if (prim_abs != nullptr && !partial_args.empty()) {
2256 MS_EXCEPTION_IF_NULL(prim_abs->prim());
2257 const auto &prim_name = prim_abs->prim()->name();
2258 if (prim_name == prim::kPrimCreateInstance->name()) {
2259 constexpr size_t class_index = 0;
2260 MS_EXCEPTION_IF_NULL(partial_args[class_index]);
2261 auto class_val = partial_args[class_index]->BuildValue();
2262 MS_EXCEPTION_IF_NULL(class_val);
2263 auto wrapper = dyn_cast_ptr<parse::PyObjectWrapper>(class_val);
2264 MS_EXCEPTION_IF_NULL(wrapper);
2265 return GetClassAttrFromPyObject(wrapper->obj(), wrapper->name(), args_abs_list, out_conf);
2266 }
2267 }
2268 return nullptr;
2269 }
2270 // Get attribute or method of FuncGraphAbstractClosure, the object could be nn.Cell/ms_class object.
2271 const auto &cls_obj = fallback::GetPyObjForFuncGraphAbstractClosure(data_args);
2272 if (py::isinstance<Cell>(cls_obj) || py::hasattr(cls_obj, PYTHON_MS_CLASS)) {
2273 return GetClassAttrFromPyObject(cls_obj, py::str(cls_obj), args_abs_list, out_conf);
2274 }
2275 return GetEvaluatedValueForPrimitiveAttr(args_abs_list, data_args);
2276 }
2277
CheckHasOverriddenMethod(AnfNodePtr node,ValuePtr item_value)2278 bool CheckHasOverriddenMethod(AnfNodePtr node, ValuePtr item_value) {
2279 const auto &item_str = item_value->cast_ptr<StringImm>();
2280 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
2281 if (item_str != nullptr) {
2282 const std::string &item_name = item_str->value();
2283 if (node->has_user_data(item_name)) {
2284 auto value_obj = *node->user_data<py::object>(item_name);
2285 py::bool_ check = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CHECK_ATTRS, value_obj, item_name);
2286 return py::cast<bool>(check);
2287 }
2288 }
2289 if (node->has_user_data("__getattr__")) {
2290 auto value_obj = *node->user_data<py::object>("__getattr__");
2291 py::bool_ check = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CHECK_ATTRS, value_obj, "__getattr__");
2292 return py::cast<bool>(check);
2293 }
2294 return false;
2295 }
2296
StaticGetter(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr & data_conf,const AnfNodeConfigPtr & out_conf)2297 EvalResultPtr StaticGetter(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
2298 const ConfigPtr &data_conf, const AnfNodeConfigPtr &out_conf) {
2299 // Inputs: namespace and its static function; or class and its member function
2300 constexpr size_t data_index = 0;
2301 constexpr size_t item_index = 1;
2302 auto data_args = args_abs_list[data_index];
2303 auto item_args = args_abs_list[item_index];
2304 MS_EXCEPTION_IF_NULL(data_args);
2305 MS_EXCEPTION_IF_NULL(item_args);
2306 MS_EXCEPTION_IF_NULL(out_conf);
2307 MS_EXCEPTION_IF_NULL(out_conf->node());
2308 constexpr auto recursive_level = 2;
2309 MS_LOG(DEBUG) << "StaticGetter, data: " << data_args->ToString() << ", item: " << item_args->ToString()
2310 << ", node: " << out_conf->node()->DebugString(recursive_level);
2311 ScopePtr scope = out_conf->node()->scope();
2312 ScopeGuard scope_guard(scope);
2313 ValuePtr item_value = item_args->BuildValue();
2314 MS_EXCEPTION_IF_NULL(item_value);
2315 if (item_value->ContainsValueAny()) {
2316 MS_LOG(INTERNAL_EXCEPTION) << "The value of the attribute could not be inferred: " << item_value->ToString();
2317 }
2318
2319 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
2320 constexpr auto max_args_size = 3;
2321 if (!allow_fallback_runtime && args_abs_list.size() == max_args_size) {
2322 constexpr size_t default_index = 2;
2323 auto default_args = args_abs_list[default_index];
2324 MS_EXCEPTION_IF_NULL(default_args);
2325 if (default_args->isa<abstract::AbstractScalar>()) {
2326 ValuePtr default_value = default_args->BuildValue();
2327 MS_EXCEPTION_IF_NULL(default_value);
2328 if (default_value->isa<parse::InterpretedObject>()) {
2329 auto obj = ValueToPyData(default_value);
2330 auto type_str = python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, parse::PYTHON_PARSE_GET_TYPE, obj);
2331 MS_EXCEPTION(TypeError) << "For 'getattr', the third input 'default' can not be " << py::str(type_str)
2332 << " object " << py::str(obj);
2333 }
2334 }
2335 }
2336
2337 auto res = GetFuncAbstractAttr(data_args->cast<AbstractFunctionPtr>(), args_abs_list, out_conf);
2338 if (res != nullptr) {
2339 return res;
2340 }
2341
2342 // Get attribute or method of AdapterTensor object.
2343 res = GetEvaluatedValueForAdapterTensorAttrOrMethod(engine, data_args, item_args, data_conf, out_conf);
2344 if (res != nullptr) {
2345 return res;
2346 }
2347 // Try to search method map, if not found, the data_type should be External type.
2348 TypePtr data_type = data_args->BuildType();
2349 MS_EXCEPTION_IF_NULL(data_type);
2350 // Check if attr is a overridden method.
2351 bool check_override = CheckHasOverriddenMethod(out_conf->node(), item_value);
2352 // Not check if the data is from PyExecute CNode, since its Tensor output is pseud.
2353 if (!IsPyExecuteData(data_args) && pipeline::Resource::IsTypeInBuiltInMap(data_type->type_id()) && !check_override) {
2354 return GetEvaluatedValueForBuiltinTypeAttrOrMethod(engine, args_abs_list, data_conf, out_conf);
2355 }
2356 return GetEvaluatedValueForNameSpace(args_abs_list, out_conf, check_override);
2357 }
2358
GetAnnotationType(const AnfNodePtr & node,const AbstractBasePtrList & args_abs_list)2359 TypePtr GetAnnotationType(const AnfNodePtr &node, const AbstractBasePtrList &args_abs_list) {
2360 MS_EXCEPTION_IF_NULL(node);
2361 fallback::FormatedVariableTypeFunc func = [&node, &args_abs_list](const std::string &type_var_str) -> TypePtr {
2362 // For PyInterpret, the args[1] is global dict, and the args[2] is local dict.
2363 // For PyExecute, the args[1] is local dict keys, and the args[2] is local dict values.
2364 ValuePtr type_value = nullptr;
2365 const auto &keys_tuple_abs = args_abs_list[1];
2366 MS_EXCEPTION_IF_NULL(keys_tuple_abs);
2367 const auto &keys_tuple = keys_tuple_abs->BuildValue();
2368 const auto &keys = dyn_cast<ValueSequence>(keys_tuple);
2369 bool is_py_execute = (keys != nullptr);
2370 if (is_py_execute) { // PyExecute.
2371 bool found = false;
2372 size_t i = 0;
2373 for (; i < keys->value().size(); ++i) {
2374 const auto &key = dyn_cast<StringImm>(keys->value()[i]);
2375 MS_EXCEPTION_IF_NULL(key);
2376 if (key->value() == type_var_str) {
2377 found = true;
2378 break;
2379 }
2380 }
2381
2382 if (!found) {
2383 MS_LOG(INFO) << "Not valid PyExecute CNode. node: " << node->DebugString() << ", keys: " << keys->ToString()
2384 << ", not found " << type_var_str;
2385 return nullptr;
2386 }
2387 constexpr auto values_index = 2;
2388 const auto &values_tuple_abs = dyn_cast<AbstractSequence>(args_abs_list[values_index]);
2389 MS_EXCEPTION_IF_NULL(values_tuple_abs);
2390 const auto &type_value_abs = values_tuple_abs->elements()[i];
2391 if (type_value_abs == nullptr) {
2392 MS_LOG(INFO) << "Not valid PyExecute CNode. node: " << node->DebugString() << ", key: " << type_var_str
2393 << ", values_tuple_abs: " << values_tuple_abs->ToString();
2394 return nullptr;
2395 }
2396 bool only_has_real_type = !fallback::HasRealShape(type_value_abs) && fallback::HasRealType(type_value_abs);
2397 type_value =
2398 only_has_real_type ? fallback::GetRealType<AbstractBase, Type>(type_value_abs) : type_value_abs->BuildValue();
2399 } else { // PyInterpret
2400 constexpr auto local_dict_index = 2;
2401 const auto &local_dict_abs = args_abs_list[local_dict_index];
2402 const auto &dict = dyn_cast<AbstractDictionary>(local_dict_abs);
2403 if (dict == nullptr || dict->elements().empty()) {
2404 MS_EXCEPTION_IF_NULL(local_dict_abs);
2405 MS_LOG(INFO) << "Not valid PyInterpret CNode. node: " << node->DebugString() << ", key: " << type_var_str
2406 << ", local_dict_abs: " << local_dict_abs->ToString();
2407 return nullptr;
2408 }
2409 for (const auto &element : dict->elements()) {
2410 MS_EXCEPTION_IF_NULL(element.first);
2411 const auto &key = element.first->BuildValue();
2412 if (key == nullptr || !key->isa<StringImm>()) {
2413 continue;
2414 }
2415 if (key->cast<StringImmPtr>()->value() == type_var_str) {
2416 MS_EXCEPTION_IF_NULL(element.second);
2417 type_value = element.second->BuildValue();
2418 break;
2419 }
2420 }
2421 }
2422
2423 if (type_value == nullptr) {
2424 MS_LOG(INFO) << "Not valid " << (is_py_execute ? "PyExecute" : "PyInterpret")
2425 << " CNode. node: " << node->DebugString() << ", key: " << type_var_str << ", type value is null.";
2426 return nullptr;
2427 }
2428 const auto &py_type = BuildPyObject(type_value);
2429 MS_LOG(DEBUG) << "type_value: " << type_value->ToString() << ", py_type: " << py_type;
2430 if (!py::isinstance<py::none>(py_type)) {
2431 return py::cast<TypePtr>(py_type);
2432 }
2433 MS_LOG(INFO) << "Not valid " << (is_py_execute ? "PyExecute" : "PyInterpret")
2434 << " CNode. node: " << node->DebugString() << ", key: " << type_var_str << ", type value is None.";
2435 return nullptr;
2436 };
2437 const auto &type = fallback::GetJitAnnotationTypeFromComment(node, func);
2438 return type;
2439 }
2440
GetLocalArgsUniqueDtype(const AnfNodePtr & node,const AbstractBasePtrList & args_abs_list)2441 TypePtr GetLocalArgsUniqueDtype(const AnfNodePtr &node, const AbstractBasePtrList &args_abs_list) {
2442 // If force to use ANY.
2443 static const auto force_any = (common::GetCompileConfig("FALLBACK_FORCE_ANY") == "1");
2444 if (force_any) {
2445 return nullptr;
2446 }
2447
2448 TypePtr res = nullptr;
2449 // Check the abstract, return true if continue, otherwise return false.
2450 auto unique_dtype_check = [&node, &res](const AbstractBasePtr &element_value_abs) -> bool {
2451 MS_EXCEPTION_IF_NULL(element_value_abs);
2452 if (!element_value_abs->isa<abstract::AbstractTensor>()) {
2453 return true;
2454 }
2455 // Fetch the dtype from element_value_abs of tensor.
2456 auto element_abs_tensor = element_value_abs->cast_ptr<abstract::AbstractTensor>();
2457 MS_EXCEPTION_IF_NULL(element_abs_tensor);
2458 MS_EXCEPTION_IF_NULL(element_abs_tensor->element());
2459 const auto dtype = element_abs_tensor->element()->BuildType();
2460 MS_EXCEPTION_IF_NULL(dtype);
2461 // Check default dtype if it's AbstractAny(AbstractTensor)
2462 if (element_value_abs->isa<abstract::AbstractAny>() &&
2463 !element_value_abs->cast_ptr<abstract::AbstractAny>()->supposed_tensor_dtype()) {
2464 return true;
2465 }
2466 if (res == nullptr) {
2467 MS_EXCEPTION_IF_NULL(node);
2468 MS_LOG(INFO) << "Tensor dtype found, set as unique dtype: " << dtype->ToString()
2469 << ", node: " << node->DebugString() << "\n\n"
2470 << trace::GetDebugInfoStr(node->debug_info());
2471 res = dtype;
2472 return true;
2473 }
2474 if (res != dtype) {
2475 MS_EXCEPTION_IF_NULL(node);
2476 MS_LOG(INFO) << "More than one tensor dtype found, not set unique dtype. node: " << node->DebugString() << "\n\n"
2477 << trace::GetDebugInfoStr(node->debug_info());
2478 return false;
2479 }
2480 return true;
2481 };
2482 constexpr auto values_index = 2;
2483 if (args_abs_list.size() <= values_index) {
2484 return nullptr;
2485 }
2486 const auto &values_tuple_abs = dyn_cast<AbstractSequence>(args_abs_list[values_index]);
2487 bool is_py_execute = (values_tuple_abs != nullptr);
2488 if (is_py_execute) { // PyExecute CNode.
2489 const auto &elements_abs = values_tuple_abs->elements();
2490 for (const auto &element_abs : elements_abs) {
2491 if (!unique_dtype_check(element_abs)) {
2492 return nullptr;
2493 }
2494 }
2495 } else { // PyInterpret CNode.
2496 const auto &local_dict_abs = dyn_cast<AbstractDictionary>(args_abs_list[values_index]);
2497 MS_EXCEPTION_IF_NULL(local_dict_abs);
2498 const auto &elements_abs = local_dict_abs->elements();
2499 for (const auto &element_abs_pair : elements_abs) {
2500 const auto &element_value_abs = element_abs_pair.second;
2501 if (!unique_dtype_check(element_value_abs)) {
2502 return nullptr;
2503 }
2504 }
2505 }
2506
2507 if (res != nullptr) {
2508 MS_LOG(INFO) << "Apply unique dtype: " << res->ToString() << " to node: " << node->DebugString() << "\n\n"
2509 << trace::GetDebugInfoStr(node->debug_info());
2510 }
2511 return res;
2512 }
2513
AddLabelsToPrimitiveFunction(const PrimitivePtr & prim_func)2514 void AddLabelsToPrimitiveFunction(const PrimitivePtr &prim_func) {
2515 auto prim_name = prim_func->name();
2516 py::module mod = py::module::import(parse::PYTHON_MOD_PRIMITIVE_OP_CREATE_INSTANCE_HELPER_MODULE);
2517 if (!py::hasattr(mod, parse::PYTHON_MOD_PRIMITIVE_OP_LABELS_DICT)) {
2518 MS_LOG(INTERNAL_EXCEPTION) << "Can not found " << parse::PYTHON_MOD_PRIMITIVE_OP_LABELS_DICT << " in "
2519 << parse::PYTHON_MOD_PRIMITIVE_OP_CREATE_INSTANCE_HELPER_MODULE << ".";
2520 }
2521 py::dict op_labels = mod.attr(parse::PYTHON_MOD_PRIMITIVE_OP_LABELS_DICT);
2522 if (!op_labels.contains(py::str(prim_name))) {
2523 return;
2524 }
2525 py::dict labels = op_labels[py::str(prim_name)];
2526 for (const auto &p : labels) {
2527 auto attr_name = py::cast<std::string>(p.first);
2528 auto attr_obj = py::reinterpret_borrow<py::object>(p.second);
2529 ValuePtr converted_ret = nullptr;
2530 bool converted = parse::ConvertData(attr_obj, &converted_ret);
2531 if (!converted) {
2532 MS_LOG(INTERNAL_EXCEPTION) << "Call 'add_attr' to add attribute to primitive failed,"
2533 << " convert python obj to MindSpore obj failed; primitive name: " << prim_name
2534 << ", attribute name:" << attr_name << ", attribute value:" << py::str(attr_obj)
2535 << ", attribute type:"
2536 << py::cast<std::string>(attr_obj.attr("__class__").attr("__name__"));
2537 }
2538 MS_LOG(DEBUG) << "Add attr {" << attr_name << ": " << converted_ret->ToString() << "} to " << prim_name;
2539 (void)prim_func->AddAttr(attr_name, converted_ret);
2540 }
2541 }
2542
GeneratePrimitiveDefaultArgs(const std::string & op_name,const std::vector<AnfNodePtr> & args_list,const std::vector<ops::OpInputArg> & op_args,bool check_init)2543 std::vector<AnfNodePtr> GeneratePrimitiveDefaultArgs(const std::string &op_name,
2544 const std::vector<AnfNodePtr> &args_list,
2545 const std::vector<ops::OpInputArg> &op_args, bool check_init) {
2546 size_t args_size = args_list.size();
2547 std::vector<AnfNodePtr> nodes;
2548 for (const auto &input : args_list) {
2549 if (HasAbstractMonad(input) || (IsPrimitiveCNode(input, prim::kPrimUpdateState) || IsValueNode<UMonad>(input) ||
2550 IsValueNode<IOMonad>(input))) {
2551 continue;
2552 }
2553 (void)nodes.emplace_back(input);
2554 }
2555 if (args_size < op_args.size()) {
2556 for (size_t i = args_size; i < op_args.size(); i++) {
2557 auto default_arg = parse::GetArgDefaultValue(op_name, op_args[i].arg_name_);
2558 if (default_arg == nullptr) {
2559 break;
2560 }
2561 MS_LOG(DEBUG) << "Get the default value of '" << op_args[i].arg_name_ << "' attribute of Primitive[" << op_name
2562 << "], which is " << default_arg->ToString() << ".";
2563 (void)nodes.emplace_back(NewValueNode(default_arg));
2564 }
2565 }
2566 if (nodes.size() != op_args.size()) {
2567 std::string args_type_str = check_init ? "init arguments" : "inputs";
2568 MS_EXCEPTION(TypeError) << "For Operator[" << op_name << "], the number of " << args_type_str
2569 << " (including default arguments) should be " << op_args.size()
2570 << ", but the actual number of inputs is not satisfied, which is " << args_size << ".";
2571 }
2572 return nodes;
2573 }
2574
ValidateAndConvertArgsType(const std::string & op_name,const std::vector<ops::OpInputArg> & op_args,const AbstractBasePtrList & abs_list,const FuncGraphPtr & fg,std::vector<AnfNodePtr> * nodes)2575 bool ValidateAndConvertArgsType(const std::string &op_name, const std::vector<ops::OpInputArg> &op_args,
2576 const AbstractBasePtrList &abs_list, const FuncGraphPtr &fg,
2577 std::vector<AnfNodePtr> *nodes) {
2578 bool exist_undetermined_arg = false;
2579 for (size_t i = 0; i < op_args.size(); i++) {
2580 auto op_arg = op_args[i];
2581 auto abs_arg = abs_list[i];
2582 if (abs_arg->isa<abstract::AbstractKeywordArg>()) {
2583 MS_EXCEPTION(TypeError) << "For Primitive[" << op_name
2584 << "], only positional arguments as inputs are supported, but got "
2585 << abs_arg->ToString();
2586 }
2587 if (HasAbstractUndetermined(abs_arg)) {
2588 exist_undetermined_arg = true;
2589 }
2590 if (ValidateArgOptional(abs_arg, op_arg) || ops::ValidateArgsType(abs_arg, op_arg.arg_dtype_)) {
2591 continue;
2592 }
2593 if (fallback::ContainsSequenceAnyType(abs_arg)) {
2594 continue;
2595 }
2596 bool match = false;
2597 auto cast_dtypes = op_arg.cast_dtype_;
2598 for (size_t j = 0; j < cast_dtypes.size(); j++) {
2599 if (ops::ValidateArgsType(abs_arg, cast_dtypes[j])) {
2600 (*nodes)[i] = GetNodeAfterTypeConversion((*nodes)[i], op_arg, fg);
2601 match = true;
2602 break;
2603 }
2604 }
2605 if (!match && !exist_undetermined_arg) {
2606 return false;
2607 }
2608 }
2609 return true;
2610 }
2611
BuilidArgsTypeString(const AbstractBasePtr & arg_abs)2612 std::string BuilidArgsTypeString(const AbstractBasePtr &arg_abs) {
2613 auto arg_type = arg_abs->BuildType();
2614 MS_EXCEPTION_IF_NULL(arg_type);
2615 if (arg_type->isa<Bool>()) {
2616 return "bool";
2617 }
2618 if (arg_type->isa<Int>() || arg_type->isa<UInt>()) {
2619 return "int";
2620 }
2621 if (arg_type->isa<Float>() || arg_type->isa<BFloat>()) {
2622 return "float";
2623 }
2624 if (arg_type->isa<String>()) {
2625 return "string";
2626 }
2627 if (arg_type->isa<TypeNone>()) {
2628 return "None";
2629 }
2630 if (arg_type->isa<TensorType>()) {
2631 return "Tensor";
2632 }
2633 if (arg_type->isa<Tuple>() || arg_type->isa<List>()) {
2634 auto seq_abs = arg_abs->cast_ptr<abstract::AbstractSequence>();
2635 MS_EXCEPTION_IF_NULL(seq_abs);
2636 std::string seq_type = arg_type->isa<Tuple>() ? "tuple" : "list";
2637 if (seq_abs->dynamic_len()) {
2638 return seq_type;
2639 }
2640 std::stringstream ss;
2641 ss << seq_type << "<";
2642 for (size_t i = 0; i < seq_abs->size(); i++) {
2643 if (i == 0) {
2644 ss << BuilidArgsTypeString(seq_abs->elements()[i]);
2645 } else {
2646 ss << ", " << BuilidArgsTypeString(seq_abs->elements()[i]);
2647 }
2648 }
2649 ss << ">";
2650 return ss.str();
2651 }
2652 return arg_type->ToString();
2653 }
2654
CheckAndConvertPrimitiveArgs(const PrimitivePtr & prim,const FuncGraphPtr & graph,const std::pair<std::vector<AnfNodePtr>,std::vector<AnfNodePtr>> & args_pair,const std::function<AbstractBasePtr (const AnfNodePtr &)> & eval_func,bool is_preprocessed)2655 CNodePtr CheckAndConvertPrimitiveArgs(const PrimitivePtr &prim, const FuncGraphPtr &graph,
2656 const std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> &args_pair,
2657 const std::function<AbstractBasePtr(const AnfNodePtr &)> &eval_func,
2658 bool is_preprocessed) {
2659 auto init_args_list = args_pair.first;
2660 auto call_args_list = args_pair.second;
2661 auto prim_name = prim->name();
2662 auto op_def = mindspore::ops::GetOpDef(prim_name);
2663 MS_EXCEPTION_IF_NULL(op_def);
2664 MS_EXCEPTION_IF_NULL(graph);
2665 // Check args size.
2666 std::vector<ops::OpInputArg> op_call_args;
2667 std::vector<ops::OpInputArg> op_init_args;
2668 auto op_args = op_def->args_;
2669 for (const auto &op_arg : op_args) {
2670 if (op_arg.as_init_arg_) {
2671 (void)op_init_args.emplace_back(op_arg);
2672 } else {
2673 (void)op_call_args.emplace_back(op_arg);
2674 }
2675 }
2676
2677 MS_LOG(DEBUG) << "For Primitive[" << prim_name << "], the number of init args is expected to be "
2678 << op_init_args.size() << ", and the number of call args is expected to be " << op_call_args.size();
2679 // Generate primitive default args.
2680 MS_LOG(DEBUG) << "For Primitive[ " << prim_name << "], before processing default args, the number of init args is "
2681 << init_args_list.size() << " and the number of call args is " << call_args_list.size();
2682 auto call_nodes = GeneratePrimitiveDefaultArgs(prim_name, call_args_list, op_call_args, false);
2683 auto init_nodes = GeneratePrimitiveDefaultArgs(prim_name, init_args_list, op_init_args, true);
2684 MS_LOG(DEBUG) << "For Primitive[ " << prim_name << "], after processing default args, the number of init args is "
2685 << init_args_list.size() << " and the number of call args is " << call_args_list.size();
2686 // If it is not preprocessed, signatures and need to be processed.
2687 if (!is_preprocessed) {
2688 // Process signatures.
2689 MS_LOG(DEBUG) << "Process signatures for Primitive[" << prim_name << "].";
2690 AbstractBasePtrList call_abs_list;
2691 (void)std::transform(call_nodes.cbegin(), call_nodes.cend(), std::back_inserter(call_abs_list), eval_func);
2692 call_nodes = prim::GetNewInputsBySignatures(graph, prim_name, prim, call_abs_list, call_nodes);
2693 // Process arg_handler.
2694 for (size_t i = 0; i < op_init_args.size(); i++) {
2695 auto abs_node = eval_func(init_nodes[i]);
2696 init_nodes[i] = GetNodeAfterArgHandler(init_nodes[i], prim_name, op_init_args[i], abs_node, graph);
2697 }
2698 }
2699 for (size_t i = 0; i < op_call_args.size(); i++) {
2700 auto abs_node = eval_func(call_nodes[i]);
2701 call_nodes[i] = GetNodeAfterArgHandler(call_nodes[i], prim_name, op_call_args[i], abs_node, graph);
2702 }
2703
2704 // Check args type and do type conversion.
2705 AbstractBasePtrList call_abs_list;
2706 AbstractBasePtrList init_abs_list;
2707 (void)std::transform(call_nodes.cbegin(), call_nodes.cend(), std::back_inserter(call_abs_list), eval_func);
2708 (void)std::transform(init_nodes.cbegin(), init_nodes.cend(), std::back_inserter(init_abs_list), eval_func);
2709 MS_LOG(DEBUG) << "For Primitive[" << prim_name << "], the number of init args is " << init_nodes.size()
2710 << " and the number of call args is " << call_nodes.size();
2711 if (!ValidateAndConvertArgsType(prim_name, op_call_args, call_abs_list, graph, &call_nodes) ||
2712 !ValidateAndConvertArgsType(prim_name, op_init_args, init_abs_list, graph, &init_nodes)) {
2713 std::vector<std::string> op_type_list;
2714 (void)std::transform(call_abs_list.cbegin(), call_abs_list.cend(), std::back_inserter(op_type_list),
2715 [](const AbstractBasePtr &op_abs) { return BuilidArgsTypeString(op_abs); });
2716 (void)std::transform(init_abs_list.cbegin(), init_abs_list.cend(), std::back_inserter(op_type_list),
2717 [](const AbstractBasePtr &op_abs) { return BuilidArgsTypeString(op_abs); });
2718 MS_EXCEPTION(TypeError) << ops::BuildOpErrorMsg(op_def, op_type_list);
2719 }
2720
2721 // Create New node.
2722 AnfNodePtrList input_nodes{NewValueNode(prim)};
2723 (void)std::copy(call_nodes.cbegin(), call_nodes.cend(), std::back_inserter(input_nodes));
2724 (void)std::copy(init_nodes.cbegin(), init_nodes.cend(), std::back_inserter(input_nodes));
2725 auto new_cnode = graph->NewCNodeInOrder(input_nodes);
2726 return new_cnode;
2727 }
2728
CheckAndConvertPrimitiveArgs(const PrimitivePtr & prim,const std::pair<std::vector<AnfNodePtr>,std::vector<AnfNodePtr>> & args_pair,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf,bool is_preprocessed)2729 AnfNodePtr CheckAndConvertPrimitiveArgs(const PrimitivePtr &prim,
2730 const std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> &args_pair,
2731 const AnalysisEnginePtr &engine, const AnfNodeConfigPtr &out_conf,
2732 bool is_preprocessed) {
2733 auto graph = out_conf->node()->func_graph();
2734 MS_EXCEPTION_IF_NULL(graph);
2735
2736 auto eval_func = [&engine, &out_conf](const AnfNodePtr &node) {
2737 AnfNodeConfigPtr config = engine->MakeConfig(node, out_conf->context(), out_conf->func_graph());
2738 MS_EXCEPTION_IF_NULL(config);
2739 const auto &eval_result = config->ObtainEvalResult();
2740 MS_EXCEPTION_IF_NULL(eval_result);
2741 return eval_result->abstract();
2742 };
2743
2744 auto new_cnode = CheckAndConvertPrimitiveArgs(prim, graph, args_pair, eval_func, is_preprocessed);
2745 MS_LOG(INFO) << "Convert primitive args: " << prim->name() << ". node: " << out_conf->node()->DebugString()
2746 << ", new_node: " << new_cnode->DebugString();
2747 return new_cnode;
2748 }
2749
ConvertArgsToInputs(const PrimitivePtr & prim,const AnfNodeWeakPtrList & inputs,const FuncGraphPtr & fg,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf)2750 AnfNodePtr ConvertArgsToInputs(const PrimitivePtr &prim, const AnfNodeWeakPtrList &inputs, const FuncGraphPtr &fg,
2751 const AnalysisEnginePtr &engine, const AnfNodeConfigPtr &out_conf) {
2752 // Append Primitive arguments to the inputs.
2753 auto prim_py = prim->cast<PrimitivePyPtr>();
2754 MS_EXCEPTION_IF_NULL(prim_py);
2755 auto op_def = mindspore::ops::GetOpDef(prim->name());
2756 MS_EXCEPTION_IF_NULL(op_def);
2757 // Get init args.
2758 const AnfNodePtrList &prim_init_arg_nodes = GetPrimitiveInitArgs(prim_py, op_def);
2759
2760 // Get call args.
2761 AnfNodePtrList prim_call_arg_nodes;
2762 (void)std::transform(inputs.cbegin() + 1, inputs.cend(), std::back_inserter(prim_call_arg_nodes),
2763 [](const AnfNodeWeakPtr &weak_node) {
2764 const auto &node = weak_node.lock();
2765 MS_EXCEPTION_IF_NULL(node);
2766 return node;
2767 });
2768 // Create new node.
2769 auto new_prim = std::make_shared<Primitive>(*prim);
2770 auto args_pair = std::make_pair(prim_init_arg_nodes, prim_call_arg_nodes);
2771 return CheckAndConvertPrimitiveArgs(new_prim, args_pair, engine, out_conf, true);
2772 }
2773 } // namespace
2774
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)2775 EvalResultPtr PrimitiveArgsToInputsEvaluator::EvalPrim(const AnalysisEnginePtr &engine,
2776 const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
2777 const AnfNodeConfigPtr &out_conf) {
2778 // Convert primitive args to inputs.
2779 MS_EXCEPTION_IF_NULL(out_conf);
2780 auto cnode = out_conf->node()->cast<CNodePtr>();
2781 MS_EXCEPTION_IF_NULL(cnode);
2782 auto fg = cnode->func_graph();
2783 MS_EXCEPTION_IF_NULL(fg);
2784
2785 constexpr size_t index_op = 0;
2786 constexpr size_t index_data = 1;
2787 auto op_node = cnode->input(index_op);
2788 AnfNodePtr new_node = nullptr;
2789 parse::SymbolPtr symbol_node = nullptr;
2790 if (op_node->isa<CNode>()) {
2791 auto inner_op_node = op_node->cast<CNodePtr>()->input(index_op);
2792 if (IsPrimitiveCNode(inner_op_node, prim::kPrimResolve)) {
2793 auto resolve_node = inner_op_node->cast<CNodePtr>();
2794 constexpr size_t index_symbol = 2;
2795 symbol_node = GetValueNode<parse::SymbolPtr>(resolve_node->input(index_symbol));
2796 }
2797 }
2798 if (IsPrimitiveCNode(op_node, prim::kPrimPartial)) {
2799 // The input may be a Partial node, such as {{prim::kPrimPartial, prim::kPrimRank, x}} -> {prim::kPrimRank, x}.
2800 AnfNodeWeakPtrList partial_inputs;
2801 auto op_cnode = op_node->cast<CNodePtr>();
2802 (void)std::copy(op_cnode->weak_inputs().begin() + index_data, op_cnode->weak_inputs().end(),
2803 std::back_inserter(partial_inputs));
2804 (void)std::copy(cnode->weak_inputs().begin() + index_data, cnode->weak_inputs().end(),
2805 std::back_inserter(partial_inputs));
2806 new_node = ConvertArgsToInputs(prim_, partial_inputs, fg, engine, out_conf);
2807 } else if (IsPrimitiveCNode(op_node, prim::kPrimGetAttr) ||
2808 IsPrimitiveCNodeWithoutDoSignature(op_node, prim::kPrimGetAttr) ||
2809 (symbol_node != nullptr && symbol_node->symbol() == "getattr")) {
2810 // The input may be a GetAttr node, such as x.abs(): {{prim::kPrimGetAttr, x, abs}} -> {prim::kPrimAbs, x}
2811 auto op_cnode = op_node->cast<CNodePtr>();
2812 AnfNodeWeakPtrList getattr_inputs;
2813 auto new_prim = std::make_shared<Primitive>(prim_->name());
2814 auto new_prim_node = NewValueNode(new_prim);
2815 (void)getattr_inputs.emplace_back(new_prim_node);
2816 (void)getattr_inputs.emplace_back(op_cnode->input(index_data));
2817 (void)std::copy(cnode->weak_inputs().begin() + index_data, cnode->weak_inputs().end(),
2818 std::back_inserter(getattr_inputs));
2819 new_node = ConvertArgsToInputs(prim_, getattr_inputs, fg, engine, out_conf);
2820 } else {
2821 constexpr int recursive_level = 2;
2822 new_node = ConvertArgsToInputs(prim_, cnode->weak_inputs(), fg, engine, out_conf);
2823 MS_LOG(DEBUG) << "Convert args to inputs for Operator[" << prim_->name()
2824 << "], node: " << cnode->DebugString(recursive_level);
2825 }
2826
2827 new_node->set_debug_info(cnode->debug_info());
2828 auto new_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
2829 MS_LOG(INFO) << "Convert primitive args to inputs: " << prim_->ToString() << ". node: " << cnode->DebugString()
2830 << ", new node: " << new_node->DebugString();
2831 return engine->ForwardConfig(out_conf, new_conf);
2832 }
2833
2834 namespace {
ConvertWeakNode(const AnfNodeWeakPtr & weak_node)2835 AnfNodePtr ConvertWeakNode(const AnfNodeWeakPtr &weak_node) {
2836 const auto &node = weak_node.lock();
2837 MS_EXCEPTION_IF_NULL(node);
2838 return node;
2839 }
2840 } // namespace
2841
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)2842 EvalResultPtr DoTransPrimitiveFunctionEvaluator::EvalPrim(const AnalysisEnginePtr &engine,
2843 const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
2844 const AnfNodeConfigPtr &out_conf) {
2845 // For PrimitiveFunction generated by CreateInstance, its args, labels, signatures and
2846 // implicit conversion need to be processed.
2847 auto do_trans_prim_func = prim_->cast<prim::DoTransPrimitiveFunctionPtr>();
2848 MS_EXCEPTION_IF_NULL(do_trans_prim_func);
2849 auto prim_func = do_trans_prim_func->function();
2850 MS_EXCEPTION_IF_NULL(prim_func);
2851 auto cnode = out_conf->node()->cast<CNodePtr>();
2852 MS_EXCEPTION_IF_NULL(cnode);
2853 auto fg = cnode->func_graph();
2854 MS_EXCEPTION_IF_NULL(fg);
2855
2856 auto prim_name = prim_func->name();
2857 auto op_def = mindspore::ops::GetOpDef(prim_name);
2858 if (op_def == nullptr) {
2859 MS_LOG(INTERNAL_EXCEPTION) << "DoTransPrimitiveFunction only supports Primitive with OpDef, but got " << prim_name
2860 << ".";
2861 }
2862 if (cnode->size() != args_abs_list.size() + 1) {
2863 MS_LOG(INTERNAL_EXCEPTION) << "For Operator[" << prim_name << "], the number of cnode inputs should be "
2864 << args_abs_list.size() + 1 << ", but got " << cnode->size()
2865 << ".\nnode: " << cnode->DebugString();
2866 }
2867 // Handle primitive labels.
2868 AddLabelsToPrimitiveFunction(prim_func);
2869 // Handle primitive signatures.
2870 auto arg_signatures = op_def->signatures_;
2871 prim_func->set_signatures(arg_signatures);
2872 prim_func->set_has_signature(!arg_signatures.empty());
2873 // Get init args size.
2874 size_t init_args_size = 0;
2875 if (do_trans_prim_func->has_given_init_size()) {
2876 // Might need to handle default arguments.
2877 init_args_size = do_trans_prim_func->given_init_size();
2878 } else {
2879 // All call args and init args should have been provided.
2880 size_t op_args_size = op_def->args_.size();
2881 if (op_args_size != args_abs_list.size()) {
2882 MS_EXCEPTION(TypeError) << "For Operator['" << prim_name
2883 << "]', the number of inputs and init args (including default arguments) should be "
2884 << op_args_size << ", but got " << args_abs_list.size() << ". ";
2885 }
2886 for (size_t i = 0; i < op_args_size; i++) {
2887 if (op_def->args_[i].as_init_arg_) {
2888 ++init_args_size;
2889 }
2890 }
2891 }
2892
2893 // Get init args and call args.
2894 AnfNodePtrList prim_init_arg_nodes;
2895 (void)std::transform(cnode->weak_inputs().cbegin() + cnode->size() - init_args_size, cnode->weak_inputs().cend(),
2896 std::back_inserter(prim_init_arg_nodes), ConvertWeakNode);
2897 AnfNodePtrList prim_call_arg_nodes;
2898 (void)std::transform(cnode->weak_inputs().cbegin() + 1, cnode->weak_inputs().cend() - init_args_size,
2899 std::back_inserter(prim_call_arg_nodes), ConvertWeakNode);
2900
2901 auto args_pair = std::make_pair(prim_init_arg_nodes, prim_call_arg_nodes);
2902 auto new_cnode = CheckAndConvertPrimitiveArgs(prim_func, args_pair, engine, out_conf, false);
2903 auto new_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
2904 MS_LOG(INFO) << "Prim: " << prim_func->name() << ", " << cnode->DebugString() << ", " << new_cnode->DebugString();
2905 return engine->ForwardConfig(out_conf, new_conf);
2906 }
2907
GetInitArgsFromUnpackCall(const prim::DoTransPrimitiveFunctionPtr & do_trans_prim,const CNodePtr & unpack_call_cnode,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf)2908 AnfNodePtrList GetInitArgsFromUnpackCall(const prim::DoTransPrimitiveFunctionPtr &do_trans_prim,
2909 const CNodePtr &unpack_call_cnode, const AnalysisEnginePtr &engine,
2910 const AnfNodeConfigPtr &out_conf) {
2911 auto prim = do_trans_prim->function();
2912 auto op_def = mindspore::ops::GetOpDef(prim->name());
2913 MS_EXCEPTION_IF_NULL(op_def);
2914
2915 AnfNodePtrList new_inputs;
2916 std::map<std::string, AnfNodePtr> key_map;
2917 auto fg = out_conf->node()->func_graph();
2918 constexpr size_t inputs_start_index = 2;
2919 for (size_t index = inputs_start_index; index < unpack_call_cnode->size(); index++) {
2920 auto input = unpack_call_cnode->input(index);
2921 AnfNodeConfigPtr config = engine->MakeConfig(input, out_conf->context(), out_conf->func_graph());
2922 MS_EXCEPTION_IF_NULL(config);
2923 const auto &eval_result = config->ObtainEvalResult();
2924 MS_EXCEPTION_IF_NULL(eval_result);
2925 auto input_abs = eval_result->abstract();
2926 if (input_abs->isa<AbstractDictionary>()) {
2927 auto dict_elems = input_abs->cast<AbstractDictionaryPtr>()->elements();
2928 for (const auto &elem : dict_elems) {
2929 auto key = GetValue<std::string>(elem.first->BuildValue());
2930 auto elem_value = fg->NewCNode({NewValueNode(prim::kPrimDictGetItem), input, NewValueNode(key)});
2931 key_map[key] = elem_value;
2932 }
2933 } else if (input_abs->isa<AbstractTuple>()) {
2934 auto arg_tuple = input_abs->cast<AbstractTuplePtr>();
2935 for (size_t i = 0; i < arg_tuple->size(); ++i) {
2936 MS_LOG(DEBUG) << "Get args for Primitive[" << prim->name() << "]: " << input->DebugString() << ", i: " << i;
2937 (void)new_inputs.emplace_back(
2938 fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(SizeToLong(i))}));
2939 }
2940 } else if (input_abs->isa<AbstractList>()) {
2941 auto arg_list = input_abs->cast<AbstractListPtr>();
2942 for (size_t i = 0; i < arg_list->size(); ++i) {
2943 MS_LOG(DEBUG) << "Get args for Primitive[" << prim->name() << "]: " << input->DebugString() << ", i: " << i;
2944 (void)new_inputs.emplace_back(
2945 fg->NewCNode({NewValueNode(prim::kPrimListGetItem), input, NewValueNode(SizeToLong(i))}));
2946 }
2947 } else {
2948 MS_LOG(INTERNAL_EXCEPTION) << "The arguments of UnpackCall operator should be tuple, list or dict, but got "
2949 << input_abs->ToString();
2950 }
2951 }
2952
2953 // Handle variable arguments.
2954 auto op_args = op_def->args_;
2955 auto inputs_size = new_inputs.size();
2956 size_t index = 0;
2957 size_t init_args_num = 0;
2958 for (const auto &op_arg : op_args) {
2959 if (!(op_arg.as_init_arg_)) {
2960 continue;
2961 }
2962 init_args_num++;
2963 if (index < inputs_size) {
2964 index++;
2965 continue;
2966 }
2967 auto arg_name = op_arg.arg_name_;
2968 auto iter = key_map.find(arg_name);
2969 if (iter != key_map.end()) {
2970 MS_LOG(DEBUG) << "Get args for Primitive[" << prim->name() << "]: " << iter->second->DebugString();
2971 (void)new_inputs.emplace_back(iter->second);
2972 (void)key_map.erase(arg_name);
2973 } else {
2974 auto default_value = parse::GetArgDefaultValue(prim->name(), arg_name);
2975 if (default_value == nullptr) {
2976 MS_EXCEPTION(TypeError) << "For Operator[" << prim->name() << "], there is no matching input for argument '"
2977 << arg_name << "'.";
2978 }
2979 MS_LOG(DEBUG) << "Get args for Primitive[" << prim->name() << "]: " << default_value->ToString();
2980 (void)new_inputs.emplace_back(NewValueNode(default_value));
2981 }
2982 }
2983 if (init_args_num < new_inputs.size()) {
2984 MS_EXCEPTION(TypeError) << "For Operator[" << prim->name() << "], the number of init arguments should be "
2985 << init_args_num << ", but got " << new_inputs.size() << ".";
2986 }
2987 if (!key_map.empty()) {
2988 std::stringstream ss;
2989 ss << "For Operator[" << prim->name() << "], there are unmatched arguments: ";
2990 for (const auto &elem : key_map) {
2991 ss << elem.first << " ";
2992 }
2993 ss << ".";
2994 MS_EXCEPTION(TypeError) << ss.str();
2995 }
2996 do_trans_prim->set_given_init_size(new_inputs.size());
2997 return new_inputs;
2998 }
2999
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3000 EvalResultPtr PartialToEndEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
3001 const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
3002 // Convert Partial{Prim, a, b}(x, y) to {Prim, x, y, a, b}.
3003 auto prim = primal_func_->BuildValue();
3004 MS_EXCEPTION_IF_NULL(prim);
3005 AnfNodePtrList new_inputs{NewValueNode(prim)};
3006 auto do_trans_prim = prim->cast<prim::DoTransPrimitiveFunctionPtr>();
3007 MS_EXCEPTION_IF_NULL(do_trans_prim);
3008 // Add inputs: x, y.
3009 MS_EXCEPTION_IF_NULL(out_conf);
3010 auto cnode = out_conf->node()->cast<CNodePtr>();
3011 MS_EXCEPTION_IF_NULL(cnode);
3012 for (size_t i = 1; i < cnode->size(); i++) {
3013 (void)new_inputs.emplace_back(cnode->input(i));
3014 }
3015 // Add args: a, b.
3016 constexpr size_t op_index = 0;
3017 auto partial_node = cnode->input(op_index);
3018 MS_EXCEPTION_IF_NULL(partial_node);
3019 auto partial_cnode = partial_node->cast<CNodePtr>();
3020 if (partial_cnode == nullptr) {
3021 MS_EXCEPTION(TypeError) << "For Primitive[" << prim->ToString()
3022 << "], only positional arguments as inputs are supported, but got "
3023 << partial_node->DebugString() << ".";
3024 }
3025 if (IsValueNode<prim::UnpackCall>(partial_cnode->input(op_index))) {
3026 auto unpack_call_args = GetInitArgsFromUnpackCall(do_trans_prim, partial_cnode, engine, out_conf);
3027 (void)std::copy(unpack_call_args.cbegin(), unpack_call_args.cend(), std::back_inserter(new_inputs));
3028 } else {
3029 (void)std::transform(partial_cnode->weak_inputs().cbegin() + 1, partial_cnode->weak_inputs().cend(),
3030 std::back_inserter(new_inputs), [](const auto &weak_node) {
3031 const auto &node = weak_node.lock();
3032 MS_EXCEPTION_IF_NULL(node);
3033 return node;
3034 });
3035 }
3036
3037 auto fg = cnode->func_graph();
3038 MS_EXCEPTION_IF_NULL(fg);
3039 auto new_cnode = fg->NewCNodeInOrder(new_inputs);
3040 auto new_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
3041 constexpr auto recursive_level = 2;
3042 MS_LOG(INFO) << "For Primitive[" << prim->ToString() << "], convert partial node "
3043 << cnode->DebugString(recursive_level) << " to new cnode " << new_cnode->DebugString(recursive_level);
3044 return engine->ForwardConfig(out_conf, new_conf);
3045 }
3046
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3047 EvalResultPtr ConstexprEvaluator::EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
3048 const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
3049 // Consider all primitive implemented python infer() real use the tuple/list arguments.
3050 CheckSequenceArgumentForPythonPrimitive(prim_py_, args_abs_list);
3051 MS_EXCEPTION_IF_NULL(prim_py_);
3052 auto py_args = PreparePyInputs(args_abs_list);
3053 prim_py_->BeginRecordAddAttr();
3054 py::dict output = prim_py_->RunInfer(py_args);
3055 prim_py_->EndRecordAddAttr();
3056 if (output.contains("fn")) {
3057 // The inputs contain variable, the constexpr will run as graph.
3058 py::tuple values = output["fn"];
3059 if (values.empty()) {
3060 MS_LOG(EXCEPTION) << "Can not get origin function from constexpr.";
3061 }
3062 auto inner_val = parse::ParsePythonCode(values[0]);
3063 MS_EXCEPTION_IF_NULL(inner_val);
3064 auto inner_fg = dyn_cast<FuncGraph>(inner_val);
3065 MS_EXCEPTION_IF_NULL(inner_fg);
3066 MS_EXCEPTION_IF_NULL(out_conf);
3067 auto cur_graph = out_conf->func_graph();
3068 MS_EXCEPTION_IF_NULL(cur_graph);
3069 auto mng = cur_graph->manager();
3070 MS_EXCEPTION_IF_NULL(mng);
3071 inner_fg->set_manager(mng);
3072 auto out_node = out_conf->node();
3073 MS_EXCEPTION_IF_NULL(out_node);
3074 auto out_cnode = dyn_cast<CNode>(out_node);
3075 MS_EXCEPTION_IF_NULL(out_cnode);
3076 FuncGraphPtr func_graph = out_node->func_graph();
3077 MS_EXCEPTION_IF_NULL(func_graph);
3078 std::vector<AnfNodePtr> new_cnode_inputs = {NewValueNode(inner_fg)};
3079 const auto &out_cnode_inputs = out_cnode->weak_inputs();
3080 (void)std::transform(out_cnode_inputs.cbegin() + 1, out_cnode_inputs.cend(), std::back_inserter(new_cnode_inputs),
3081 [](const auto &weak_node) {
3082 const auto &node = weak_node.lock();
3083 MS_EXCEPTION_IF_NULL(node);
3084 return node;
3085 });
3086 auto new_node = func_graph->NewCNodeInOrder(new_cnode_inputs);
3087 AnalysisEnginePtr eng = out_conf->engine();
3088 MS_EXCEPTION_IF_NULL(eng);
3089 AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
3090 return eng->ForwardConfig(out_conf, fn_conf);
3091 }
3092 // If all inputs are constant value, use python prim evaluator.
3093 // Ensure input arguments are evaluated.
3094 auto res_abstract = EvalUndeterminedArgs(args_abs_list);
3095 if (res_abstract != nullptr) {
3096 MS_LOG(DEBUG) << "PythonPrimEvaluator eval Undetermined";
3097 return res_abstract;
3098 }
3099 auto forbid_reuse = prim_py_->HasAttr(GRAPH_FLAG_FORBID_REUSE_RESULT);
3100 if (!forbid_reuse) {
3101 // Try to get infer result from evaluator cache.
3102 EvalResultPtr eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
3103 if (eval_result != nullptr) {
3104 MS_EXCEPTION_IF_NULL(eval_result->abstract());
3105 return std::make_shared<EvalResult>(eval_result->abstract()->Clone(), eval_result->attribute());
3106 }
3107 }
3108 const auto &added_attrs = prim_py_->evaluate_added_attrs();
3109 MS_LOG(DEBUG) << "Output type is " << py::str(output);
3110 auto res_abs = PyInferRes2Abstract(prim_py_, output);
3111 MS_EXCEPTION_IF_NULL(res_abs);
3112 MS_LOG(DEBUG) << "Python InferTensor result abstract: " << res_abs->ToString();
3113 EvalResultPtr eval_result = std::make_shared<EvalResult>(res_abs, std::make_shared<AttrValueMap>(added_attrs));
3114 evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
3115 return eval_result;
3116 }
3117
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3118 EvalResultPtr MakeTupleEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
3119 const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
3120 std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
3121 auto abs = std::make_shared<AbstractTuple>(args_abs_list, sequence_nodes);
3122 if (out_conf != nullptr) { // 'out_conf' maybe nullptr in PyNative mode.
3123 if (args_abs_list.empty()) {
3124 MS_EXCEPTION_IF_NULL(out_conf->node());
3125 MS_LOG(INFO) << "For MakeTuple, the inputs should not be empty. node: " << out_conf->node()->DebugString();
3126 }
3127 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
3128 if (enable_eliminate_unused_element) {
3129 auto flags = GetSequenceNodeElementsUseFlags(out_conf->node());
3130 if (flags == nullptr) {
3131 SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_abs_list.size()));
3132 }
3133 bool has_any = fallback::ContainsSequenceAnyType(abs);
3134 if (has_any) {
3135 SetSequenceElementsUseFlagsRecursively(abs, true);
3136 }
3137 (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(out_conf->node()));
3138 }
3139 }
3140 auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
3141 evaluator_cache_mgr_->SetValue(args_abs_list, res);
3142 // pass the need_unpack tag from the AnfNode to the abstract
3143 if (out_conf != nullptr) {
3144 auto node = out_conf->node();
3145 constexpr auto need_unpack_str = "need_unpack";
3146 auto need_unpack = node->user_data<bool>(need_unpack_str);
3147 if (need_unpack != nullptr && *need_unpack) {
3148 abs->SetData<bool>(need_unpack_str, std::make_shared<bool>(true));
3149 }
3150 }
3151 return res;
3152 }
3153
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3154 EvalResultPtr MakeListEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
3155 const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
3156 std::shared_ptr<AnfNodeWeakPtrList> sequence_nodes = std::make_shared<AnfNodeWeakPtrList>();
3157 auto abs = std::make_shared<AbstractList>(args_abs_list, sequence_nodes);
3158 if (out_conf != nullptr) { // 'out_conf' maybe nullptr in PyNative mode.
3159 if (args_abs_list.empty()) {
3160 MS_LOG(INFO) << "For MakeList, the inputs should not be empty. node: " << out_conf->node()->DebugString();
3161 }
3162 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
3163 if (enable_eliminate_unused_element) {
3164 auto flags = GetSequenceNodeElementsUseFlags(out_conf->node());
3165 if (flags == nullptr) {
3166 SetSequenceNodeElementsUseFlags(out_conf->node(), std::make_shared<std::vector<bool>>(args_abs_list.size()));
3167 }
3168
3169 (void)sequence_nodes->emplace_back(AnfNodeWeakPtr(out_conf->node()));
3170 bool has_any = fallback::ContainsSequenceAnyType(abs);
3171 if (has_any) {
3172 SetSequenceElementsUseFlagsRecursively(abs, true);
3173 }
3174 }
3175 }
3176 MS_LOG(DEBUG) << "Generate python object for new value node.";
3177 if (fallback::EnableFallbackListDictInplace()) {
3178 py::object py_list_obj = fallback::GeneratePyObj(abs);
3179 fallback::AttachPyObjToAbs(abs, py_list_obj, true);
3180 }
3181 auto res = std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
3182 evaluator_cache_mgr_->SetValue(args_abs_list, res);
3183 return res;
3184 }
3185
CreateRealAbstract(const TypePtr & preset_type,const BaseShapePtr & shape,const AnfNodePtr & node,const AbstractBasePtrList & args_abs_list)3186 AbstractBasePtr CreateRealAbstract(const TypePtr &preset_type, const BaseShapePtr &shape, const AnfNodePtr &node,
3187 const AbstractBasePtrList &args_abs_list) {
3188 AbstractBasePtr res = nullptr;
3189 if (preset_type->isa<Scalar>()) {
3190 res = std::make_shared<AbstractScalar>(preset_type);
3191 } else if (preset_type->isa<List>() || preset_type->isa<Tuple>()) {
3192 res = fallback::GenerateAbstractSequence(shape, preset_type, true);
3193 } else if (preset_type->isa<TensorType>() && !preset_type->isa<AnyType>()) {
3194 auto tensor_type = preset_type->cast_ptr<TensorType>();
3195 MS_EXCEPTION_IF_NULL(tensor_type);
3196 auto element = std::make_shared<abstract::AbstractScalar>(kValueAny, tensor_type->element());
3197 res = std::make_shared<abstract::AbstractTensor>(element, shape);
3198 auto abs_tensor = res->cast_ptr<abstract::AbstractTensor>();
3199 if (node->has_user_data(fallback::kIsAdapter)) {
3200 abs_tensor->set_is_adapter(true);
3201 }
3202 } else {
3203 const auto any_abstract = std::make_shared<AbstractAny>();
3204 // If no annotation dtype, try to use unique tensor dtype.
3205 auto dtype = GetLocalArgsUniqueDtype(node, args_abs_list);
3206 if (dtype != nullptr) {
3207 MS_EXCEPTION_IF_NULL(any_abstract->element());
3208 any_abstract->element()->set_type(dtype);
3209 any_abstract->set_supposed_tensor_dtype(true);
3210 }
3211 res = any_abstract;
3212 }
3213 return res;
3214 }
3215
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3216 EvalResultPtr PyExecuteEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list,
3217 const ConfigPtr &, const AnfNodeConfigPtr &out_conf) {
3218 MS_EXCEPTION_IF_NULL(out_conf);
3219 if (args_abs_list.empty()) {
3220 MS_LOG(INTERNAL_EXCEPTION) << "'args_abs_list' should not be empty";
3221 }
3222
3223 // Handle for DDE.
3224 for (size_t i = 0; i < args_abs_list.size(); ++i) {
3225 MS_EXCEPTION_IF_NULL(args_abs_list[i]);
3226 if (args_abs_list[i]->isa<abstract::AbstractSequence>()) {
3227 MS_LOG(DEBUG) << "Primitive \'PyExecute\' is consuming tuple/list arguments[" << i
3228 << "]: " << args_abs_list[i]->ToString();
3229 SetSequenceElementsUseFlagsRecursively(args_abs_list[i], true);
3230 }
3231 }
3232
3233 auto node = out_conf->node();
3234 MS_EXCEPTION_IF_NULL(node);
3235 MS_LOG(DEBUG) << "The current pyexecute node: " << node->DebugString();
3236 // Get the type parameter.
3237 MS_EXCEPTION_IF_NULL(args_abs_list[0]);
3238 ValuePtr script_value_track = args_abs_list[0]->GetValueTrack();
3239 MS_EXCEPTION_IF_NULL(script_value_track);
3240 auto script_obj = dyn_cast_ptr<StringImm>(script_value_track);
3241 if (script_obj == nullptr) {
3242 MS_LOG(INTERNAL_EXCEPTION) << "Cast value failed, not PyObjectWrapper: " << script_value_track->ToString() << ".";
3243 }
3244
3245 // Make global and local parameters.
3246 const std::string &script = script_obj->value();
3247 // Call python script string.
3248 MS_LOG(DEBUG) << "Call script: " << script << ", args: " << args_abs_list;
3249 // Make abstract by type and shape.
3250 AbstractBasePtr res = nullptr;
3251 // Support Tensor annotation type. Add list and tuple here later.
3252 TypePtr dtype = nullptr;
3253 TypePtr type = GetAnnotationType(node, args_abs_list);
3254 if (type != nullptr && type->isa<TensorType>()) {
3255 dtype = type->cast<TensorTypePtr>()->element();
3256 }
3257 // Create output abstract.
3258 if (dtype != nullptr) {
3259 res = std::make_shared<AbstractTensor>(dtype, std::make_shared<Shape>(ShapeVector({Shape::kShapeRankAny})));
3260 } else if (fallback::HasRealType(node) && fallback::HasRealShape(node)) {
3261 const auto &preset_type = fallback::GetRealType<AnfNode, Type>(node);
3262 MS_LOG(DEBUG) << "preset_type: " << preset_type->ToString();
3263 const auto &shape = fallback::GetRealShape<AnfNode, BaseShape>(node);
3264 MS_LOG(DEBUG) << "shape: " << shape->ToString();
3265 res = CreateRealAbstract(preset_type, shape, node, args_abs_list);
3266 } else if (fallback::HasRealType(node) && fallback::GetRealType<AnfNode, Type>(node)->isa<NegligibleType>()) {
3267 res = std::make_shared<AbstractNegligible>();
3268 } else {
3269 const auto any_abstract = std::make_shared<AbstractAny>();
3270 // If no annotation dtype, try to use unique tensor dtype.
3271 dtype = GetLocalArgsUniqueDtype(node, args_abs_list);
3272 if (dtype != nullptr) {
3273 MS_EXCEPTION_IF_NULL(any_abstract->element());
3274 any_abstract->element()->set_type(dtype);
3275 any_abstract->set_supposed_tensor_dtype(true);
3276 }
3277 res = any_abstract;
3278 }
3279
3280 // Set input real type and shape for caller.
3281 if (fallback::HasRealType(node)) {
3282 const auto &real_type = fallback::GetRealType<AnfNode, Type>(node);
3283 fallback::SetRealType<AbstractBase, Type>(res, real_type);
3284 }
3285 if (fallback::HasRealShape(node)) {
3286 const auto &real_shape = fallback::GetRealShape<AnfNode, BaseShape>(node);
3287 fallback::SetRealShape<AbstractBase, BaseShape>(res, real_shape);
3288 }
3289 if (res->isa<AbstractTensor>() && node->has_user_data(fallback::kAdapterTensor) &&
3290 *node->user_data<bool>(fallback::kAdapterTensor)) {
3291 auto res_tensor = res->cast<AbstractTensorPtr>();
3292 res_tensor->set_is_adapter(true);
3293 }
3294 auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
3295 evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
3296 return infer_result;
3297 }
3298
3299 namespace {
3300 class PyInterpretEvaluator : public TransitionPrimEvaluator {
3301 public:
PyInterpretEvaluator()3302 PyInterpretEvaluator() : TransitionPrimEvaluator("PyInterpretEvaluator") {}
3303 ~PyInterpretEvaluator() override = default;
3304 MS_DECLARE_PARENT(PyInterpretEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3305 EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
3306 const AnfNodeConfigPtr &out_conf) override {
3307 if (args_abs_list.empty()) {
3308 MS_LOG(INTERNAL_EXCEPTION) << "'args_abs_list' should not be empty";
3309 }
3310 auto node = out_conf->node();
3311 MS_EXCEPTION_IF_NULL(node);
3312 MS_LOG(DEBUG) << "The current interpret node: " << node->DebugString();
3313
3314 // If the interpret node contains FuncGraph node input, need to convert the Graph node to Interpreted object.
3315 AnfNodePtr converted_interpret_node = ConvertPyInterpretNode(node, args_abs_list);
3316 if (converted_interpret_node != nullptr) {
3317 AnalysisEnginePtr eng = out_conf->engine();
3318 MS_EXCEPTION_IF_NULL(eng);
3319 AnfNodeConfigPtr fn_conf = eng->MakeConfig(converted_interpret_node, out_conf->context(), out_conf->func_graph());
3320 return eng->ForwardConfig(out_conf, fn_conf);
3321 }
3322
3323 non_const_err_ = false;
3324 check_list_dict_inplace_ =
3325 node->has_user_data(fallback::kCheckListDictInplace) && *node->user_data<bool>(fallback::kCheckListDictInplace);
3326
3327 constexpr size_t script_index = 0;
3328 const std::string &script = GetScriptStr(args_abs_list[script_index]);
3329 // Make global and local parameters.
3330 py::tuple params = MakeParameters(args_abs_list, script);
3331 // Would convert PyInterpret to PyExecute then.
3332 if (non_const_err_ || fallback::GetJitAnnotationSideEffectFromComment(node)) {
3333 // Make abstract by type and shape.
3334 AbstractBasePtr res = nullptr;
3335 // Support Tensor annotation type. Add list and tuple here later.
3336 TypePtr dtype = nullptr;
3337 TypePtr type = GetAnnotationType(node, args_abs_list);
3338 if (type != nullptr && type->isa<TensorType>()) {
3339 dtype = type->cast<TensorTypePtr>()->element();
3340 }
3341 // Create output abstract.
3342 if (dtype != nullptr) {
3343 res = std::make_shared<AbstractTensor>(dtype, std::make_shared<Shape>(ShapeVector({Shape::kShapeRankAny})));
3344 } else {
3345 const auto any_abstract = std::make_shared<AbstractAny>();
3346 // If no annotation dtype, try to use unique tensor dtype.
3347 dtype = GetLocalArgsUniqueDtype(node, args_abs_list);
3348 if (dtype != nullptr) {
3349 MS_EXCEPTION_IF_NULL(any_abstract->element());
3350 any_abstract->element()->set_type(dtype);
3351 any_abstract->set_supposed_tensor_dtype(true);
3352 }
3353 res = any_abstract;
3354 }
3355 auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
3356 evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
3357 return infer_result;
3358 }
3359
3360 // Call python script string.
3361 MS_LOG(DEBUG) << "Call script: " << script << ", params: " << py::str(params);
3362 auto obj = parse::data_converter::CallPythonScript(py::str(script), params);
3363 if (py::isinstance<py::none>(obj)) {
3364 AbstractBasePtr res = std::make_shared<abstract::AbstractNone>();
3365 auto infer_result = std::make_shared<EvalResult>(res, nullptr);
3366 evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
3367 return infer_result;
3368 }
3369
3370 ValuePtr converted_val = nullptr;
3371 bool converted = false;
3372 // converted_val could be a InterpretedObject.
3373 if (node->has_user_data("__keep_metafg_obj_flag__")) {
3374 converted_val = std::make_shared<parse::InterpretedObject>(obj);
3375 converted = true;
3376 } else {
3377 converted = parse::ConvertData(obj, &converted_val, true);
3378 }
3379 if (!converted) {
3380 MS_LOG(INTERNAL_EXCEPTION) << "Convert the python object failed";
3381 }
3382 MS_EXCEPTION_IF_NULL(converted_val);
3383 auto fg = node->func_graph();
3384 MS_EXCEPTION_IF_NULL(fg);
3385 auto mng = fg->manager();
3386 MS_EXCEPTION_IF_NULL(mng);
3387 AddManagerForFuncGraphValue(converted_val, mng);
3388 if (converted_val->isa<tensor::Tensor>() && HasConstArgAttr(obj)) {
3389 MS_LOG(WARNING) << "The tensor " << converted_val->ToString()
3390 << " which is not used for network input argument should not be set const.";
3391 }
3392 if (converted_val->isa<parse::InterpretedObject>()) {
3393 const auto interpreted_value = dyn_cast<parse::InterpretedObject>(converted_val);
3394 MS_LOG(DEBUG) << "The InterpretedObject(" << converted_val->ToString() << ") is converted by PyInterpret"
3395 << " node: " << node->DebugString();
3396 interpreted_value->set_has_converted(true);
3397 }
3398
3399 AbstractBasePtr res = ToAbstract(converted_val, AnalysisContext::DummyContext(), out_conf);
3400 auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
3401 evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
3402 return infer_result;
3403 }
3404
AddManagerForFuncGraphValue(const ValuePtr & val,const FuncGraphManagerPtr & mng) const3405 void AddManagerForFuncGraphValue(const ValuePtr &val, const FuncGraphManagerPtr &mng) const {
3406 // mng has been checked before using.
3407 MS_EXCEPTION_IF_NULL(val);
3408 if (val->isa<ValueSequence>()) {
3409 auto val_seq = val->cast<ValueSequencePtr>();
3410 const auto &values = val_seq->value();
3411 std::for_each(values.begin(), values.end(),
3412 [this, &mng](const ValuePtr &e) { AddManagerForFuncGraphValue(e, mng); });
3413 return;
3414 }
3415 if (val->isa<ValueDictionary>()) {
3416 auto val_dict = val->cast<ValueDictionaryPtr>();
3417 const auto &values = val_dict->value();
3418 std::for_each(values.begin(), values.end(), [this, &mng](const std::pair<ValuePtr, ValuePtr> &pair) {
3419 // Key for value dictionary can not have function graph.
3420 AddManagerForFuncGraphValue(pair.second, mng);
3421 });
3422 return;
3423 }
3424 if (val->isa<FuncGraph>()) {
3425 auto val_fg = val->cast<FuncGraphPtr>();
3426 if (val_fg->manager() == nullptr) {
3427 mng->AddFuncGraph(val_fg);
3428 val_fg->set_manager(mng);
3429 }
3430 }
3431 return;
3432 }
3433
CheckInterpretInput(const AbstractDictionaryPtr & abstract_dict,const std::string & script) const3434 void CheckInterpretInput(const AbstractDictionaryPtr &abstract_dict, const std::string &script) const {
3435 // Check whether this node should be interpretive executed.
3436 MS_EXCEPTION_IF_NULL(abstract_dict);
3437 const auto &elements = abstract_dict->elements();
3438 if (elements.empty()) {
3439 return;
3440 }
3441 for (const auto &element : elements) {
3442 const auto &name = element.first;
3443 const auto &local_abs = element.second;
3444 MS_EXCEPTION_IF_NULL(local_abs);
3445 const auto &local_abs_val = local_abs->BuildValue();
3446 MS_EXCEPTION_IF_NULL(local_abs_val);
3447 MS_EXCEPTION_IF_NULL(name);
3448 auto py_data_name = py::str(ValueToPyData(name->BuildValue()));
3449 bool has_python_obj = check_list_dict_inplace_ && fallback::HasObjInExtraInfoHolder(local_abs);
3450 if (local_abs_val->ContainsValueAny() || has_python_obj) {
3451 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
3452 if (allow_fallback_runtime) {
3453 MS_LOG(INFO) << "When using JIT Fallback to handle script '" << script
3454 << "', the inputs should be constant, but found variable '" << py_data_name
3455 << "' to be nonconstant. To convert to PyExecute() afterwards";
3456 non_const_err_ = true;
3457 } else {
3458 MS_EXCEPTION(ValueError) << "When handling script '" << script << " in graph mode"
3459 << "', the inputs should be constant, but found variable '" << py_data_name
3460 << "' to be nonconstant. Try to set jit_syntax_level to LAX.";
3461 }
3462 }
3463 }
3464 }
3465
AddGlobalPythonFunction(const AbstractDictionaryPtr & global_dict,py::object * global_params_dict) const3466 void AddGlobalPythonFunction(const AbstractDictionaryPtr &global_dict, py::object *global_params_dict) const {
3467 MS_EXCEPTION_IF_NULL(global_dict);
3468 MS_EXCEPTION_IF_NULL(global_params_dict);
3469 const auto &global_dict_elements = global_dict->elements();
3470 for (const auto &element : global_dict_elements) {
3471 const auto &element_name = element.first;
3472 const auto &element_abs = element.second;
3473 MS_EXCEPTION_IF_NULL(element_name);
3474 MS_EXCEPTION_IF_NULL(element_abs);
3475 const auto &fn_py_obj = fallback::GetPyObjForFuncGraphAbstractClosure(element_abs);
3476 if (!py::isinstance<py::none>(fn_py_obj)) {
3477 (*global_params_dict)[ValueToPyData(element_name->BuildValue())] = fn_py_obj;
3478 MS_LOG(DEBUG) << "Found global python function object for " << element_name << ", add it to global dict.";
3479 }
3480 }
3481 return;
3482 }
3483
MakeParameters(const AbstractBasePtrList & args_abs_list,const std::string & script) const3484 py::tuple MakeParameters(const AbstractBasePtrList &args_abs_list, const std::string &script) const {
3485 constexpr int params_size = 3;
3486 auto args_size = std::count_if(args_abs_list.begin(), args_abs_list.end(),
3487 [](const AbstractBasePtr &arg) -> bool { return !arg->isa<AbstractMonad>(); });
3488 if (params_size != args_size) {
3489 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected params_size: " << params_size
3490 << ", not equal to arguments.size: " << args_abs_list.size();
3491 }
3492 // The first argument is script string, ignore it.
3493 auto params = py::tuple(params_size - 1);
3494
3495 // Make the global parameters.
3496 constexpr size_t global_index = 1;
3497 auto global_abs = args_abs_list[global_index];
3498 const py::object &global_params_dict = GetGlobalObject(global_abs);
3499 params[0] = global_params_dict;
3500
3501 // Make the local parameters.
3502 constexpr size_t local_index = 2;
3503 auto local_dict = dyn_cast<AbstractDictionary>(args_abs_list[local_index]); // Local parameters dict.
3504 if (local_dict == nullptr) {
3505 MS_EXCEPTION_IF_NULL(args_abs_list[local_index]);
3506 MS_LOG(INTERNAL_EXCEPTION) << "The third argument should be a dictionary, but got "
3507 << args_abs_list[local_index]->ToString();
3508 }
3509 auto filtered_local_dict = FilterParameters(local_dict);
3510 MS_LOG(DEBUG) << "arg_2, local_dict: " << local_dict->ToString()
3511 << ", filtered_local_dict: " << filtered_local_dict->ToString();
3512 ValuePtr local_dict_value = filtered_local_dict->BuildValue();
3513 MS_EXCEPTION_IF_NULL(local_dict_value);
3514 py::dict local_params_dict = ReCheckLocalDict(filtered_local_dict);
3515 MS_LOG(DEBUG) << "arg_2, python local_params_dict: " << local_dict_value->ToString() << " -> "
3516 << py::str(local_params_dict);
3517 params[1] = local_params_dict;
3518 CheckInterpretInput(filtered_local_dict, script);
3519
3520 return params;
3521 }
3522
ReCheckLocalDict(const AbstractDictionaryPtr & filtered_local_dict) const3523 py::dict ReCheckLocalDict(const AbstractDictionaryPtr &filtered_local_dict) const {
3524 const auto &keys_values = filtered_local_dict->elements();
3525 py::dict local_params_dict;
3526 for (auto &key_value : keys_values) {
3527 MS_EXCEPTION_IF_NULL(key_value.second);
3528 ValuePtr element_value = key_value.second->BuildValue();
3529 MS_EXCEPTION_IF_NULL(element_value);
3530 auto py_data = ValueToPyData(element_value);
3531 MS_EXCEPTION_IF_NULL(key_value.first);
3532 local_params_dict[ValueToPyData(key_value.first->BuildValue())] = py_data;
3533 }
3534 return local_params_dict;
3535 }
3536
FilterParameters(const AbstractDictionaryPtr & abstract_dict) const3537 AbstractDictionaryPtr FilterParameters(const AbstractDictionaryPtr &abstract_dict) const {
3538 MS_EXCEPTION_IF_NULL(abstract_dict);
3539 std::vector<AbstractElementPair> kv;
3540 const auto &keys_values = abstract_dict->elements();
3541 // Filter out the element of Function type.
3542 (void)std::copy_if(keys_values.cbegin(), keys_values.cend(), std::back_inserter(kv),
3543 [](const AbstractElementPair &item) {
3544 MS_EXCEPTION_IF_NULL(item.second);
3545 return (!item.second->isa<abstract::AbstractFunction>());
3546 });
3547 return std::make_shared<AbstractDictionary>(kv);
3548 }
3549
HasConstArgAttr(const py::object & obj) const3550 bool HasConstArgAttr(const py::object &obj) const {
3551 constexpr char const_arg_attr[] = "const_arg";
3552 return py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr));
3553 }
3554
GetScriptStr(const AbstractBasePtr & abs) const3555 std::string GetScriptStr(const AbstractBasePtr &abs) const {
3556 // When PyInterpret node is built in python, the value of script abstract should be StringImm.
3557 // Otherwise, the value of script should be Script type.
3558 MS_EXCEPTION_IF_NULL(abs);
3559 ValuePtr value_track = abs->GetValueTrack();
3560 MS_EXCEPTION_IF_NULL(value_track);
3561 if (value_track->isa<parse::Script>()) {
3562 auto script_value_track = dyn_cast_ptr<parse::Script>(value_track);
3563 return script_value_track->script();
3564 }
3565 if (!value_track->isa<StringImm>()) {
3566 MS_INTERNAL_EXCEPTION(TypeError) << "Wrong script type for PyInterpret node, script abs: " << abs->ToString();
3567 }
3568 return value_track->ToString();
3569 }
3570
GetGlobalObject(const AbstractBasePtr & abs) const3571 py::object GetGlobalObject(const AbstractBasePtr &abs) const {
3572 MS_EXCEPTION_IF_NULL(abs);
3573 if (!abs->isa<abstract::AbstractScalar>() && !abs->isa<abstract::AbstractDictionary>()) {
3574 MS_INTERNAL_EXCEPTION(TypeError) << "The second argument should be a scalar(InterpretedObject) or dictionary, "
3575 << "but got " << abs->ToString();
3576 }
3577 auto val = abs->BuildValue();
3578 MS_EXCEPTION_IF_NULL(val);
3579 AbstractDictionaryPtr global_dict = nullptr;
3580 // Some functions in global_dict are not used and will be released early,
3581 // resulting in the func_graph pointer in AbstractClosure being released.
3582 ValuePtr globals_converted_value = nullptr;
3583 py::object global_params_dict;
3584 if (abs->isa<abstract::AbstractDictionary>()) {
3585 global_dict = abs->cast<abstract::AbstractDictionaryPtr>();
3586 auto filtered_global_dict = FilterParameters(global_dict);
3587 global_params_dict = ValueToPyData(filtered_global_dict->BuildValue());
3588 } else {
3589 auto global_dict_interpreted = dyn_cast<parse::InterpretedObject>(val);
3590 MS_EXCEPTION_IF_NULL(global_dict_interpreted);
3591 const py::object &global_params_dict_obj = global_dict_interpreted->obj();
3592 if (!parse::ConvertData(global_params_dict_obj, &globals_converted_value)) {
3593 MS_LOG(INTERNAL_EXCEPTION) << "Convert data failed";
3594 }
3595 MS_EXCEPTION_IF_NULL(globals_converted_value);
3596 // Filter global parameters dict.
3597 global_dict = dyn_cast<AbstractDictionary>(globals_converted_value->ToAbstract());
3598 if (global_dict == nullptr) {
3599 MS_LOG(INTERNAL_EXCEPTION) << "The second argument should be a dictionary, but got "
3600 << globals_converted_value->ToAbstract()->ToString();
3601 }
3602 auto filtered_global_dict = FilterParameters(global_dict);
3603 MS_LOG(DEBUG) << "arg_1, global_dict: " << global_dict->ToString()
3604 << ", filtered_global_dict: " << filtered_global_dict->ToString();
3605 ValuePtr global_dict_value = filtered_global_dict->BuildValue();
3606 global_params_dict = ValueToPyData(global_dict_value);
3607 }
3608 // Add filtered global python function to global_params_dict.
3609 AddGlobalPythonFunction(global_dict, &global_params_dict);
3610 return global_params_dict;
3611 }
3612
ConvertLocalValueInputNode(const AnfNodePtr & local_node,const AbstractBasePtr & local_abs) const3613 AnfNodePtr ConvertLocalValueInputNode(const AnfNodePtr &local_node, const AbstractBasePtr &local_abs) const {
3614 MS_EXCEPTION_IF_NULL(local_node);
3615 MS_EXCEPTION_IF_NULL(local_abs);
3616 AnfNodePtr ret_node = nullptr;
3617 // Not consider AbstractDictionary scene yet.
3618 if (local_abs->isa<abstract::AbstractSequence>() &&
3619 IsOneOfPrimitiveCNode(local_node, {prim::kPrimMakeTuple, prim::kPrimMakeList})) {
3620 auto local_cnode = local_node->cast<CNodePtr>();
3621 auto local_abs_seq = local_abs->cast<abstract::AbstractSequencePtr>();
3622 if (local_cnode->size() - 1 != local_abs_seq->size()) {
3623 MS_LOG(INTERNAL_EXCEPTION) << "For node: " << local_node->DebugString() << ", input size is "
3624 << local_cnode->size() << " and abstract size is " << local_abs_seq->size()
3625 << ". Size not matched.";
3626 }
3627 const auto &local_elements_abs = local_abs_seq->elements();
3628 AnfNodePtrList new_inputs;
3629 (void)new_inputs.emplace_back(local_cnode->input(0));
3630 for (size_t i = 1; i < local_cnode->size(); ++i) {
3631 (void)new_inputs.emplace_back(ConvertLocalValueInputNode(local_cnode->input(i), local_elements_abs[i - 1]));
3632 }
3633 auto fg = local_cnode->func_graph();
3634 MS_EXCEPTION_IF_NULL(fg);
3635 ret_node = fg->NewCNode(new_inputs);
3636 } else {
3637 auto py_obj = fallback::GetPyObjForFuncGraphAbstractClosure(local_abs);
3638 if (py::isinstance<py::none>(py_obj)) {
3639 return local_node;
3640 }
3641 ret_node = NewValueNode(std::make_shared<parse::InterpretedObject>(py_obj));
3642 }
3643 MS_EXCEPTION_IF_NULL(ret_node);
3644 ret_node->set_debug_info(local_node->debug_info());
3645 return ret_node;
3646 }
3647
ConvertPyInterpretNode(const AnfNodePtr & node,const AbstractBasePtrList & args_abs_list) const3648 AnfNodePtr ConvertPyInterpretNode(const AnfNodePtr &node, const AbstractBasePtrList &args_abs_list) const {
3649 MS_EXCEPTION_IF_NULL(node);
3650 // Ensure the same node only check local dict once.
3651 if (node->has_user_data(fallback::kLocalDictCheck) && *node->user_data<bool>(fallback::kLocalDictCheck)) {
3652 return nullptr;
3653 }
3654 node->set_user_data(fallback::kLocalDictCheck, std::make_shared<bool>(true));
3655 auto cnode = node->cast<CNodePtr>();
3656 MS_EXCEPTION_IF_NULL(cnode);
3657 constexpr size_t interpret_min_len = 4;
3658 if (cnode->size() < interpret_min_len) {
3659 MS_LOG(INTERNAL_EXCEPTION) << "The minimum input number for PyInterpret node should be " << interpret_min_len
3660 << " but got " << cnode->size();
3661 }
3662 if (args_abs_list.size() < interpret_min_len - 1) {
3663 MS_LOG(INTERNAL_EXCEPTION) << "The minimum number for PyInterpret input abstract should be "
3664 << interpret_min_len - 1 << " but got " << args_abs_list.size();
3665 }
3666 constexpr size_t local_index = 3;
3667 auto local_node = cnode->input(local_index);
3668 auto local_node_abs = args_abs_list[local_index - 1];
3669 MS_EXCEPTION_IF_NULL(local_node);
3670 MS_EXCEPTION_IF_NULL(local_node_abs);
3671 if (!IsPrimitiveCNode(local_node, prim::kPrimMakeDict)) {
3672 return nullptr;
3673 }
3674 auto local_cnode = local_node->cast<CNodePtr>();
3675 constexpr size_t make_dict_len = 3;
3676 if (local_cnode->size() != make_dict_len) {
3677 MS_LOG(INTERNAL_EXCEPTION) << "Make dict mode input size should be " << make_dict_len << " but got "
3678 << local_cnode->size();
3679 }
3680
3681 const auto &check_abs_function = [](const AbstractBasePtr &input) {
3682 std::function<bool(const AbstractBasePtr &)> check_abs_function_inner;
3683 check_abs_function_inner = [&](const AbstractBasePtr &abs) {
3684 MS_EXCEPTION_IF_NULL(abs);
3685 if (abs->isa<abstract::AbstractSequence>()) {
3686 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
3687 const auto &elements = abs_seq->elements();
3688 return std::any_of(elements.begin(), elements.end(),
3689 [check_abs_function_inner](const AbstractBasePtr &inner_abs) {
3690 return check_abs_function_inner(inner_abs);
3691 });
3692 }
3693 if (abs->isa<abstract::AbstractDictionary>()) {
3694 auto abs_dict = abs->cast<abstract::AbstractDictionaryPtr>();
3695 const auto elements = abs_dict->elements();
3696 return std::any_of(elements.begin(), elements.end(),
3697 [check_abs_function_inner](const abstract::AbstractElementPair &inner_abs) {
3698 // Dictionary key can not be abstract function, no need to check.
3699 return check_abs_function_inner(inner_abs.second);
3700 });
3701 }
3702 return abs->isa<abstract::FuncGraphAbstractClosure>();
3703 };
3704 return check_abs_function_inner(input);
3705 };
3706
3707 if (!check_abs_function(local_node_abs)) {
3708 return nullptr;
3709 }
3710 auto local_node_abs_dict = local_node_abs->cast<abstract::AbstractDictionaryPtr>();
3711 MS_EXCEPTION_IF_NULL(local_node_abs_dict);
3712 const auto &elements_pair = local_node_abs_dict->elements();
3713 std::vector<abstract::AbstractBasePtr> element_abs{};
3714 (void)std::transform(elements_pair.begin(), elements_pair.end(), std::back_inserter(element_abs),
3715 [](const AbstractElementPair &pairs) { return pairs.second; });
3716 auto local_value_abs = std::make_shared<abstract::AbstractTuple>(element_abs);
3717 constexpr size_t value_index = 2;
3718 auto local_value_node = local_cnode->input(value_index);
3719 auto new_local_value_node = ConvertLocalValueInputNode(local_value_node, local_value_abs);
3720 std::vector<AnfNodePtr> new_local_node_inputs;
3721 for (size_t i = 0; i < value_index; ++i) {
3722 new_local_node_inputs.push_back(local_cnode->input(i));
3723 }
3724 new_local_node_inputs.push_back(new_local_value_node);
3725 auto fg = node->func_graph();
3726 MS_EXCEPTION_IF_NULL(fg);
3727 auto new_local_cnode = fg->NewCNode(new_local_node_inputs);
3728 new_local_cnode->set_debug_info(local_cnode->debug_info());
3729 std::vector<AnfNodePtr> new_cnode_inputs;
3730 for (size_t i = 0; i < local_index; ++i) {
3731 new_cnode_inputs.push_back(cnode->input(i));
3732 }
3733 new_cnode_inputs.push_back(new_local_cnode);
3734 for (size_t i = local_index + 1; i < cnode->size(); ++i) {
3735 new_cnode_inputs.push_back(cnode->input(i));
3736 }
3737 auto new_cnode = fg->NewCNode(new_cnode_inputs);
3738 new_cnode->set_debug_info(cnode->debug_info());
3739 new_cnode->set_user_data(fallback::kLocalDictCheck, std::make_shared<bool>(true));
3740 return new_cnode;
3741 }
3742
3743 private:
3744 mutable bool non_const_err_{false};
3745 mutable bool check_list_dict_inplace_{false};
3746 };
3747
3748 class EmbedEvaluator : public SymbolicPrimEvaluator {
3749 public:
EmbedEvaluator()3750 EmbedEvaluator() : SymbolicPrimEvaluator("EmbedEvaluator") {}
3751 ~EmbedEvaluator() override = default;
3752 MS_DECLARE_PARENT(EmbedEvaluator, SymbolicPrimEvaluator);
EvalPrim(const ConfigPtrList & args_conf_list)3753 EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
3754 // arg: free variable to be embedded
3755 if (args_conf_list.size() != 1) {
3756 MS_LOG(INTERNAL_EXCEPTION) << "EmbedEvaluator requires 1 parameter, but got " << args_conf_list.size();
3757 }
3758 auto node_conf = dyn_cast_ptr<AnfNodeConfig>(args_conf_list[0]);
3759 MS_EXCEPTION_IF_NULL(node_conf);
3760 const auto &eval_result = node_conf->ObtainEvalResult();
3761 MS_EXCEPTION_IF_NULL(eval_result);
3762 AbstractBasePtr x = eval_result->abstract();
3763 x = SensitivityTransform(x);
3764 SymbolicKeyInstancePtr key = std::make_shared<SymbolicKeyInstance>(node_conf->node(), x);
3765 AbstractScalarPtr abs_scalar = std::make_shared<AbstractScalar>(key, std::make_shared<SymbolicKeyType>());
3766 return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
3767 }
3768 };
3769
FindParameterNodeByString(const FuncGraphManagerPtr & manager,const std::string & name)3770 static AnfNodePtr FindParameterNodeByString(const FuncGraphManagerPtr &manager, const std::string &name) {
3771 MS_EXCEPTION_IF_NULL(manager);
3772 auto root_g_set = manager->roots();
3773 if (root_g_set.size() != 1) {
3774 return nullptr;
3775 }
3776 const FuncGraphPtr &root_g = root_g_set.back();
3777 MS_EXCEPTION_IF_NULL(root_g);
3778 for (auto ¶m_node : root_g->parameters()) {
3779 auto param = param_node->cast<ParameterPtr>();
3780 if (param != nullptr && param->name() == name) {
3781 return param;
3782 }
3783 }
3784 return nullptr;
3785 }
3786
3787 class RefToEmbedEvaluator : public SymbolicPrimEvaluator {
3788 public:
RefToEmbedEvaluator()3789 RefToEmbedEvaluator() : SymbolicPrimEvaluator("RefToEmbedEvaluator") {}
3790 ~RefToEmbedEvaluator() override = default;
3791 MS_DECLARE_PARENT(RefToEmbedEvaluator, SymbolicPrimEvaluator);
EvalPrim(const ConfigPtrList & args_conf_list)3792 EvalResultPtr EvalPrim(const ConfigPtrList &args_conf_list) override {
3793 if (args_conf_list.size() != 1) {
3794 MS_LOG(ERROR) << "Requires 1 parameter, but has: " << args_conf_list.size();
3795 return nullptr;
3796 }
3797 static TypePtr type = std::make_shared<SymbolicKeyType>();
3798 auto node_conf = dyn_cast_ptr<AnfNodeConfig>(args_conf_list[0]);
3799 if (node_conf == nullptr) {
3800 MS_LOG(ERROR) << "Conf should be AnfNodeConfig";
3801 return nullptr;
3802 }
3803 const auto &eval_result = node_conf->ObtainEvalResult();
3804 MS_EXCEPTION_IF_NULL(eval_result);
3805 AbstractBasePtr abs = eval_result->abstract();
3806 MS_EXCEPTION_IF_NULL(abs);
3807 auto ref_key_value = abstract::GetRefKeyValue(abs);
3808 if (ref_key_value == nullptr) {
3809 MS_LOG(ERROR) << "The first parameter of RefToEmbed should be Ref, but " << abs->ToString();
3810 return nullptr;
3811 }
3812 // Check if the input of RefEmbed is a weight parameter, if not, don't create the
3813 // specific SymbolicKey.
3814 // Notes: when different weight parameter have same type and shape passed as parameter to same funcgraph
3815 // which has RefToEmbed CNode, that funcgraph will not be specialized to different funcgraph, so the
3816 // RefToEmbed CNode in that funcgraph also should not be evaluated to specific SymbolicKey.
3817 // Only after that funcgrpah is inlined, the RefToEmbed CNode should be evaluated to specific SymbolicKey.
3818 bool embed_is_weight = false;
3819 if (node_conf->node() != nullptr && node_conf->node()->isa<Parameter>()) {
3820 auto param = node_conf->node()->cast_ptr<Parameter>();
3821 MS_EXCEPTION_IF_NULL(param);
3822 embed_is_weight = param->has_default();
3823 }
3824 auto refkey = ref_key_value->cast_ptr<StringImm>();
3825 if (refkey == nullptr || !embed_is_weight) {
3826 auto res = std::make_shared<AbstractScalar>(type);
3827 return std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
3828 }
3829
3830 std::string name = refkey->value();
3831 MS_EXCEPTION_IF_NULL(node_conf->node());
3832 if (node_conf->node()->func_graph() == nullptr) {
3833 MS_LOG(INTERNAL_EXCEPTION) << "Should not evaluate a ValueNode, node: " << node_conf->node()->DebugString();
3834 }
3835 const auto &manager = node_conf->node()->func_graph()->manager();
3836 auto node = FindParameterNodeByString(manager, name);
3837 if (node == nullptr) {
3838 MS_LOG(ERROR) << "RefToEmbed input can't find parameter \"" << name << "\" in graph.";
3839 return nullptr;
3840 }
3841 AbstractBasePtr x = SensitivityTransform(abs);
3842 std::shared_ptr<SymbolicKeyInstance> key = std::make_shared<SymbolicKeyInstance>(node, x);
3843 std::shared_ptr<AbstractScalar> abs_scalar = std::make_shared<AbstractScalar>(key, type);
3844 return std::make_shared<EvalResult>(abs_scalar, std::make_shared<AttrValueMap>());
3845 }
3846 };
3847
3848 class GetAttrEvaluator : public TransitionPrimEvaluator {
3849 public:
GetAttrEvaluator()3850 GetAttrEvaluator() : TransitionPrimEvaluator("GetAttrEvaluator") {}
3851 ~GetAttrEvaluator() override = default;
3852 MS_DECLARE_PARENT(GetAttrEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr & in_conf0,const AnfNodeConfigPtr & out_conf)3853 EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
3854 const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
3855 constexpr auto args_min_size = 2;
3856 constexpr auto args_max_size = 3;
3857 const auto args_size = args_abs_list.size();
3858 if (args_size != args_min_size && args_size != args_max_size) {
3859 MS_LOG(EXCEPTION) << "For Primitive GetAttr, the input size should be " << args_min_size << " or "
3860 << args_max_size << ", but got size: " << args_size;
3861 }
3862 auto res_abstract = EvalUndeterminedArgs(args_abs_list);
3863 if (res_abstract != nullptr) {
3864 return res_abstract;
3865 }
3866
3867 constexpr auto attr_index = 1;
3868 auto attr_abs = args_abs_list[attr_index];
3869 MS_EXCEPTION_IF_NULL(attr_abs);
3870 auto attr_abs_type = attr_abs->BuildType();
3871 MS_EXCEPTION_IF_NULL(attr_abs_type);
3872 auto type_id = attr_abs_type->type_id();
3873 if (type_id != TypeId::kObjectTypeString) {
3874 MS_EXCEPTION(TypeError) << "getattr(): attribute name must be string but got: " << TypeIdToString(type_id);
3875 }
3876 EvalResultPtr res = nullptr;
3877 if (bound_node() != nullptr) {
3878 TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
3879 res = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
3880 } else {
3881 res = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
3882 }
3883 // Don't lookup from cache, as different out_conf with same node but different context
3884 // may add different entry to anfnode_config_map, like getattr primitive.
3885 evaluator_cache_mgr_->SetValue(args_abs_list, res);
3886 return res;
3887 }
3888 };
3889
3890 class SetAttrEvaluator : public TransitionPrimEvaluator {
3891 public:
SetAttrEvaluator()3892 SetAttrEvaluator() : TransitionPrimEvaluator("SetAttrEvaluator") {}
3893 ~SetAttrEvaluator() override = default;
3894 MS_DECLARE_PARENT(SetAttrEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3895 EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
3896 const AnfNodeConfigPtr &out_conf) override {
3897 constexpr size_t min_args_size = 3;
3898 constexpr size_t max_args_size = 4;
3899 size_t args_size = args_abs_list.size();
3900 if (args_size != min_args_size && args_size != max_args_size) {
3901 MS_LOG(EXCEPTION) << "For Primitive SetAttr, the input size should be " << min_args_size << " or "
3902 << max_args_size << ", but got size: " << args_size;
3903 }
3904 auto res_abstract = EvalUndeterminedArgs(args_abs_list);
3905 if (res_abstract != nullptr) {
3906 return res_abstract;
3907 }
3908
3909 return InterpretSetAttrNode(args_abs_list, out_conf);
3910 }
3911 };
3912
3913 class ResolveEvaluator : public TransitionPrimEvaluator {
3914 public:
ResolveEvaluator()3915 ResolveEvaluator() : TransitionPrimEvaluator("ResolveEvaluator") {}
3916 ~ResolveEvaluator() override = default;
3917 MS_DECLARE_PARENT(ResolveEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr & in_conf0,const AnfNodeConfigPtr & out_conf)3918 EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list,
3919 const ConfigPtr &in_conf0, const AnfNodeConfigPtr &out_conf) override {
3920 constexpr auto resolve_args_size = 2; // (namespace, symbol)
3921 constexpr auto resolve_with_args_size = 3; // (namespace, symbol, arguments)
3922 // Inputs: namespace, symbol
3923 if (args_abs_list.size() != resolve_args_size && args_abs_list.size() != resolve_with_args_size) {
3924 MS_LOG(EXCEPTION) << "Expected args_abs_list size is 2 or 3, but has size: " << args_abs_list.size();
3925 }
3926 EvalResultPtr res = nullptr;
3927 if (bound_node() != nullptr) {
3928 TraceGuard trace_guard(std::make_shared<TraceResolve>(bound_node()->debug_info()));
3929 res = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
3930 } else {
3931 res = StaticGetter(engine, args_abs_list, in_conf0, out_conf);
3932 }
3933 return res;
3934 }
3935 };
3936
3937 class CreateInstanceEvaluator : public TransitionPrimEvaluator {
3938 public:
CreateInstanceEvaluator()3939 CreateInstanceEvaluator() : TransitionPrimEvaluator("CreateInstanceEvaluator") {}
3940 ~CreateInstanceEvaluator() override = default;
3941 MS_DECLARE_PARENT(CreateInstanceEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)3942 EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
3943 const AnfNodeConfigPtr &out_conf) override {
3944 // Check the type parameter.
3945 if (args_abs_list.empty()) {
3946 MS_LOG(INTERNAL_EXCEPTION) << "'args_abs_list' should not be empty";
3947 }
3948 constexpr size_t class_index = 0;
3949 auto class_obj = GetPythonObject(args_abs_list[class_index]);
3950 py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
3951 std::string class_name =
3952 python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_MS_CLASS_NAME, class_obj).cast<std::string>();
3953 // Get the create instance obj's parameters, `params` may contain tuple(args, kwargs).
3954 auto params = py::tuple(args_abs_list.size() - 1);
3955 bool is_prim_variable = GetParameters(args_abs_list, class_obj, class_name, ¶ms);
3956 if (is_prim_variable) {
3957 return CreatePrimitiveInstanceWithVariableArgs(args_abs_list, class_name, class_obj, engine, out_conf);
3958 }
3959 // Create class instance.
3960 auto obj = parse::data_converter::CreatePythonObject(class_obj, params);
3961 if (py::isinstance<py::none>(obj)) {
3962 MS_LOG(EXCEPTION) << "Create python object `" << py::str(class_obj)
3963 << "` failed, only support to create 'Cell', 'Primitive' or "
3964 << "user-defined Class decorated with 'jit_class'.";
3965 }
3966
3967 // Process the object.
3968 MS_EXCEPTION_IF_NULL(out_conf->node());
3969 TraceGuard guard(std::make_shared<TraceResolve>(out_conf->node()->debug_info()));
3970 ValuePtr converted_res = nullptr;
3971 bool converted = parse::ConvertData(obj, &converted_res, true);
3972 if (!converted) {
3973 MS_LOG(INTERNAL_EXCEPTION) << "Convert the python object failed";
3974 }
3975 MS_EXCEPTION_IF_NULL(converted_res);
3976 // To check isolated side effect for the func graph who returns constant.
3977 HandleSideEffect(obj, converted_res, engine, out_conf);
3978
3979 if (converted_res->isa<FuncGraph>()) {
3980 AddToManager(engine, converted_res->cast<FuncGraphPtr>());
3981 }
3982 AbstractBasePtr res = ToAbstract(converted_res, AnalysisContext::DummyContext(), out_conf);
3983 auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
3984 evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
3985 return infer_result;
3986 }
3987
GetPythonObject(const AbstractBasePtr & arg_class_type) const3988 py::object GetPythonObject(const AbstractBasePtr &arg_class_type) const {
3989 MS_EXCEPTION_IF_NULL(arg_class_type);
3990 TypePtr type = arg_class_type->GetTypeTrack();
3991 MS_EXCEPTION_IF_NULL(type);
3992 if (type->type_id() != kMetaTypeTypeType && type->type_id() != kObjectTypeClass) {
3993 MS_LOG(EXCEPTION)
3994 << "CreateInstanceEvaluator require first parameter should be an object of TypeType or TypeClass, but got "
3995 << type->ToString();
3996 }
3997
3998 ValuePtr value_track = arg_class_type->GetValueTrack();
3999 MS_EXCEPTION_IF_NULL(value_track);
4000 auto type_obj = dyn_cast_ptr<parse::PyObjectWrapper>(value_track);
4001 if (type_obj == nullptr) {
4002 MS_LOG(INTERNAL_EXCEPTION) << "Cast value failed, not PyObjectWrapper: " << value_track->ToString() << ".";
4003 }
4004 if (!type_obj->isa<parse::ClassType>() && !type_obj->isa<parse::MsClassObject>()) {
4005 MS_LOG(EXCEPTION)
4006 << "CreateInstanceEvaluator the type_obj should be an object of ClassType or MsClassObject, but got "
4007 << type_obj->ToString() << ".";
4008 }
4009 MS_LOG(DEBUG) << "Get class type: " << type_obj->ToString() << ".";
4010 return type_obj->obj();
4011 }
4012
HandleSideEffect(const py::object & obj,const ValuePtr & converted_res,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf) const4013 void HandleSideEffect(const py::object &obj, const ValuePtr &converted_res, const AnalysisEnginePtr &engine,
4014 const AnfNodeConfigPtr &out_conf) const {
4015 if (engine->check_side_effect()) {
4016 MS_LOG(DEBUG) << "obj: " << py::str(obj) << ", converted_res: " << converted_res->ToString();
4017 auto prim = GetValueWithoutDoSignature(converted_res)->cast<PrimitivePtr>();
4018 if (prim != nullptr) {
4019 auto effect_info = GetPrimEffectInfo(prim);
4020 if (effect_info.memory || effect_info.io) {
4021 const auto &cnode = dyn_cast<CNode>(out_conf->node());
4022 MS_EXCEPTION_IF_NULL(cnode);
4023 MS_EXCEPTION_IF_NULL(out_conf->func_graph());
4024 MS_LOG(DEBUG) << "Found side-effect, cnode: " << cnode->DebugString()
4025 << ", func_graph: " << out_conf->func_graph()->ToString();
4026 cnode->set_has_side_effect_node(true);
4027 out_conf->func_graph()->set_has_side_effect_node(true);
4028 }
4029 }
4030 }
4031 }
4032
GetParameters(const AbstractBasePtrList & args_abs_list,const py::object & obj,const std::string & cls_name,py::tuple * params)4033 bool GetParameters(const AbstractBasePtrList &args_abs_list, const py::object &obj, const std::string &cls_name,
4034 py::tuple *params) {
4035 auto params_size = (*params).size();
4036 for (size_t i = 0; i < params_size; i++) {
4037 // Only support the Scalar parameters type. Bypass class type by offset with 1.
4038 auto arg = args_abs_list[i + 1];
4039 MS_EXCEPTION_IF_NULL(arg);
4040 auto param_value = arg->BuildValue();
4041 MS_EXCEPTION_IF_NULL(param_value);
4042 if (param_value->ContainsValueAny() && !arg->isa<AbstractFunction>()) {
4043 // If obj is a Primitive class and has variable arguments, just return and go through another process.
4044 if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG) && mindspore::ops::GetOpDef(cls_name) != nullptr) {
4045 return true;
4046 }
4047 MS_EXCEPTION(TypeError) << "When creating an instance of '" << cls_name
4048 << "', all arguments are required to be constants, but input " << i
4049 << " is a variable, which is " << arg->ToString() << ".";
4050 }
4051 py::object param = ValueToPyData(param_value);
4052 (*params)[i] = param;
4053 }
4054 return false;
4055 }
4056
CreatePrimitiveInstanceWithVariableArgs(const AbstractBasePtrList & args_abs_list,const std::string & cls_name,const py::object & cls_obj,const AnalysisEnginePtr & engine,const AnfNodeConfigPtr & out_conf) const4057 EvalResultPtr CreatePrimitiveInstanceWithVariableArgs(const AbstractBasePtrList &args_abs_list,
4058 const std::string &cls_name, const py::object &cls_obj,
4059 const AnalysisEnginePtr &engine,
4060 const AnfNodeConfigPtr &out_conf) const {
4061 // Create Primitive instance with variable arguments.
4062 auto prim_func = std::make_shared<Primitive>(cls_name);
4063 auto do_trans_prim_func = std::make_shared<prim::DoTransPrimitiveFunction>(prim_func);
4064 // Ignore the first input which is ClassType.
4065 AbstractBasePtrList partial_args_abs_list(args_abs_list.begin() + 1, args_abs_list.end());
4066 do_trans_prim_func->set_given_init_size(partial_args_abs_list.size());
4067 auto func_ptr = std::make_shared<abstract::PrimitiveAbstractClosure>(do_trans_prim_func);
4068 auto ret_val =
4069 std::make_shared<abstract::PartialAbstractClosure>(func_ptr, partial_args_abs_list, out_conf->node());
4070 ret_val->set_need_append_to_end(true);
4071 return std::make_shared<EvalResult>(ret_val, std::make_shared<AttrValueMap>());
4072 }
4073 };
4074
4075 class PartialEvaluator : public Evaluator {
4076 public:
PartialEvaluator()4077 PartialEvaluator() : Evaluator("PartialEvaluator") {}
4078 ~PartialEvaluator() override = default;
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)4079 EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
4080 const AnfNodeConfigPtr &out_conf) override {
4081 if (args_conf_list.size() == 0) {
4082 MS_LOG(INTERNAL_EXCEPTION) << "Args size should be greater than 0";
4083 }
4084 MS_EXCEPTION_IF_NULL(out_conf);
4085 MS_EXCEPTION_IF_NULL(out_conf->node());
4086 MS_EXCEPTION_IF_NULL(args_conf_list[0]);
4087 const auto &arg0_eval_result = args_conf_list[0]->ObtainEvalResult();
4088 MS_EXCEPTION_IF_NULL(arg0_eval_result);
4089 auto arg0_value = arg0_eval_result->abstract();
4090 MS_EXCEPTION_IF_NULL(arg0_value);
4091 AbstractBasePtrList args_abs_list{arg0_value};
4092 auto cnode = out_conf->node()->cast<CNodePtr>();
4093 MS_EXCEPTION_IF_NULL(cnode);
4094
4095 // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node.
4096 if (arg0_value->isa<AbstractProblem>()) {
4097 MS_EXCEPTION_IF_NULL(arg0_value->GetValueTrack());
4098 const auto &value_problem = arg0_value->GetValueTrack()->cast<ValueProblemPtr>();
4099 auto res = std::make_shared<AbstractProblem>(value_problem, out_conf->node());
4100 MS_LOG(DEBUG) << "AbstractProblem for node: " << out_conf->node()->DebugString()
4101 << " as func is: " << arg0_value->ToString();
4102 auto eval_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
4103 evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
4104 return eval_result;
4105 }
4106 auto func = CheckArg<AbstractFunction>("partial", args_abs_list, 0);
4107 // Sometimes, node[0] in out_conf becomes phi0;
4108 if (func->isa<PrimitiveAbstractClosure>()) {
4109 auto prim_func = dyn_cast_ptr<PrimitiveAbstractClosure>(func);
4110 MS_EXCEPTION_IF_NULL(prim_func);
4111 MS_EXCEPTION_IF_NULL(prim_func->prim());
4112 if (prim_func->prim()->isa<prim::DoSignaturePrimitive>()) {
4113 auto do_signature_prim = dyn_cast_ptr<prim::DoSignaturePrimitive>(prim_func->prim());
4114 return HandleDoSignature(engine, do_signature_prim->function(), out_conf);
4115 }
4116 }
4117
4118 (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_abs_list),
4119 [](const ConfigPtr &config) -> AbstractBasePtr {
4120 MS_EXCEPTION_IF_NULL(config);
4121 const auto &eval_result = config->ObtainEvalResult();
4122 MS_EXCEPTION_IF_NULL(eval_result);
4123 return eval_result->abstract();
4124 });
4125 AbstractBasePtrList args(args_abs_list.begin() + 1, args_abs_list.end());
4126
4127 if (cnode->size() != (args_conf_list.size() + 1)) {
4128 MS_LOG(INTERNAL_EXCEPTION) << "Out_conf node: " << cnode->DebugString()
4129 << ", args_conf_list: " << mindspore::ToString(args_conf_list);
4130 }
4131 AbstractFuncAtomPtrList partial_funcs_list;
4132 auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) {
4133 auto new_func = std::make_shared<PartialAbstractClosure>(atom_func, args, cnode);
4134 partial_funcs_list.push_back(new_func);
4135 };
4136 func->Visit(build_partial);
4137
4138 auto res = AbstractFunction::MakeAbstractFunction(partial_funcs_list);
4139 auto eval_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
4140 MS_LOG(DEBUG) << "args_abs_list: " << args_abs_list << ", eval_result: " << eval_result->abstract()->ToString();
4141 evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
4142 return eval_result;
4143 }
4144
Eval(AnalysisEnginePtr,const AbstractBasePtrList &,const AnfNodeConfigPtr &)4145 EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &, const AnfNodeConfigPtr &) override {
4146 MS_LOG(INTERNAL_EXCEPTION) << "Eval() should not be called, Run() method should be called";
4147 }
4148
HandleDoSignature(const AnalysisEnginePtr & engine,const ValuePtr & signature_value,const AnfNodeConfigPtr & out_conf) const4149 EvalResultPtr HandleDoSignature(const AnalysisEnginePtr &engine, const ValuePtr &signature_value,
4150 const AnfNodeConfigPtr &out_conf) const {
4151 MS_EXCEPTION_IF_NULL(engine);
4152 MS_EXCEPTION_IF_NULL(out_conf);
4153 MS_EXCEPTION_IF_NULL(out_conf->node());
4154 auto cnode = out_conf->node()->cast_ptr<CNode>();
4155 MS_EXCEPTION_IF_NULL(cnode);
4156
4157 ScopeGuard scope_guard(out_conf->node()->scope());
4158 TraceGuard trace_guard(std::make_shared<TraceDoSignature>(out_conf->node()->debug_info()));
4159 auto new_nodes_inputs = cnode->weak_inputs();
4160 auto new_signature_value = std::make_shared<prim::DoSignatureMetaFuncGraph>("signature", signature_value);
4161 auto new_sig_node = NewValueNode(new_signature_value);
4162 new_nodes_inputs[1] = AnfNodeWeakPtr(new_sig_node);
4163 FuncGraphPtr func_graph = cnode->func_graph();
4164 MS_EXCEPTION_IF_NULL(func_graph);
4165 CNodePtr new_cnode = func_graph->NewCNodeWeak(std::move(new_nodes_inputs));
4166 AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_cnode, out_conf->context(), out_conf->func_graph());
4167 return engine->ForwardConfig(out_conf, fn_conf);
4168 }
4169 };
4170
4171 class RaiseEvaluator : public TransitionPrimEvaluator {
4172 public:
RaiseEvaluator()4173 RaiseEvaluator() : TransitionPrimEvaluator("RaiseEvaluator") {}
4174 ~RaiseEvaluator() override = default;
4175 MS_DECLARE_PARENT(RaiseEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr &,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)4176 EvalResultPtr EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
4177 const AnfNodeConfigPtr &out_conf) override {
4178 MS_EXCEPTION_IF_NULL(out_conf);
4179 // Handle for DDE.
4180 for (size_t i = 0; i < args_abs_list.size(); ++i) {
4181 MS_EXCEPTION_IF_NULL(args_abs_list[i]);
4182 if (args_abs_list[i]->isa<abstract::AbstractSequence>()) {
4183 MS_LOG(DEBUG) << "Primitive \'Raise\' is consuming tuple/list arguments[" << i
4184 << "]: " << args_abs_list[i]->ToString();
4185 SetSequenceElementsUseFlagsRecursively(args_abs_list[i], true);
4186 }
4187 }
4188 auto node = out_conf->node();
4189 MS_EXCEPTION_IF_NULL(node);
4190 auto cur_graph = node->func_graph();
4191 MS_EXCEPTION_IF_NULL(cur_graph);
4192 if (args_abs_list.empty()) {
4193 // Process raise.
4194 MS_LOG(INTERNAL_EXCEPTION) << "No active exception to reraise.";
4195 }
4196 const auto &cnode = node->cast<CNodePtr>();
4197 MS_EXCEPTION_IF_NULL(cnode);
4198
4199 // Return Any directly if meet variable condition or content.
4200 bool is_variable_condition = raiseutils::HasVariableCondition(cur_graph);
4201 bool has_variable = false;
4202 size_t index_begin = 2;
4203 size_t index_end = cnode->size() - 1;
4204 for (size_t index = index_begin; index < cnode->size(); ++index) {
4205 if (raiseutils::CheckHasVariable(args_abs_list[index - 1])) {
4206 has_variable = true;
4207 break;
4208 }
4209 }
4210 if (is_variable_condition || has_variable) {
4211 AbstractBasePtr res = std::make_shared<AbstractNegligible>();
4212 cnode->set_has_side_effect_node(true);
4213 cur_graph->set_has_side_effect_node(true);
4214 auto infer_result = std::make_shared<EvalResult>(res, std::make_shared<AttrValueMap>());
4215 evaluator_cache_mgr_->SetValue(args_abs_list, infer_result);
4216 return infer_result;
4217 }
4218
4219 // Continue to handle raise in compile time.
4220 std::shared_ptr<raiseutils::KeyValueInfo> key_value = std::make_shared<raiseutils::KeyValueInfo>();
4221 std::string exception_type =
4222 raiseutils::GetExceptionType(args_abs_list[0], cnode->input(index_end), key_value, false);
4223 std::string exception_string;
4224 // Process raise ValueError()
4225 if (args_abs_list.size() == 1) {
4226 RaiseConstant(exception_type);
4227 }
4228 // Processed in units of nodes. Raise ValueError(xxxx)
4229 for (size_t index = index_begin; index < cnode->size() - 1; ++index) {
4230 const auto input = cnode->input(index);
4231 auto input_abs = args_abs_list[index - 1];
4232 MS_EXCEPTION_IF_NULL(input_abs);
4233 const bool need_symbol = raiseutils::CheckNeedSymbol(input_abs);
4234 if (need_symbol) {
4235 exception_string += "'";
4236 }
4237 bool need_comma = !IsPrimitiveCNode(input, prim::kPrimMakeTuple);
4238 exception_string += raiseutils::GetExceptionString(input_abs, input, key_value, need_symbol, need_comma);
4239 if (need_symbol) {
4240 exception_string += "'";
4241 }
4242 constexpr auto end_index = 2;
4243 if (index < cnode->size() - end_index) {
4244 exception_string += ", ";
4245 }
4246 }
4247 bool need_out_symbol = cnode->size() > 4;
4248 if (need_out_symbol) {
4249 exception_string = "(" + exception_string + ")";
4250 }
4251 RaiseConstant(exception_type, exception_string);
4252 MS_LOG(EXCEPTION) << "Constant raise is not raising exception correctly";
4253 }
4254
4255 private:
RaiseConstant(const std::string & type,const std::string & exception_string="")4256 void RaiseConstant(const std::string &type, const std::string &exception_string = "") {
4257 auto iter = exception_types_map.find(type);
4258 if (iter == exception_types_map.end()) {
4259 MS_LOG(EXCEPTION) << "Unsupported exception type: " << type
4260 << ". Raise only support some Python standard exception types: "
4261 << SupportedExceptionsToString();
4262 }
4263 ExceptionType error_type = iter->second;
4264 if (exception_string.empty()) {
4265 MS_EXCEPTION(error_type);
4266 } else {
4267 MS_EXCEPTION(error_type) << exception_string;
4268 }
4269 }
4270 };
4271
4272 class WithEnterEvaluator : public TransitionPrimEvaluator {
4273 public:
WithEnterEvaluator()4274 WithEnterEvaluator() : TransitionPrimEvaluator("WithEnterEvaluator") {}
4275 ~WithEnterEvaluator() override = default;
4276 MS_DECLARE_PARENT(WithEnterEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)4277 EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
4278 const AnfNodeConfigPtr &out_conf) override {
4279 MS_EXCEPTION_IF_NULL(out_conf);
4280 MS_EXCEPTION_IF_NULL(out_conf->node());
4281 auto node = out_conf->node()->cast<CNodePtr>();
4282 MS_EXCEPTION_IF_NULL(node);
4283 auto cur_graph = node->func_graph();
4284 MS_EXCEPTION_IF_NULL(cur_graph);
4285
4286 if (args_abs_list.size() != 1) {
4287 MS_LOG(INTERNAL_EXCEPTION) << "The enter node has wrong input." << node->debug_info();
4288 }
4289
4290 // Check class object
4291 constexpr size_t cls_index = 0;
4292 MS_EXCEPTION_IF_NULL(args_abs_list[cls_index]);
4293 auto cls_val = args_abs_list[cls_index]->BuildValue();
4294 MS_EXCEPTION_IF_NULL(cls_val);
4295 auto value_obj = cls_val->cast<parse::MsClassObjectPtr>();
4296 if (value_obj == nullptr) {
4297 MS_EXCEPTION(TypeError) << "Only support jit_class instance, but got " << cls_val->ToString();
4298 }
4299 auto cls_obj = value_obj->obj();
4300
4301 const std::string call_func = "__enter__";
4302 if (!py::hasattr(cls_obj, common::SafeCStr(call_func))) {
4303 MS_LOG(EXCEPTION) << value_obj->name() << " has no " << call_func << " function, please check the code.";
4304 }
4305 py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func));
4306 FuncGraphPtr call_func_graph = parse::ConvertToFuncGraph(call_obj);
4307 if (call_func_graph == nullptr) {
4308 MS_LOG(INTERNAL_EXCEPTION) << "Parse python object " << call_func << " failed.";
4309 }
4310 FuncGraphManagerPtr manager = engine->func_graph_manager();
4311 MS_EXCEPTION_IF_NULL(manager);
4312 manager->AddFuncGraph(call_func_graph);
4313
4314 std::vector<AnfNodePtr> enter_inputs{NewValueNode(call_func_graph)};
4315 // __enter__(self)
4316 auto call_enter_node = cur_graph->NewCNodeInOrder(enter_inputs);
4317 // Continue to eval call_enter_node.
4318 AnfNodeConfigPtr fn_conf = engine->MakeConfig(call_enter_node, out_conf->context(), out_conf->func_graph());
4319 return engine->ForwardConfig(out_conf, fn_conf);
4320 }
4321 };
4322
4323 class WithExitEvaluator : public TransitionPrimEvaluator {
4324 public:
WithExitEvaluator()4325 WithExitEvaluator() : TransitionPrimEvaluator("WithExitEvaluator") {}
4326 ~WithExitEvaluator() override = default;
4327 MS_DECLARE_PARENT(WithExitEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)4328 EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
4329 const AnfNodeConfigPtr &out_conf) override {
4330 MS_EXCEPTION_IF_NULL(out_conf);
4331 MS_EXCEPTION_IF_NULL(out_conf->node());
4332 auto node = out_conf->node()->cast<CNodePtr>();
4333 MS_EXCEPTION_IF_NULL(node);
4334 auto cur_graph = node->func_graph();
4335 MS_EXCEPTION_IF_NULL(cur_graph);
4336
4337 if (args_abs_list.size() != 1) {
4338 MS_LOG(INTERNAL_EXCEPTION) << "The exit node has wrong input." << node->debug_info();
4339 }
4340
4341 // Check class object
4342 constexpr size_t cls_index = 0;
4343 MS_EXCEPTION_IF_NULL(args_abs_list[cls_index]);
4344 auto cls_val = args_abs_list[cls_index]->BuildValue();
4345 MS_EXCEPTION_IF_NULL(cls_val);
4346 auto value_obj = cls_val->cast<parse::MsClassObjectPtr>();
4347 if (value_obj == nullptr) {
4348 MS_EXCEPTION(TypeError) << "Only support jit_class instance, but got " << cls_val->ToString();
4349 }
4350 auto cls_obj = value_obj->obj();
4351
4352 const std::string call_func = "__exit__";
4353 if (!py::hasattr(cls_obj, common::SafeCStr(call_func))) {
4354 MS_LOG(EXCEPTION) << value_obj->name() << " has no " << call_func << " function, please check the code.";
4355 }
4356 py::object call_obj = py::getattr(cls_obj, common::SafeCStr(call_func));
4357 FuncGraphPtr call_func_graph = parse::ConvertToFuncGraph(call_obj);
4358 if (call_func_graph == nullptr) {
4359 MS_LOG(INTERNAL_EXCEPTION) << "Parse python object " << call_func << " failed.";
4360 }
4361 FuncGraphManagerPtr manager = engine->func_graph_manager();
4362 MS_EXCEPTION_IF_NULL(manager);
4363 manager->AddFuncGraph(call_func_graph);
4364
4365 std::vector<AnfNodePtr> exit_inputs{NewValueNode(call_func_graph)};
4366 constexpr size_t arg_size = 3;
4367 // __exit__(self, type, value, trace)
4368 for (size_t i = 0; i < arg_size; ++i) {
4369 (void)exit_inputs.emplace_back(NewValueNode(kNone));
4370 }
4371 auto call_exit_node = cur_graph->NewCNodeInOrder(exit_inputs);
4372 // Continue to eval call_exit_node.
4373 AnfNodeConfigPtr fn_conf = engine->MakeConfig(call_exit_node, out_conf->context(), out_conf->func_graph());
4374 return engine->ForwardConfig(out_conf, fn_conf);
4375 }
4376 };
4377
4378 class CondEvaluator : public TransitionPrimEvaluator {
4379 public:
CondEvaluator()4380 CondEvaluator() : TransitionPrimEvaluator("CondEvaluator") {}
4381 ~CondEvaluator() override = default;
4382 MS_DECLARE_PARENT(CondEvaluator, TransitionPrimEvaluator);
EvalPrim(const AnalysisEnginePtr & engine,const AbstractBasePtrList & args_abs_list,const ConfigPtr &,const AnfNodeConfigPtr & out_conf)4383 EvalResultPtr EvalPrim(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_abs_list, const ConfigPtr &,
4384 const AnfNodeConfigPtr &out_conf) override {
4385 auto res_abstract = EvalUndeterminedArgs(args_abs_list);
4386 if (res_abstract != nullptr) {
4387 return res_abstract;
4388 }
4389 MS_EXCEPTION_IF_NULL(out_conf);
4390 MS_EXCEPTION_IF_NULL(out_conf->node());
4391 auto cnode = out_conf->node()->cast<CNodePtr>();
4392 MS_EXCEPTION_IF_NULL(cnode);
4393 auto cur_graph = cnode->func_graph();
4394 MS_EXCEPTION_IF_NULL(cur_graph);
4395 constexpr size_t input_size = 2;
4396 if (args_abs_list.size() != input_size) {
4397 MS_LOG(INTERNAL_EXCEPTION) << "The input size to cond node should be " << std::to_string(input_size)
4398 << ", but got " << std::to_string(args_abs_list.size());
4399 }
4400
4401 AnfNodePtr new_node = nullptr;
4402 constexpr size_t cond_abs_index = 0;
4403 constexpr size_t cond_input_index = 1;
4404 constexpr size_t flag_input_index = 2;
4405 auto cond_abs = args_abs_list[cond_abs_index];
4406 auto cond_node = cnode->input(cond_input_index);
4407 auto flag_node = cnode->input(flag_input_index);
4408 MS_EXCEPTION_IF_NULL(cond_abs);
4409 if (cond_abs->isa<AbstractAny>()) {
4410 // If the input to cond node is AbstractAny, genenrate pyexecute node 'bool(input)';
4411 const auto script_str = std::make_shared<StringImm>("bool(__input__)");
4412
4413 const auto input_str = std::make_shared<StringImm>("__input__");
4414 std::vector<AnfNodePtr> key_value_names_list{NewValueNode(prim::kPrimMakeTuple)};
4415 (void)key_value_names_list.emplace_back(NewValueNode(input_str));
4416 const auto key_value_name_tuple = cur_graph->NewCNode(key_value_names_list);
4417
4418 std::vector<AnfNodePtr> key_value_list{NewValueNode(prim::kPrimMakeTuple), cond_node};
4419 const auto key_value_tuple = cur_graph->NewCNode(key_value_list);
4420 new_node =
4421 fallback::CreatePyExecuteCNodeInOrder(cnode, NewValueNode(script_str), key_value_name_tuple, key_value_tuple);
4422 fallback::SetRealType<AnfNode, Type>(new_node, std::make_shared<TensorType>(kBool));
4423 fallback::SetRealShape(new_node, std::make_shared<abstract::Shape>(std::vector<int64_t>{Shape::kShapeDimAny}));
4424 } else if (cond_abs->isa<AbstractTensor>() && is_while_condition(flag_node)) {
4425 // When the condition of while is a tensor, do not use standard_method.tensor_bool
4426 // to avoid turning the tensor into scalar to cause a loop.
4427 constexpr auto operations_module = "mindspore.ops.operations";
4428 auto cast_op = python_adapter::GetPyFn(operations_module, kCastOpName)();
4429 auto cast_node = NewValueNode(parse::data_converter::PyDataToValue(cast_op));
4430 auto type_node = NewValueNode(TypeIdToType(kNumberTypeBool));
4431 new_node = cur_graph->NewCNodeInOrder({cast_node, cond_node, type_node});
4432 new_node->set_debug_info(cnode->debug_info());
4433 } else if (cond_abs->isa<AbstractFunction>()) {
4434 auto abs = std::make_shared<AbstractScalar>(std::make_shared<BoolImm>(true), kBool);
4435 return std::make_shared<EvalResult>(abs, std::make_shared<AttrValueMap>());
4436 } else {
4437 // The logic of truth value testing:
4438 // 1. If the object has __bool__ attribute, call __bool__()
4439 // 2. Else if the object has __len__ attribute, call __len__()
4440 // 3. Else return true.
4441 auto cond_type = cond_abs->BuildType();
4442 MS_EXCEPTION_IF_NULL(cond_type);
4443 auto cond_type_id = cond_type->type_id();
4444 constexpr auto bool_attr_str = "__bool__";
4445 constexpr auto len_attr_str = "__len__";
4446 ValuePtr prim_func;
4447 if (!pipeline::Resource::GetMethodPtr(cond_type_id, bool_attr_str).empty()) {
4448 prim_func = prim::GetPythonOps(parse::NAMED_PRIMITIVE_BOOL);
4449 } else if (!pipeline::Resource::GetMethodPtr(cond_type_id, len_attr_str).empty()) {
4450 prim_func = prim::GetPythonOps(parse::NAMED_PRIMITIVE_CHECK_LEN);
4451 } else {
4452 prim_func = prim::GetPythonOps(parse::NAMED_PRIMITIVE_REAL_BOOL);
4453 }
4454 auto prim_fg = dyn_cast<FuncGraph>(prim_func);
4455 MS_EXCEPTION_IF_NULL(prim_fg);
4456 auto mng = cur_graph->manager();
4457 MS_EXCEPTION_IF_NULL(mng);
4458 prim_fg->set_manager(mng);
4459 new_node = cur_graph->NewCNodeInOrder({NewValueNode(prim_fg), cond_node});
4460 }
4461 AnfNodeConfigPtr fn_conf = engine->MakeConfig(new_node, out_conf->context(), out_conf->func_graph());
4462 return engine->ForwardConfig(out_conf, fn_conf);
4463 }
4464
is_while_condition(const AnfNodePtr & flag_node) const4465 bool is_while_condition(const AnfNodePtr &flag_node) const {
4466 MS_EXCEPTION_IF_NULL(flag_node);
4467 auto vnode = GetValueNode(flag_node);
4468 MS_EXCEPTION_IF_NULL(vnode);
4469 return GetValue<bool>(vnode);
4470 }
4471 };
4472
4473 struct PrimitiveImplInferValue {
4474 PrimitiveImpl impl_; // implement function of primitive
4475 bool eval_value_; // whether evaluate value
4476 TypePtr specify_out_type_; // whether specify return type
4477 bool in_white_list_; // true if this Primitive in white list, else false.
4478 };
4479
4480 using PrimitiveToImplMap = mindspore::HashMap<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
GetUniformPrimitiveToImplMap()4481 PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
4482 using R = PrimitiveToImplMap::mapped_type;
4483 static PrimitiveToImplMap uniform_prim_implement_map{
4484 {prim::kPrimScalarPow, R{prim::ScalarPow, true, nullptr, true}},
4485 {prim::kPrimScalarUadd, R{prim::ScalarUAdd, true, nullptr, true}},
4486 {prim::kPrimScalarUsub, R{prim::ScalarUSub, true, nullptr, true}},
4487 {prim::kPrimScalarLog, R{prim::ScalarLog, true, nullptr, true}},
4488 {prim::kPrimBitXor, R{prim::BitXor, true, nullptr, true}},
4489 {prim::kPrimBitLeftShift, R{prim::BitLeftShift, true, nullptr, true}},
4490 {prim::kPrimBitRightShift, R{prim::BitRightShift, true, nullptr, true}},
4491 {prim::kPrimScalarNe, R{prim::ScalarNe, true, std::make_shared<Bool>(), true}},
4492 {prim::kPrimBoolAnd, R{prim::BoolAnd, true, std::make_shared<Bool>(), true}},
4493 {prim::kPrimBoolEq, R{prim::BoolEq, true, std::make_shared<Bool>(), true}},
4494 {prim::kPrimBoolOr, R{prim::BoolOr, true, std::make_shared<Bool>(), true}},
4495 {prim::kPrimStringConcat, R{prim::StringConcat, true, nullptr, true}},
4496 {prim::kPrimStringEq, R{prim::StringEq, true, std::make_shared<Bool>(), true}},
4497 {prim::kPrimStringLt, R{prim::StringLt, true, std::make_shared<Bool>(), true}},
4498 {prim::kPrimStringGt, R{prim::StringGt, true, std::make_shared<Bool>(), true}},
4499 {prim::kPrimStringLe, R{prim::StringLe, true, std::make_shared<Bool>(), true}},
4500 {prim::kPrimStringGe, R{prim::StringGe, true, std::make_shared<Bool>(), true}},
4501 {prim::kPrimStringNot, R{prim::StringNot, true, std::make_shared<Bool>(), true}},
4502 {prim::kPrimStringIn, R{prim::StringIn, true, std::make_shared<Bool>(), true}},
4503 };
4504 return uniform_prim_implement_map;
4505 }
4506
4507 PrimEvaluatorMap prim_evaluator_constructors = PrimEvaluatorMap();
4508 std::mutex PrimEvaluatorConstructorMutex;
4509
InitPrimEvaluatorConstructors()4510 void InitPrimEvaluatorConstructors() {
4511 PrimEvaluatorMap &constructor = prim_evaluator_constructors;
4512
4513 for (const auto &iter : GetPrimitiveInferMap()) {
4514 constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second);
4515 }
4516
4517 for (const auto &iter : GetUniformPrimitiveToImplMap()) {
4518 constructor[iter.first] =
4519 InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
4520 }
4521 constructor[prim::kPrimEmbed] = std::make_shared<EmbedEvaluator>();
4522 constructor[prim::kPrimRefToEmbed] = std::make_shared<RefToEmbedEvaluator>();
4523 constructor[prim::kPrimGetAttr] = std::make_shared<GetAttrEvaluator>();
4524 constructor[prim::kPrimSetAttr] = std::make_shared<SetAttrEvaluator>();
4525 constructor[prim::kPrimResolve] = std::make_shared<ResolveEvaluator>();
4526 constructor[prim::kPrimCreateInstance] = std::make_shared<CreateInstanceEvaluator>();
4527 constructor[prim::kPrimPartial] = std::make_shared<PartialEvaluator>();
4528 constructor[prim::kPrimPyInterpret] = std::make_shared<PyInterpretEvaluator>();
4529 constructor[prim::kPrimMakeTuple] = std::make_shared<MakeTupleEvaluator>();
4530 constructor[prim::kPrimMakeList] = std::make_shared<MakeListEvaluator>();
4531 constructor[prim::kPrimRaise] = std::make_shared<RaiseEvaluator>();
4532 constructor[prim::kPrimWithEnter] = std::make_shared<WithEnterEvaluator>();
4533 constructor[prim::kPrimWithExit] = std::make_shared<WithExitEvaluator>();
4534 constructor[prim::kPrimCond] = std::make_shared<CondEvaluator>();
4535 }
4536
InitBuiltinPrimEvaluatorConstructors()4537 void InitBuiltinPrimEvaluatorConstructors() {
4538 PrimEvaluatorMap &constructor = prim_evaluator_constructors;
4539 constructor[prim::kPrimInnerAbs] = std::make_shared<InnerAbsEvaluator>();
4540 constructor[prim::kPrimInnerRound] = std::make_shared<InnerRoundEvaluator>();
4541 }
4542 } // namespace
4543
ClearPrimEvaluatorMap()4544 void ClearPrimEvaluatorMap() {
4545 prim_evaluator_constructors.clear();
4546 GetFrontendPrimitiveInferMapPtr()->clear();
4547 GetUniformPrimitiveToImplMap().clear();
4548 }
4549
IsInWhiteList(const PrimitivePtr & primitive)4550 bool IsInWhiteList(const PrimitivePtr &primitive) {
4551 MS_EXCEPTION_IF_NULL(primitive);
4552
4553 using WhiteList = mindspore::HashMap<PrimitivePtr, bool, PrimitiveHasher, PrimitiveEqual>;
4554
4555 static WhiteList whitelist = {{prim::kPrimPartial, true}};
4556 auto iter = whitelist.find(primitive);
4557 if (iter != whitelist.end()) {
4558 return iter->second;
4559 }
4560
4561 auto found = abstract::GetFrontendPrimitiveInferImpl(primitive);
4562 if (found.has_value()) {
4563 auto infer = found.value();
4564 return infer.IsInWhiteList();
4565 }
4566
4567 auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
4568 if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
4569 return uni_iter->second.in_white_list_;
4570 }
4571
4572 return true;
4573 }
4574
GetPrimEvaluatorConstructors()4575 PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
4576 PrimEvaluatorMap &constructor = prim_evaluator_constructors;
4577 if (!constructor.empty()) {
4578 return constructor;
4579 }
4580 std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
4581 if (constructor.empty()) {
4582 InitPrimEvaluatorConstructors();
4583 InitBuiltinPrimEvaluatorConstructors();
4584 }
4585
4586 return constructor;
4587 }
4588
4589 namespace {
IsSubtypeTuple(const AbstractBasePtr x,const TypePtr model)4590 bool IsSubtypeTuple(const AbstractBasePtr x, const TypePtr model) {
4591 MS_EXCEPTION_IF_NULL(x);
4592 MS_EXCEPTION_IF_NULL(model);
4593 auto x_tuple = dyn_cast_ptr<AbstractTuple>(x);
4594 auto model_tuple = dyn_cast_ptr<Tuple>(model);
4595
4596 if (x_tuple == nullptr || model_tuple == nullptr) {
4597 return false;
4598 }
4599
4600 if (model->IsGeneric()) {
4601 return true;
4602 }
4603
4604 if (x_tuple->size() != model_tuple->size()) {
4605 return false;
4606 }
4607
4608 for (size_t i = 0; i < x_tuple->size(); i++) {
4609 bool is_subtype = IsSubtype((*x_tuple)[i], (*model_tuple)[i]);
4610 if (!is_subtype) {
4611 return false;
4612 }
4613 }
4614 return true;
4615 }
4616
IsSubtypeArray(const AbstractBasePtr x,const TypePtr model)4617 bool IsSubtypeArray(const AbstractBasePtr x, const TypePtr model) {
4618 MS_EXCEPTION_IF_NULL(x);
4619 MS_EXCEPTION_IF_NULL(model);
4620 auto x_tensor = dyn_cast_ptr<AbstractTensor>(x);
4621 auto model_tensor = dyn_cast_ptr<TensorType>(model);
4622
4623 if (x_tensor == nullptr || model_tensor == nullptr) {
4624 return false;
4625 }
4626
4627 if (model->IsGeneric()) {
4628 return true;
4629 }
4630
4631 return IsSubtype(x_tensor->element(), model_tensor->element());
4632 }
4633
IsSubtypeList(const AbstractBasePtr x,const TypePtr model)4634 bool IsSubtypeList(const AbstractBasePtr x, const TypePtr model) {
4635 MS_EXCEPTION_IF_NULL(x);
4636 MS_EXCEPTION_IF_NULL(model);
4637 auto x_list = dyn_cast_ptr<AbstractList>(x);
4638 auto model_list = dyn_cast_ptr<List>(model);
4639
4640 if (x_list == nullptr || model_list == nullptr) {
4641 return false;
4642 }
4643
4644 if (model->IsGeneric()) {
4645 return true;
4646 }
4647
4648 if (x_list->size() != model_list->size()) {
4649 return false;
4650 }
4651
4652 bool is_subtype = true;
4653 for (size_t i = 0; i < x_list->size(); i++) {
4654 is_subtype = IsSubtype((*x_list)[i], (*model_list)[i]);
4655 if (!is_subtype) {
4656 return false;
4657 }
4658 }
4659 return is_subtype;
4660 }
4661
IsSubtypeScalar(const AbstractBasePtr x,const TypePtr model)4662 inline bool IsSubtypeScalar(const AbstractBasePtr x, const TypePtr model) {
4663 MS_EXCEPTION_IF_NULL(x);
4664 MS_EXCEPTION_IF_NULL(model);
4665 if (dyn_cast_ptr<AbstractScalar>(x) == nullptr) {
4666 return false;
4667 }
4668 auto &x_type = x->GetTypeTrack();
4669 return IsSubType(x_type, model);
4670 }
4671 } // namespace
4672
IsSubtype(const AbstractBasePtr x,const TypePtr model)4673 bool IsSubtype(const AbstractBasePtr x, const TypePtr model) {
4674 MS_EXCEPTION_IF_NULL(x);
4675 MS_EXCEPTION_IF_NULL(model);
4676 TypeId model_typeid = model->type_id();
4677 switch (model_typeid) {
4678 case kMetaTypeObject:
4679 return true;
4680 case kObjectTypeTuple:
4681 return IsSubtypeTuple(x, model);
4682 case kObjectTypeTensorType:
4683 return IsSubtypeArray(x, model);
4684 case kObjectTypeList:
4685 return IsSubtypeList(x, model);
4686 default:
4687 if (IsSubType(model, std::make_shared<Number>())) {
4688 return IsSubtypeScalar(x, model);
4689 }
4690 MS_LOG(EXCEPTION) << "Invalid model type: " << model->ToString() << ".";
4691 }
4692 }
4693
GetPrimitiveInitArgs(const PrimitivePyPtr & prim_py,const ops::OpDef * op_def)4694 AnfNodePtrList GetPrimitiveInitArgs(const PrimitivePyPtr &prim_py, const ops::OpDef *op_def) {
4695 MS_EXCEPTION_IF_NULL(prim_py);
4696 MS_EXCEPTION_IF_NULL(op_def);
4697
4698 std::vector<AnfNodePtr> prim_init_arg_nodes;
4699 auto obj = prim_py->GetPyObj();
4700
4701 for (const auto &op_arg : op_def->args_) {
4702 if (op_arg.as_init_arg_) {
4703 auto arg_name = op_arg.arg_name_;
4704 py::object arg_value = py::getattr(obj, common::SafeCStr(arg_name));
4705 ValuePtr converted_ret = nullptr;
4706 bool converted = parse::ConvertData(arg_value, &converted_ret);
4707 if (!converted) {
4708 MS_LOG(INTERNAL_EXCEPTION) << "Cannot convert initialization arg: (" << arg_name << ": " << py::str(arg_value)
4709 << ") in Primitive '" << prim_py->name() << "'.";
4710 }
4711 (void)prim_init_arg_nodes.emplace_back(NewValueNode(converted_ret));
4712 }
4713 }
4714 MS_LOG(DEBUG) << "PrimitivePy " << prim_py->name() << " has " << prim_init_arg_nodes.size() << " __init__() args";
4715 return prim_init_arg_nodes;
4716 }
4717
GeneratePrimitiveCNode(const PrimitivePtr & primitive,const ops::OpDef * op_def,const FuncGraphPtr & graph,const AnfNodePtrList & init_args_nodes,const AnfNodePtrList & call_args_nodes,const std::function<AbstractBasePtr (const AnfNodePtr &)> & eval_func)4718 CNodePtr GeneratePrimitiveCNode(const PrimitivePtr &primitive, const ops::OpDef *op_def, const FuncGraphPtr &graph,
4719 const AnfNodePtrList &init_args_nodes, const AnfNodePtrList &call_args_nodes,
4720 const std::function<AbstractBasePtr(const AnfNodePtr &)> &eval_func) {
4721 MS_EXCEPTION_IF_NULL(primitive);
4722 MS_EXCEPTION_IF_NULL(op_def);
4723
4724 auto args_pair = std::make_pair(init_args_nodes, call_args_nodes);
4725
4726 // Follow the implementations in PrimitiveArgsToInputsEvaluator, convert to base Primitive, and is_preprocessed=true
4727 auto new_prim = std::make_shared<Primitive>(*primitive);
4728 auto new_cnode = CheckAndConvertPrimitiveArgs(new_prim, graph, args_pair, eval_func, true);
4729
4730 MS_LOG(INFO) << "Convert primitive args: " << primitive->name() << ", new node: " << new_cnode->DebugString();
4731 return new_cnode;
4732 }
4733 } // namespace abstract
4734 } // namespace mindspore
4735