1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/ps/static_analysis/evaluator.h"
18
19 #include <algorithm>
20 #include <utility>
21
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "mindspore/core/ops/structure_ops.h"
25 #include "utils/hash_set.h"
26 #include "ir/func_graph_cloner.h"
27 #include "abstract/utils.h"
28 #include "pipeline/jit/ps/debug/trace.h"
29 #include "utils/ms_context.h"
30 #include "utils/compile_config.h"
31 #include "pipeline/jit/ps/static_analysis/stack_frame.h"
32 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
33 #include "frontend/expander/bprop/bprop_meta_func_graph.h"
34 #include "frontend/operator/composite/unpack_call.h"
35 #include "frontend/optimizer/ad/dfunctor.h"
36
37 namespace mindspore {
38 namespace abstract {
39 namespace {
EvalEntryLogging(const EvaluatorPtr & evaluator,const AbstractBasePtrList & arg_abs_list,const AnfNodeConfigPtr & out_conf)40 string EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_abs_list,
41 const AnfNodeConfigPtr &out_conf) {
42 MS_EXCEPTION_IF_NULL(evaluator);
43 std::stringstream ss;
44 if (out_conf != nullptr) {
45 MS_EXCEPTION_IF_NULL(out_conf->node());
46 MS_EXCEPTION_IF_NULL(out_conf->node()->scope());
47 ss << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name();
48 }
49 for (size_t i = 0; i < arg_abs_list.size(); i++) {
50 ss << evaluator->ToString() << " input[" << i
51 << "] abstract value: " << (arg_abs_list[i] ? arg_abs_list[i]->ToString() : "null abstract.");
52 }
53 return ss.str();
54 }
55
EvalFailLogging(const EvaluatorPtr & evaluator,const AbstractBasePtrList &,const AnfNodeConfigPtr & out_conf)56 void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) {
57 MS_EXCEPTION_IF_NULL(evaluator);
58 if (out_conf != nullptr) {
59 auto node = out_conf->node();
60 MS_EXCEPTION_IF_NULL(node);
61 if (IsValueNode<Primitive>(node)) {
62 MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->fullname_with_scope()
63 << ", with debug info: " << trace::GetDebugInfoStr(node->debug_info());
64 } else {
65 MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->DebugString()
66 << ", with debug info: " << trace::GetDebugInfoStr(node->debug_info());
67 }
68 }
69 }
70
ContainsAbstractAnyInner(const AbstractBasePtr & abs)71 bool ContainsAbstractAnyInner(const AbstractBasePtr &abs) {
72 MS_EXCEPTION_IF_NULL(abs);
73 if (abs->isa<AbstractSequence>()) {
74 auto abs_list = abs->cast<AbstractSequencePtr>();
75 const auto &elements = abs_list->elements();
76 return std::any_of(elements.begin(), elements.end(), [](const AbstractBasePtr &e) {
77 MS_EXCEPTION_IF_NULL(e);
78 return ContainsAbstractAnyInner(e);
79 });
80 }
81 return abs->isa<AbstractAny>();
82 }
83
GetArgsUniqueDtype(const AbstractBasePtrList & args_abs_list)84 TypePtr GetArgsUniqueDtype(const AbstractBasePtrList &args_abs_list) {
85 TypePtr res = nullptr;
86 for (const auto &arg : args_abs_list) {
87 MS_EXCEPTION_IF_NULL(arg);
88 if (!arg->isa<AbstractTensor>()) {
89 continue;
90 }
91 // Check default dtype if it's AbstractAny(AbstractTensor)
92 if (arg->isa<abstract::AbstractAny>()) {
93 auto any_arg = arg->cast_ptr<abstract::AbstractAny>();
94 MS_EXCEPTION_IF_NULL(any_arg);
95 if (!any_arg->supposed_tensor_dtype()) {
96 continue;
97 }
98 }
99 // Fetch the dtype from item of tensor.
100 auto tensor_abs = arg->cast_ptr<AbstractTensor>();
101 MS_EXCEPTION_IF_NULL(tensor_abs);
102 MS_EXCEPTION_IF_NULL(tensor_abs->element());
103 const auto dtype = tensor_abs->element()->BuildType();
104 MS_EXCEPTION_IF_NULL(dtype);
105 if (res == nullptr) {
106 res = dtype;
107 continue;
108 }
109 if (dtype != res) {
110 return nullptr;
111 }
112 }
113 return res;
114 }
115
GetCloneBpropGraph(const MetaFuncGraphPtr & meta_func_graph,const FuncGraphPtr & generated_func_graph,const AnfNodePtr & bound_node,const ScopePtr & scope)116 FuncGraphPtr GetCloneBpropGraph(const MetaFuncGraphPtr &meta_func_graph, const FuncGraphPtr &generated_func_graph,
117 const AnfNodePtr &bound_node, const ScopePtr &scope) {
118 MS_EXCEPTION_IF_NULL(meta_func_graph);
119 auto bound_cnode = dyn_cast_ptr<CNode>(bound_node);
120 if (bound_cnode == nullptr) {
121 MS_LOG(INTERNAL_EXCEPTION) << "For BpropMetaFuncGraph '" << meta_func_graph->ToString()
122 << "', the evaluator should have the bound cnode.";
123 }
124 PrimalAttrGuard primal_attr_guard(bound_cnode->primal_attrs());
125 const auto &primal_debug_infos = bound_cnode->primal_debug_infos();
126 std::vector<NodeDebugInfoPtr> primal_debug_infos_vec;
127 (void)std::copy(primal_debug_infos.begin(), primal_debug_infos.end(), std::back_inserter(primal_debug_infos_vec));
128 PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos_vec);
129 FuncGraphPtr cloned_func_graph =
130 BasicClone(generated_func_graph, false, std::make_shared<UpdateInfo>(scope, bound_cnode->debug_info()));
131 return cloned_func_graph;
132 }
133
IsSideEffectCNode(const AnfNodePtr & node)134 bool IsSideEffectCNode(const AnfNodePtr &node) {
135 MS_EXCEPTION_IF_NULL(node);
136 const auto &primitive = GetCNodePrimitiveWithoutDoSignature(node);
137 if (primitive != nullptr) {
138 auto effect_info = GetPrimEffectInfo(primitive);
139 if (effect_info.memory || effect_info.io) {
140 MS_LOG(DEBUG) << "Side Effect Primitive CNode: " << node->DebugString();
141 return true;
142 }
143 } else if (node->isa<CNode>()) {
144 // Call side effect node.
145 auto first_node = node->cast<CNodePtr>()->input(0);
146 if (first_node->isa<CNode>() && IsSideEffectCNode(first_node)) {
147 return true;
148 }
149 }
150 return false;
151 }
152
153 bool HasIsolatedSideEffectNode(const FuncGraphPtr &func_graph);
154
CheckSideEffect(const AnfNodePtr & input)155 bool CheckSideEffect(const AnfNodePtr &input) {
156 if (IsSideEffectCNode(input)) {
157 MS_LOG(DEBUG) << "Multiple side-effect node: " << input->DebugString();
158 return true;
159 }
160 // Process {Depend -> StopGradient -> MakeTuple(call function, ...)}.
161 if (input->isa<CNode>()) {
162 auto fn_input = input->cast<CNodePtr>()->input(0);
163 if (IsValueNode<prim::UnpackCall>(fn_input)) {
164 fn_input = input->cast<CNodePtr>()->input(1);
165 }
166 if (IsValueNode<FuncGraph>(fn_input)) {
167 auto func = GetValueNode<FuncGraphPtr>(fn_input);
168 if (IsSideEffectCNode(func->output()) || HasIsolatedSideEffectNode(func)) {
169 MS_LOG(DEBUG) << "Single nested side-effect node: " << input->DebugString();
170 return true;
171 }
172 }
173 }
174 return false;
175 }
176
HasIsolatedSideEffectNode(const FuncGraphPtr & func_graph)177 bool HasIsolatedSideEffectNode(const FuncGraphPtr &func_graph) {
178 MS_EXCEPTION_IF_NULL(func_graph);
179 const auto node = func_graph->output();
180 if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
181 return false;
182 }
183 auto cnode = dyn_cast<CNode>(node);
184 MS_EXCEPTION_IF_NULL(cnode);
185 auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
186 auto sort_rhs_first =
187 attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
188 if (!sort_rhs_first) {
189 // Return false if it's definitely not side-effect Depend CNode.
190 return false;
191 }
192
193 // To check side-effect nodes in {Depend -> StopGradient -> MakeTuple(...)}.
194 constexpr size_t stop_gradient_pos = 2;
195 auto stop_gradient_node = cnode->input(stop_gradient_pos);
196 auto stop_gradient_cnode = dyn_cast<CNode>(stop_gradient_node);
197 MS_EXCEPTION_IF_NULL(stop_gradient_cnode);
198 constexpr size_t isolated_node_pos = 1;
199 auto isolated_node = stop_gradient_cnode->input(isolated_node_pos);
200 MS_EXCEPTION_IF_NULL(isolated_node);
201 if (CheckSideEffect(isolated_node)) {
202 return true;
203 }
204 if (IsPrimitiveCNode(isolated_node, prim::kPrimMakeTuple)) {
205 auto isolated_cnode = dyn_cast<CNode>(isolated_node);
206 MS_EXCEPTION_IF_NULL(isolated_cnode);
207 for (size_t i = 1; i < isolated_cnode->size(); ++i) {
208 auto input = isolated_cnode->input(i);
209 if (CheckSideEffect(input)) {
210 return true;
211 }
212 }
213 }
214 return false;
215 }
216
217 // Mark the side effect at output and func graph for later constant folding.
PresetCertainSideEffect(const FuncGraphPtr & func_graph)218 void PresetCertainSideEffect(const FuncGraphPtr &func_graph) {
219 MS_EXCEPTION_IF_NULL(func_graph);
220 if (!HasIsolatedSideEffectNode(func_graph)) {
221 return;
222 }
223
224 auto new_return = func_graph->get_return();
225 new_return->set_has_side_effect_node(true);
226 func_graph->set_has_side_effect_node(true);
227 auto output_cnode = dyn_cast<CNode>(func_graph->output());
228 if (output_cnode != nullptr) {
229 output_cnode->set_has_side_effect_node(true);
230 }
231 MS_LOG(DEBUG) << "Set isolated side-effect node flag for " << func_graph->ToString();
232 }
233 } // namespace
234
ContainsAbstractAny(const AbstractBasePtrList & args_abs_list)235 bool ContainsAbstractAny(const AbstractBasePtrList &args_abs_list) {
236 return std::any_of(args_abs_list.cbegin(), args_abs_list.cend(), [](const AbstractBasePtr &item) {
237 MS_EXCEPTION_IF_NULL(item);
238 return ContainsAbstractAnyInner(item);
239 });
240 }
241
242 // MakeTuple and MakeList will handle AbstractAny in ops infer.
243 const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> ignore_any_type_checking_prims{
244 prim::kPrimReturn, prim::kPrimDepend, prim::kPrimSwitch, prim::kPrimSwitchLayer,
245 prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimIsConstant, prim::kPrimMakeKeywordArg,
246 prim::kPrimIsShapeUnknown, prim::kPrimIsDimUnknown, prim::kPrimListGetItem, prim::kPrimTupleGetItem,
247 prim::kPrimSequenceLen, prim::kPrimMakeDict, prim::kPrimMutable};
248
EvaluateArguments(const ConfigPtrList & args_conf_list)249 AbstractBasePtrList EvaluateArguments(const ConfigPtrList &args_conf_list) {
250 AbstractBasePtrList args_abs_list;
251 args_abs_list.reserve(args_conf_list.size());
252 for (auto &config : args_conf_list) {
253 MS_EXCEPTION_IF_NULL(config);
254 auto result = config->ObtainEvalResult();
255 MS_EXCEPTION_IF_NULL(result);
256 const auto &abs = result->abstract();
257 // Check if there's an inplace abstract and use it.
258 AbstractBasePtr real_abs;
259 MS_EXCEPTION_IF_NULL(abs);
260 if (abs->inplace_abstract() == nullptr) {
261 real_abs = abs;
262 } else {
263 real_abs = abs->inplace_abstract();
264 MS_LOG(INFO) << "Use inplace abstract, " << abs->ToString() << " -> " << real_abs->ToString();
265 }
266 (void)args_abs_list.emplace_back(real_abs);
267 }
268 return args_abs_list;
269 }
270
CheckIfAlwaysEval(const AnfNodeConfigPtr & conf,const AbstractBasePtr & arg)271 bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
272 MS_EXCEPTION_IF_NULL(arg);
273 auto new_sequence = dyn_cast_ptr<AbstractSequence>(arg);
274 if (new_sequence != nullptr && !new_sequence->dynamic_len() && new_sequence->sequence_nodes() != nullptr &&
275 new_sequence->size() != 0) {
276 const auto &prev_result = ObtainEvalResultFromCache(conf);
277 if (prev_result == nullptr) {
278 return false;
279 }
280 auto prev_abs = prev_result->abstract();
281 auto old_sequence = dyn_cast_ptr<AbstractSequence>(prev_abs);
282 if (old_sequence != nullptr &&
283 (old_sequence->sequence_nodes() == nullptr || old_sequence->sequence_nodes()->empty()) && *arg == *prev_abs) {
284 MS_LOG(DEBUG) << "Always eval";
285 return true;
286 }
287 }
288 return false;
289 }
290
EnterStackFrame(const AnalysisEnginePtr & engine,const StackFramePtr & current_stack_frame,const StackFramePtr & new_stack_frame)291 void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame,
292 const StackFramePtr &new_stack_frame) {
293 MS_EXCEPTION_IF_NULL(current_stack_frame);
294 MS_EXCEPTION_IF_NULL(new_stack_frame);
295 MS_EXCEPTION_IF_NULL(engine);
296 // Enter new func graph.
297 auto ¤t_node = current_stack_frame->CurrentNode();
298 auto current_context = current_stack_frame->current_context();
299 AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context, current_context->func_graph());
300 auto evaluator = new_stack_frame->evaluator();
301 MS_EXCEPTION_IF_NULL(evaluator);
302 auto new_context = new_stack_frame->current_context();
303 trace::TraceGraphEvalEnter(new_context, call_conf);
304
305 // Increase & Check the func graph call depth.
306 // Don't check it if the user set no_recursive flag.
307 IncreaseFunctionCallDepth();
308 IncreaseStackFrameDepth();
309 const auto &top_graph = parse::Parser::GetTopFuncGraph();
310 bool no_recursive = (top_graph == nullptr ? false : top_graph->has_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE));
311 const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
312 if (!no_recursive && FunctionCallDepth() > max_depth) {
313 MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
314 << ", (function call depth: " << FunctionCallDepth()
315 << ", simulate call depth: " << StackFrameDepth() << ").\n"
316 << "It's always happened with complex construction of code or infinite recursion or loop.\n"
317 << "Please check the code if it's has the infinite recursion "
318 << "or call 'context.set_context(max_call_depth=value)' to adjust this value.\n"
319 << "If max_call_depth is set larger, the system max stack depth should be set larger too "
320 << "to avoid stack overflow.\n"
321 << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
322 }
323 MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
324 << "), enter, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
325 }
326
LeaveStackFrame(const AnalysisEnginePtr &,const StackFramePtr & current_stack_frame)327 void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr ¤t_stack_frame) {
328 MS_EXCEPTION_IF_NULL(current_stack_frame);
329 // Leave current func graph.
330 auto current_context = current_stack_frame->current_context();
331 trace::TraceGraphEvalLeave(current_context);
332
333 // Decrease the func graph call depth.
334 DecreaseFunctionCallDepth();
335 DecreaseStackFrameDepth();
336
337 auto evaluator = current_stack_frame->evaluator();
338 MS_EXCEPTION_IF_NULL(evaluator);
339 MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
340 << "), leave, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
341 }
342
343 // Start running stack frames in a Evaluator.
LaunchStackFrame(const AnalysisEnginePtr & engine,const FuncGraphPtr & fg,const AnalysisContextPtr & context)344 AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
345 const AnalysisContextPtr &context) {
346 EvalResultPtr eval_result = nullptr;
347 AbstractBasePtr abstract = nullptr;
348 std::stack<StackFramePtr> stack_frames;
349 auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context, parent_context_);
350 MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame;
351 stack_frames.push(current_stack_frame);
352 while (true) {
353 current_stack_frame = stack_frames.top();
354 MS_EXCEPTION_IF_NULL(current_stack_frame);
355 if (current_stack_frame->Done()) {
356 MS_EXCEPTION_IF_NULL(abstract);
357 MS_EXCEPTION_IF_NULL(current_stack_frame->func_graph());
358 if (current_stack_frame->func_graph()->has_flag(FUNC_GRAPH_FLAG_PRIMAL_OF_BPROP)) {
359 // Set all fprop outputs as used.
360 SetSequenceElementsUseFlagsRecursively(abstract, true);
361 }
362 MS_LOG(DEBUG) << "[" << this << "/StackFrame] Leave from func graph, " << current_stack_frame;
363 stack_frames.pop();
364 if (stack_frames.empty()) {
365 MS_LOG(DEBUG) << "[" << this << "/StackFrame] Finish at func graph, " << current_stack_frame
366 << ", abstract: " << abstract->ToString();
367 break;
368 }
369 // Leave current func graph.
370 LeaveStackFrame(engine, current_stack_frame);
371 // Switch the stack frame.
372 auto last_stack_frame = current_stack_frame;
373 current_stack_frame = stack_frames.top();
374 MS_LOG(DEBUG) << "[" << this << "/StackFrame] Back to func graph, " << current_stack_frame;
375 current_stack_frame->Back(engine, last_stack_frame, eval_result);
376 continue;
377 }
378
379 auto new_stack_frame = current_stack_frame->Jump(engine);
380 if (new_stack_frame != nullptr) {
381 // Enter new func graph.
382 EnterStackFrame(engine, current_stack_frame, new_stack_frame);
383 // Update current stack frame.
384 stack_frames.push(new_stack_frame);
385 MS_LOG(DEBUG) << "[" << this << "/StackFrame] Jump to new func graph, " << new_stack_frame;
386 continue;
387 }
388
389 eval_result = current_stack_frame->Step(engine);
390 MS_EXCEPTION_IF_NULL(eval_result);
391 abstract = eval_result->abstract();
392 }
393 return abstract;
394 }
395
LaunchRecursiveEval(const AnalysisEnginePtr & engine,const FuncGraphPtr & fg,const AnalysisContextPtr & context) const396 AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
397 const AnalysisContextPtr &context) const {
398 MS_EXCEPTION_IF_NULL(fg);
399 MS_EXCEPTION_IF_NULL(engine);
400 const AnfNodePtr &func_node = fg->get_return();
401 const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
402 MS_EXCEPTION_IF_NULL(node);
403 static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
404 if (node->isa<ValueNode>() || node->isa<Parameter>() ||
405 (enable_pre_lift && IsPrimitiveCNode(node, prim::kPrimPartial))) {
406 return EXCLUDE;
407 }
408 return FOLLOW;
409 });
410 AbstractBasePtr abstract = nullptr;
411 for (const auto &node : all_nodes) {
412 MS_EXCEPTION_IF_NULL(node);
413 AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context, fg);
414 MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
415 << ", node: " << node->DebugString() << ", node_conf: " << node_conf->ToString();
416 EvalResultPtr node_eval_result = nullptr;
417 if (always_eval_flag()) {
418 MS_LOG(DEBUG) << "Always eval node";
419 node_eval_result = engine->ObtainEvalResultWithoutCache(node_conf);
420 } else {
421 node_eval_result = ObtainEvalResultFromCache(node_conf);
422 if (node_eval_result != nullptr) {
423 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
424 if (enable_eliminate_unused_element) {
425 const auto &cnode = node->cast<CNodePtr>();
426 MS_EXCEPTION_IF_NULL(cnode);
427 const auto &maybe_func = engine->GetCNodeOperatorAbstract(cnode, context, fg);
428 if (maybe_func->isa<MetaFuncGraphAbstractClosure>() || maybe_func->isa<FuncGraphAbstractClosure>()) {
429 const auto &abs_func_graph = maybe_func->cast<AbstractFunctionPtr>();
430 SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(engine, fg, cnode, abs_func_graph, context);
431 }
432 }
433 if (engine->check_side_effect() && node_eval_result->has_side_effect_node()) {
434 auto cnode = dyn_cast_ptr<CNode>(node);
435 MS_EXCEPTION_IF_NULL(cnode);
436 MS_LOG(DEBUG) << "Found side-effect, cnode: " << cnode->DebugString() << ", func_graph: " << fg->ToString();
437 cnode->set_has_side_effect_node(true);
438 fg->set_has_side_effect_node(true);
439 }
440 MS_LOG(DEBUG) << "No need to jump as found result from cache for node_config";
441 } else {
442 node_eval_result = engine->ObtainEvalResultWithoutCache(node_conf);
443 }
444 }
445 MS_EXCEPTION_IF_NULL(node_eval_result);
446 abstract = node_eval_result->abstract();
447 MS_EXCEPTION_IF_NULL(abstract);
448 MS_LOG(DEBUG) << GetInferThread() << "Eval ( " << node_conf->ToString() << ") = " << abstract->ToString();
449 }
450 MS_EXCEPTION_IF_NULL(abstract);
451 if (fg->has_flag(FUNC_GRAPH_FLAG_PRIMAL_OF_BPROP)) {
452 // Set all fprop outputs as used.
453 SetSequenceElementsUseFlagsRecursively(abstract, true);
454 }
455 return abstract;
456 }
457
Eval(AnalysisEnginePtr engine,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)458 EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list,
459 const AnfNodeConfigPtr &out_conf) {
460 auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
461 if (eval_result != nullptr) {
462 MS_LOG(ERROR) << ToString() << ArgsToString(args_abs_list) << " entered again. There is something wrong.";
463 return eval_result;
464 }
465 MS_LOG(DEBUG) << ToString() << " entered first.";
466 MS_EXCEPTION_IF_NULL(engine);
467 // Increase & Check the func graph call depth.
468 // Don't check it if the user set no_recursive flag.
469 IncreaseFunctionCallDepth();
470 const auto &top_graph = parse::Parser::GetTopFuncGraph();
471 bool no_recursive = (top_graph == nullptr ? false : top_graph->has_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE));
472 const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
473 if (!no_recursive && FunctionCallDepth() > max_depth) {
474 MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
475 << ", (function call depth: " << FunctionCallDepth()
476 << ", simulate call depth: " << StackFrameDepth() << ").\n"
477 << "It's always happened with complex construction of code or infinite recursion or loop.\n"
478 << "Please check the code if it's has the infinite recursion "
479 << "or call 'context.set_context(max_call_depth=value)' to adjust this value.\n"
480 << "If max_call_depth is set larger, the system max stack depth should be set larger too "
481 << "to avoid stack overflow.\n"
482 << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
483 }
484 MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
485 << "), enter, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
486
487 FuncGraphPtr fg = GetFuncGraph(engine, args_abs_list);
488 MS_EXCEPTION_IF_NULL(fg);
489 MS_EXCEPTION_IF_NULL(parent_context_);
490 auto context = NewContext(parent_context_, fg, args_abs_list);
491 trace::TraceGraphEvalEnter(context, out_conf);
492
493 std::size_t nargs = fg->parameters().size();
494 if (args_abs_list.size() != nargs) {
495 MS_EXCEPTION(TypeError) << "The parameters number of the function is " << fg->parameters().size()
496 << ", but the number of provided arguments is " << args_abs_list.size() << ".\n"
497 << "FunctionGraph : " << fg->ToString()
498 << "\nNodeInfo: " << trace::GetDebugInfoStr(fg->debug_info());
499 }
500 MS_LOG(DEBUG) << GetInferThread() << "@" << fg->ToString() << ArgsToString(args_abs_list) << " { ";
501 if (parent_context_->func_graph() != nullptr) {
502 MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisSchedule::thread_id() << ":"
503 << parent_context_->func_graph()->ToString() << "()->" << AnalysisSchedule::thread_id() << ":"
504 << fg->ToString() << "();";
505 }
506
507 auto func_graph_evaluator = mindspore::cast<FuncGraphEvaluator>(this);
508 if (func_graph_evaluator != nullptr) {
509 MS_EXCEPTION_IF_NULL(engine->root_func_graph());
510 if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
511 engine->set_root_context(context);
512 }
513 }
514 bool always_eval_flag = false;
515 const auto ¶meters = fg->parameters();
516 for (size_t i = 0; i < nargs; i++) {
517 const auto &arg = args_abs_list[i];
518 const auto &node = parameters[i];
519 AnfNodeConfigPtr conf = engine->MakeConfig(node, context, fg);
520 always_eval_flag = always_eval_flag || CheckIfAlwaysEval(conf, arg);
521 auto result = std::make_shared<EvalResult>(arg, nullptr);
522 engine->SaveEvalResultInCache(conf, result);
523 MS_EXCEPTION_IF_NULL(arg);
524 MS_LOG(DEBUG) << GetInferThread() << ", Save argument[" << i << "] result for " << fg->ToString()
525 << ", NodeConfig: " << conf->ToString() << ", result: " << arg << "/" << arg->ToString();
526 }
527 PushAlwaysEvalFlag(always_eval_flag);
528 if (fg->get_return() == nullptr) {
529 MS_LOG(EXCEPTION) << "The func graph " << fg << "/" << fg->ToString() << " has no return node.";
530 }
531 MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
532 << ", context: " << context->ToString() << ", return node: " << fg->get_return()->DebugString()
533 << ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL")
534 << ", current function call depth: " << FunctionCallDepth();
535 AbstractBasePtr abstract = nullptr;
536 if (engine->enable_recursive_eval()) {
537 abstract = LaunchRecursiveEval(engine, fg, context);
538 } else {
539 abstract = LaunchStackFrame(engine, fg, context);
540 }
541 PopAlwaysEvalFlag();
542
543 MS_EXCEPTION_IF_NULL(abstract);
544 MS_LOG(DEBUG) << "Analysis FuncGraph end, " << fg << "/" << fg->ToString()
545 << ", evaluated abstract: " << abstract->ToString() << ", is stub: " << fg->stub();
546 if (fg->stub()) {
547 abstract = std::make_shared<AbstractUndetermined>();
548 }
549 MS_LOG(DEBUG) << GetInferThread() << "} //" << fg->ToString() << " = " << abstract->ToString();
550
551 SyncFuncGraphSideEffectFlag(fg);
552
553 trace::TraceGraphEvalLeave(context);
554 // Decrease the func graph call depth.
555 DecreaseFunctionCallDepth();
556 MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
557 << "), leave, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
558 auto res = std::make_shared<EvalResult>(abstract, nullptr);
559 return res;
560 }
561
BroadenArgs(const AbstractBasePtrList & args_abs_list,AbstractBasePtrList * broaded_args,bool broaden_scalar)562 void BroadenArgs(const AbstractBasePtrList &args_abs_list, AbstractBasePtrList *broaded_args, bool broaden_scalar) {
563 MS_EXCEPTION_IF_NULL(broaded_args);
564 (void)std::transform(
565 args_abs_list.begin(), args_abs_list.end(), std::back_inserter(*broaded_args),
566 [&broaden_scalar](const AbstractBasePtr &arg) -> AbstractBasePtr {
567 auto arg_sequence = arg->cast<AbstractSequencePtr>();
568 if (arg_sequence != nullptr && !arg_sequence->dynamic_len() && !arg->isa<AbstractSparseTensor>()) {
569 MS_LOG(DEBUG) << "set as arg of dyn len param, arg:" << arg->ToString();
570 auto dyn_len_arg = arg_sequence->BroadenToDynamicLenSequence();
571 return broaden_scalar ? AbstractBroaden(dyn_len_arg) : dyn_len_arg->Broaden();
572 }
573 if (arg->GetValueTrack() != kValueAny) {
574 return broaden_scalar ? AbstractBroaden(arg) : arg->Broaden();
575 }
576 return arg;
577 });
578 }
579
NormalizeArgs(const AbstractBasePtrList & args_abs_list) const580 AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_abs_list) const {
581 MS_EXCEPTION_IF_NULL(func_graph_);
582 if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
583 AbstractBasePtrList broadened_list;
584 auto broaden_scalar = !func_graph_->has_flag(FUNC_GRAPH_FLAG_VMAP_TRANSFORMED);
585 BroadenArgs(args_abs_list, &broadened_list, broaden_scalar);
586 MS_LOG(DEBUG) << func_graph_->ToString() << ", original: " << mindspore::ToString(args_abs_list)
587 << ", broadened: " << mindspore::ToString(broadened_list);
588 return broadened_list;
589 }
590 return args_abs_list;
591 }
592
BroadenUndeterminedArgs(const AbstractBasePtrList & args_abs_list,const AnalysisEnginePtr & engine)593 AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_abs_list,
594 const AnalysisEnginePtr &engine) {
595 MS_EXCEPTION_IF_NULL(func_graph_);
596 if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
597 return args_abs_list;
598 }
599 // Set ignore flag for mutlithread eval.
600 engine->SetIgnoreValueFlag(AnalysisSchedule::thread_id(), func_graph_.get());
601 // Set ignore flag for recursive eval.
602 if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
603 func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE, true);
604 MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag in recursive eval.";
605 }
606 if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
607 auto normalized_args_abs_list = NormalizeArgs(args_abs_list);
608 MS_LOG(DEBUG) << "Normalized args " << mindspore::ToString(normalized_args_abs_list);
609 return normalized_args_abs_list;
610 }
611 return args_abs_list;
612 }
613
GetFuncGraph(AnalysisEnginePtr engine,const AbstractBasePtrList & args_abs_list)614 FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) {
615 auto iter = func_graph_cache_.find(args_abs_list);
616 FuncGraphPtr res;
617 if (iter == func_graph_cache_.end()) {
618 auto fg = func_graph();
619 MS_EXCEPTION_IF_NULL(fg);
620 FuncGraphPtr generated_graph = fg->GenerateFuncGraph(args_abs_list);
621 func_graph_cache_[args_abs_list] = generated_graph;
622 MS_LOG(DEBUG) << "Generate special instance of function graph: " << ToString()
623 << ", special function: " << generated_graph->ToString() << ", args: " << ArgsToString(args_abs_list);
624
625 MS_EXCEPTION_IF_NULL(engine);
626 MS_EXCEPTION_IF_NULL(engine->func_graph_manager());
627 engine->func_graph_manager()->AddFuncGraph(generated_graph);
628 if (engine->check_side_effect()) {
629 PresetCertainSideEffect(generated_graph);
630 }
631 res = generated_graph;
632 } else {
633 res = iter->second;
634 }
635
636 // For the top graph, if it is replaced by generated graph, update the top graph to the new one.
637 if (parse::Parser::GetTopFuncGraph() == func_graph()) {
638 if (res != func_graph()) {
639 parse::Parser::UpdateTopFuncGraph(res);
640 }
641 }
642 return res;
643 }
644
GetFuncGraph(AnalysisEnginePtr engine,const AbstractBasePtrList & args_abs_list)645 FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) {
646 auto iter = func_graph_cache_.find(args_abs_list);
647 if (iter != func_graph_cache_.end()) {
648 return iter->second;
649 }
650 MS_EXCEPTION_IF_NULL(meta_func_graph_);
651 (void)meta_func_graph_->GetChecker("check_infer_inputs").Execute(args_abs_list);
652
653 if (scope_ != nullptr) {
654 meta_func_graph_->set_scope_name(scope_->name());
655 }
656 if (this->bound_node() != nullptr) {
657 auto node_debug_info = bound_node()->debug_info();
658 TraceGuard trace_guard(std::make_shared<TraceGenMetaFuncGraph>(node_debug_info));
659 auto node_location = trace::GetSourceCodeDebugInfo(node_debug_info)->location();
660 if (node_location != nullptr) {
661 meta_func_graph_->set_node_expr_src(node_location->expr_src());
662 }
663 generated_func_graph_ = meta_func_graph_->GenerateFuncGraph(args_abs_list);
664 } else {
665 generated_func_graph_ = meta_func_graph_->GenerateFuncGraph(args_abs_list);
666 }
667
668 FuncGraphPtr cloned_func_graph;
669 NodeDebugInfoPtr debug_info;
670 if (this->bound_node() != nullptr) {
671 debug_info = this->bound_node()->debug_info();
672 }
673 if (meta_func_graph_->isa<expander::bprop::BpropMetaFuncGraph>()) {
674 cloned_func_graph = GetCloneBpropGraph(meta_func_graph_, generated_func_graph_, this->bound_node(), scope_);
675 } else {
676 cloned_func_graph = BasicClone(generated_func_graph_, false, std::make_shared<UpdateInfo>(scope_, debug_info));
677 }
678 func_graph_cache_[args_abs_list] = cloned_func_graph;
679 MS_EXCEPTION_IF_NULL(engine);
680 MS_EXCEPTION_IF_NULL(engine->func_graph_manager());
681 engine->func_graph_manager()->AddFuncGraph(cloned_func_graph);
682 if (engine->check_side_effect()) {
683 PresetCertainSideEffect(cloned_func_graph);
684 }
685 return cloned_func_graph;
686 }
687
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)688 EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
689 const AnfNodeConfigPtr &out_conf) {
690 AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
691 args_abs_list = NormalizeArgs(args_abs_list);
692 args_abs_list = BroadenUndeterminedArgs(args_abs_list, engine);
693 MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_abs_list, out_conf);
694 EvalResultPtr eval_result = nullptr;
695 const std::string &evaluator_name = ToString();
696 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
697 auto &cache = evaluator_cache_mgr_->GetCache();
698 auto iter = cache.find(args_abs_list);
699 if (iter == cache.end()) {
700 MS_LOG(DEBUG) << "[" << this << "/" << evaluator_name << "] cache miss, call Eval(), args: " << args_abs_list;
701 eval_result = Eval(engine, args_abs_list, out_conf);
702 MS_EXCEPTION_IF_NULL(eval_result);
703 if (eval_result->abstract() == nullptr) {
704 EvalFailLogging(shared_from_base<Evaluator>(), args_abs_list, out_conf);
705 MS_LOG(INTERNAL_EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
706 }
707 MS_LOG(DEBUG) << "[" << this << "/" << evaluator_name
708 << "] set cache. result: " << eval_result->abstract()->ToString()
709 << ", args_abs_list hash: " << AbstractBasePtrListHash(args_abs_list)
710 << ", args_abs_list: " << args_abs_list;
711 evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
712 } else {
713 eval_result = iter->second;
714 MS_EXCEPTION_IF_NULL(eval_result->abstract());
715 MS_LOG(DEBUG) << "[" << this << "/" << evaluator_name
716 << "] cache hit. result: " << eval_result->abstract()->ToString() << ", args: " << args_abs_list;
717 // Update inputs sequence nodes info, if matched in cache.
718 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
719 if (enable_eliminate_unused_element) {
720 for (size_t i = 0; i < args_abs_list.size(); ++i) {
721 auto new_sequence = dyn_cast<AbstractSequence>(args_abs_list[i]);
722 auto old_sequence = dyn_cast<AbstractSequence>(iter->first[i]);
723 if (old_sequence != nullptr && new_sequence != nullptr) {
724 MS_LOG(DEBUG) << "Before synchronize sequence nodes use flags for NodeConfig: "
725 << (out_conf ? out_conf->ToString() : "NULL") << "old_sequence: " << old_sequence->ToString()
726 << ", new_sequence: " << new_sequence->ToString();
727 SynchronizeSequenceElementsUseFlagsRecursively(old_sequence, new_sequence);
728 MS_LOG(DEBUG) << "After synchronize sequence nodes use flags for NodeConfig: "
729 << (out_conf ? out_conf->ToString() : "NULL") << ", old_sequence: " << old_sequence->ToString()
730 << ", new_sequence: " << new_sequence->ToString();
731 }
732 }
733 }
734 }
735 return eval_result;
736 }
737
EvalUndeterminedArgs(const AbstractBasePtrList & args_abs_list)738 EvalResultPtr Evaluator::EvalUndeterminedArgs(const AbstractBasePtrList &args_abs_list) {
739 auto is_undetermined = std::any_of(args_abs_list.begin(), args_abs_list.end(), [](auto &arg) -> bool {
740 return arg->IsSameTypeId(AbstractUndetermined::kTypeId);
741 });
742 if (is_undetermined) {
743 MS_LOG(DEBUG) << "Eval " << identifier_ << " return undetermined abstract result";
744 return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), std::make_shared<AttrValueMap>());
745 }
746 return nullptr;
747 }
748
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)749 EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
750 const AnfNodeConfigPtr &) {
751 AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
752
753 EvalResultPtr res;
754 // If the arguments contain Any, return Any directly.
755 // Only check in TrivialPrimEvaluator, not in TransitionPrimEvaluator.
756 const auto standard_prim_eval = dyn_cast_ptr<StandardPrimEvaluator>(shared_from_this());
757 bool ignore_any_type_checking =
758 (standard_prim_eval != nullptr &&
759 ignore_any_type_checking_prims.find(standard_prim_eval->prim()) != ignore_any_type_checking_prims.end());
760 if (!ignore_any_type_checking && ContainsAbstractAny(args_abs_list)) {
761 MS_LOG(INFO) << ToString() << " receives arguments that contain Any.";
762 auto any_abstract = std::make_shared<AbstractAny>();
763 const auto &dtype = GetArgsUniqueDtype(args_abs_list);
764 if (dtype != nullptr) {
765 MS_EXCEPTION_IF_NULL(any_abstract->element());
766 any_abstract->element()->set_type(dtype);
767 any_abstract->set_supposed_tensor_dtype(true);
768 }
769 for (const auto &abs : args_abs_list) {
770 MS_EXCEPTION_IF_NULL(abs);
771 if (abs->isa<abstract::AbstractSequence>()) {
772 SetSequenceElementsUseFlagsRecursively(abs, true);
773 }
774 }
775 res = std::make_shared<EvalResult>(any_abstract, std::make_shared<AttrValueMap>());
776 } else {
777 try {
778 res = EvalPrim(engine, args_abs_list);
779 } catch (std::exception &e) {
780 MS_LOG(ERROR) << "Primitive: <" << ToString() << "> infer failed, failed info: " << e.what();
781 std::rethrow_exception(std::current_exception());
782 }
783 }
784 MS_EXCEPTION_IF_NULL(res);
785 // Update the input abstract for inplace primitive.
786 if (inplace_prim() && !args_abs_list.empty() && args_abs_list[0] != res->abstract()) {
787 MS_LOG(DEBUG) << "Set inplace abstract, " << args_abs_list[0]->ToString() << " -> " << res->abstract()->ToString();
788 // Always update the inplace abstract.
789 args_abs_list[0]->set_inplace_abstract(res->abstract());
790 }
791 return res;
792 }
793
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)794 EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
795 const AnfNodeConfigPtr &out_conf) {
796 if (args_conf_list.empty() && identifier_ != "MakeTupleEvaluator" && identifier_ != "MakeListEvaluator" &&
797 identifier_ != "RaiseEvaluator" && identifier_ != "ConstexprEvaluator") {
798 MS_LOG(INTERNAL_EXCEPTION) << "Size should be greater than 0, during running " << identifier_;
799 }
800 AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
801 EvalResultPtr res = EvalPrim(engine, args_abs_list, args_conf_list[0], out_conf);
802 MS_EXCEPTION_IF_NULL(res);
803 // Update the input abstract for inplace primitive.
804 if (inplace_prim() && !args_abs_list.empty() && args_abs_list[0] != res->abstract()) {
805 MS_LOG(DEBUG) << "Set inplace abstract, " << args_abs_list[0]->ToString() << " -> " << res->abstract()->ToString();
806 // Always update the inplace abstract.
807 args_abs_list[0]->set_inplace_abstract(res->abstract());
808 }
809 // No need to cache.
810 return res;
811 }
812
Run(AnalysisEnginePtr,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)813 EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list,
814 const AnfNodeConfigPtr &) {
815 return EvalPrim(args_conf_list);
816 }
817
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)818 EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
819 const AnfNodeConfigPtr &out_conf) {
820 AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
821 EvalResultPtr res = sub_evaluator_->Run(engine, args_conf_list, out_conf);
822 // Don't lookup from cache, as different out_conf with same node but different context
823 // may add different entry to anfnode_config_map_, like getattr primitive.
824 evaluator_cache_mgr_->SetValue(args_abs_list, res);
825 return res;
826 }
827
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)828 EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
829 const AnfNodeConfigPtr &out_conf) {
830 AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
831 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
832 auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
833 if (eval_result != nullptr) {
834 return eval_result;
835 }
836
837 ConfigPtrList partial_args_conf_list;
838 // Join arguments in partial and the rest arguments from args_conf_list.
839 (void)std::transform(args_abs_list_.begin(), args_abs_list_.end(), std::back_inserter(partial_args_conf_list),
840 [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
841
842 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(partial_args_conf_list),
843 [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
844 EvalResultPtr res = evaluator_->Run(engine, partial_args_conf_list, out_conf);
845 evaluator_cache_mgr_->SetValue(args_abs_list, res);
846 return res;
847 }
848
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)849 EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
850 const AnfNodeConfigPtr &out_conf) {
851 AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
852 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
853 auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
854 if (eval_result != nullptr) {
855 return eval_result;
856 }
857
858 // Call the original evaluator, get the result: y = f(x)
859 EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
860 MS_EXCEPTION_IF_NULL(result);
861 // If the primal func graph's output is sequence, set its elements use flags all true.
862 SetSequenceElementsUseFlagsRecursively(result->abstract(), true);
863 // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
864 // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
865 AbstractBasePtrList bparams;
866 bparams.push_back(SensitivityTransform(primal_func_));
867 // Check if primal func graph has the primitive returned sparse result in its bprop().
868 auto real_primal_func = dyn_cast_ptr<FuncGraphAbstractClosure>(primal_func_);
869 MS_EXCEPTION_IF_NULL(real_primal_func);
870 FuncGraphPtr primal_func_graph = real_primal_func->func_graph();
871 MS_EXCEPTION_IF_NULL(primal_func_graph);
872 bool has_sparse_bprop_prim = primal_func_graph->has_flag(FUNC_GRAPH_FLAG_SPARSE_BPROP);
873 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(bparams),
874 [&has_sparse_bprop_prim](const AbstractBasePtr &arg_abs) -> AbstractBasePtr {
875 MS_EXCEPTION_IF_NULL(arg_abs);
876 if (has_sparse_bprop_prim && arg_abs->isa<AbstractTensor>()) {
877 return std::make_shared<AbstractUndetermined>();
878 }
879 return SensitivityTransform(arg_abs);
880 });
881 AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
882 AbstractFunctionPtr bprop;
883 MS_EXCEPTION_IF_NULL(out_conf);
884 auto current_node = out_conf->node();
885 MS_EXCEPTION_IF_NULL(current_node);
886 if (current_node->isa<CNode>()) {
887 auto current_cnode = current_node->cast<CNodePtr>();
888 auto effect_info = current_cnode->GetEffectInfo();
889 if (current_cnode->IsEffectHandled() && effect_info.back_mem) {
890 AbstractBasePtrList bprop_inputs{SensitivityTransform(result->abstract()), kUMonad->ToAbstract()};
891 bprop = std::make_shared<VirtualAbstractClosure>(bprop_inputs, bparams_final);
892 } else {
893 bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
894 }
895 } else {
896 bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
897 }
898
899 // J(f)(J(x)) return a tuple (y, bprop_f)
900 AbstractBasePtrList jargs = {result->abstract(), bprop};
901 AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
902 auto res = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
903 evaluator_cache_mgr_->SetValue(args_abs_list, res);
904 return res;
905 }
906
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)907 EvalResultPtr TaylorEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
908 const AnfNodeConfigPtr &) {
909 AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
910 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
911 auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
912 if (eval_result != nullptr) {
913 return eval_result;
914 }
915
916 // Call the original evaluator, get the result: y = f(x)
917 EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
918 MS_EXCEPTION_IF_NULL(result);
919 evaluator_cache_mgr_->SetValue(args_abs_list, result);
920 return result;
921 }
922
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)923 EvalResultPtr ShardEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
924 const AnfNodeConfigPtr &) {
925 AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
926 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
927 auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
928 if (eval_result != nullptr) {
929 return eval_result;
930 }
931
932 // Call the original evaluator, get the result: y = f(x)
933 EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
934 MS_EXCEPTION_IF_NULL(result);
935 auto res = std::make_shared<EvalResult>(result->abstract(), std::make_shared<AttrValueMap>());
936 evaluator_cache_mgr_->SetValue(args_abs_list, res);
937 return res;
938 }
939
940 namespace {
ReduceDim(int * axis,const AbstractBasePtr & orig_abs,int * axis_size)941 AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_size) {
942 MS_EXCEPTION_IF_NULL(axis);
943 MS_EXCEPTION_IF_NULL(orig_abs);
944 MS_EXCEPTION_IF_NULL(axis_size);
945 if (!orig_abs->isa<AbstractTensor>()) {
946 MS_LOG(EXCEPTION) << "The orig_abs should be AbstractTensor when corresponding axis is " << *axis << ", but got a "
947 << orig_abs->ToString() << ". Tip: Please check the correspondence between "
948 << "vmap's 'in_axes' and inputs. You may want to explicitly specify the 'in_axes' "
949 << "corresponding to " << orig_abs->ToString() << " as 'None' to solve this problem.";
950 }
951 auto orig_abs_shape = dyn_cast_ptr<Shape>(orig_abs->BuildShape());
952 MS_EXCEPTION_IF_NULL(orig_abs_shape);
953 ShapeVector orig_shape = orig_abs_shape->shape();
954 int shape_len = SizeToInt(orig_shape.size());
955 if (*axis < -shape_len || *axis >= shape_len) {
956 MS_LOG(EXCEPTION) << "The axis: " << *axis << " in 'in_axes' is out of bounds for array of dimension ["
957 << -shape_len << "," << shape_len << ").";
958 }
959 *axis = *axis < 0 ? shape_len + *axis : *axis;
960 auto temp_axes_size = orig_shape[IntToSize(*axis)];
961 if (*axis_size == -1) {
962 *axis_size = LongToInt(temp_axes_size);
963 } else if (*axis_size != temp_axes_size) {
964 MS_LOG(EXCEPTION) << "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got "
965 << *axis_size << " and " << temp_axes_size << ".";
966 }
967 (void)orig_shape.erase(orig_shape.begin() + *axis);
968 BaseShapePtr new_shape = std::make_shared<Shape>(orig_shape);
969 MS_EXCEPTION_IF_NULL(orig_abs->Clone());
970 AbstractBasePtr abs_clone = orig_abs->Clone()->Broaden();
971 abs_clone->set_shape(new_shape);
972 return abs_clone;
973 }
974
GetLogicalViewAbs(const AbstractBasePtr & physical_view_abs,const ValuePtr & in_axes,int * axis_size)975 AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, const ValuePtr &in_axes, int *axis_size) {
976 MS_EXCEPTION_IF_NULL(physical_view_abs);
977 MS_EXCEPTION_IF_NULL(in_axes);
978 auto physical_view_abs_sequence = dyn_cast_ptr<AbstractSequence>(physical_view_abs);
979 if (physical_view_abs_sequence != nullptr) {
980 AbstractBasePtrList abs_list = physical_view_abs_sequence->elements();
981 AbstractBasePtrList logical_view_abs_list;
982 auto in_axes_seq = dyn_cast_ptr<ValueSequeue>(in_axes);
983 int index = 0;
984 (void)std::transform(abs_list.begin(), abs_list.end(), std::back_inserter(logical_view_abs_list),
985 [&axis_size, &index, in_axes_seq, in_axes](const AbstractBasePtr &sub_abs) -> AbstractBasePtr {
986 ValuePtr sub_in_axes = in_axes;
987 if (in_axes->isa<ValueSequeue>()) {
988 sub_in_axes = (*in_axes_seq)[index];
989 index++;
990 }
991 return GetLogicalViewAbs(sub_abs, sub_in_axes, axis_size);
992 });
993 if (physical_view_abs->isa<AbstractList>()) {
994 return std::make_shared<AbstractList>(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes());
995 }
996 return std::make_shared<AbstractTuple>(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes());
997 }
998 ValuePtr in_axis = in_axes;
999 if (in_axis->isa<Int64Imm>()) {
1000 int axis = dyn_cast_ptr<Int64Imm>(in_axis)->value();
1001 auto logical_view_abs = ReduceDim(&axis, physical_view_abs, axis_size);
1002 return logical_view_abs;
1003 }
1004 if (!in_axis->isa<None>()) {
1005 MS_LOG(EXCEPTION) << "The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm, but got a "
1006 << in_axis->ToString() << ".";
1007 }
1008 // in_axis is None.
1009 return physical_view_abs;
1010 }
1011
ExtendDim(int * axis,const AbstractBasePtr & orig_abs,int axis_size)1012 AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_size) {
1013 MS_EXCEPTION_IF_NULL(orig_abs);
1014 MS_EXCEPTION_IF_NULL(axis);
1015 AbstractBasePtr out_abs = nullptr;
1016 ShapeVector orig_shape;
1017 if (orig_abs->isa<AbstractTensor>()) {
1018 auto shape = dyn_cast_ptr<Shape>(orig_abs->BuildShape());
1019 if (shape != nullptr) {
1020 orig_shape = shape->shape();
1021 }
1022 if (std::any_of(orig_shape.begin(), orig_shape.end(),
1023 [](ShapeValueDType s) { return s == Shape::kShapeRankAny; })) {
1024 return orig_abs;
1025 }
1026 }
1027 int shape_len = SizeToInt(orig_shape.size() + 1);
1028 if (*axis < -shape_len || *axis >= shape_len) {
1029 MS_LOG(EXCEPTION) << "The axis: " << *axis << " in 'out_axes' is out of bounds for array of dimension ["
1030 << -shape_len << "," << shape_len << ").";
1031 }
1032 *axis = *axis < 0 ? shape_len + *axis : *axis;
1033 (void)orig_shape.insert(orig_shape.begin() + *axis, axis_size);
1034 BaseShapePtr new_shape = std::make_shared<Shape>(orig_shape);
1035 if (orig_abs->isa<AbstractTensor>()) {
1036 auto tmp_abs = orig_abs->Clone();
1037 MS_EXCEPTION_IF_NULL(tmp_abs);
1038 out_abs = tmp_abs->Broaden();
1039 MS_EXCEPTION_IF_NULL(out_abs);
1040 out_abs->set_shape(new_shape);
1041 } else if (orig_abs->isa<AbstractScalar>()) {
1042 out_abs = std::make_shared<AbstractTensor>(orig_abs, new_shape);
1043 } else {
1044 MS_LOG(EXCEPTION) << "The outputs of vmap's 'fn' should be consisting of tensors or constants, but got "
1045 << orig_abs->ToString() << ".";
1046 }
1047 return out_abs;
1048 }
1049
GetPhysicalViewAbs(const AbstractBasePtr & logical_view_abs,const ValuePtr & out_axes,int axis_size)1050 AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, const ValuePtr &out_axes, int axis_size) {
1051 MS_EXCEPTION_IF_NULL(logical_view_abs);
1052 MS_EXCEPTION_IF_NULL(out_axes);
1053 auto logical_view_abs_sequence = dyn_cast_ptr<AbstractSequence>(logical_view_abs);
1054 if (logical_view_abs_sequence != nullptr) {
1055 AbstractBasePtrList logical_view_abs_list = logical_view_abs_sequence->elements();
1056 AbstractBasePtrList physical_view_abs_list;
1057 auto out_axes_seq = dyn_cast_ptr<ValueSequeue>(out_axes);
1058 if (out_axes_seq != nullptr) {
1059 if (logical_view_abs_list.size() != out_axes_seq->size()) {
1060 MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': "
1061 << logical_view_abs_list.size() << ", but got size: " << out_axes_seq->size() << ".";
1062 }
1063 }
1064 int index = 0;
1065 (void)std::transform(
1066 logical_view_abs_list.begin(), logical_view_abs_list.end(), std::back_inserter(physical_view_abs_list),
1067 [&axis_size, &index, out_axes_seq, out_axes](const AbstractBasePtr &arg_abs) -> AbstractBasePtr {
1068 ValuePtr sub_out_axes = out_axes;
1069 if (out_axes->isa<ValueSequeue>()) {
1070 sub_out_axes = (*out_axes_seq)[index];
1071 index++;
1072 }
1073 if (arg_abs->isa<AbstractSequence>()) {
1074 return GetPhysicalViewAbs(arg_abs, sub_out_axes, axis_size);
1075 }
1076 if (sub_out_axes->isa<Int64Imm>()) {
1077 int axis = static_cast<int>(dyn_cast_ptr<Int64Imm>(sub_out_axes)->value());
1078 return ExtendDim(&axis, arg_abs, axis_size);
1079 } else if (sub_out_axes->isa<None>()) {
1080 return arg_abs;
1081 }
1082 MS_LOG(EXCEPTION) << "The axis in vmap's 'out_axes' should be a None or a scalar of type Int64Imm, but got a "
1083 << sub_out_axes->ToString() << ".";
1084 });
1085 if (logical_view_abs->isa<AbstractList>()) {
1086 return std::make_shared<AbstractList>(physical_view_abs_list);
1087 }
1088 return std::make_shared<AbstractTuple>(physical_view_abs_list);
1089 }
1090
1091 // for the single output case, outputs: A, and out_axes: 1 or (1,).
1092 ValuePtr sub_out_axes = out_axes;
1093 ValueSequeuePtr out_axes_seq = dyn_cast<ValueSequeue>(out_axes);
1094 if (out_axes_seq != nullptr) {
1095 if (out_axes_seq->size() != 1) {
1096 MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the result size: 1, but got size: "
1097 << out_axes_seq->size() << ".";
1098 }
1099 sub_out_axes = (*out_axes_seq)[0];
1100 }
1101
1102 int axis = 0;
1103 auto axis_int_ptr = dyn_cast_ptr<Int64Imm>(sub_out_axes);
1104 if (axis_int_ptr != nullptr) {
1105 axis = LongToInt(axis_int_ptr->value());
1106 } else {
1107 MS_LOG(EXCEPTION) << "The axis in vmap's 'out_axes' should be a None or a scalar of type Int64Imm, but got a "
1108 << sub_out_axes->ToString() << ".";
1109 }
1110 return ExtendDim(&axis, logical_view_abs, axis_size);
1111 }
1112 } // namespace
1113
1114 // According to the in_axes (e.g. (1,(None,3))), the abstraction of input parameters with the
1115 // physical view (e.g. (A,(B,C))) are converted into that with the logical view (e.g.(a,(b,c))),
1116 // more specific, the input `A` with shape (32, 16, 8) fitting the axis index `1` is converted in to
1117 // `a` with shape (32, 8). And then leverage the original graph to perform the evaluation.
1118 // Finally, the outputs with the logical view are converted back into the physical view in
1119 // combination with the out_axes. The inferring result is consistent with that after eliminating
1120 // the VmapOperator.
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)1121 EvalResultPtr VmapEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
1122 const AnfNodeConfigPtr &) {
1123 AbstractBasePtrList args_abs_list;
1124 int axis_size = -1;
1125 int index = 0;
1126 auto in_axes = in_axes_;
1127 auto in_axes_seq = dyn_cast_ptr<ValueSequeue>(in_axes);
1128 (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
1129 [&axis_size, &index, in_axes_seq, in_axes](const ConfigPtr &conf) -> AbstractBasePtr {
1130 MS_EXCEPTION_IF_NULL(conf);
1131 AbstractBasePtr abs = conf->ObtainEvalResult()->abstract();
1132 MS_EXCEPTION_IF_NULL(abs);
1133 // Drop the side effect tag parameters, because it has no mapping axis.
1134 // e.g. args=(A,(B,C),U), in_axes=(1,(None,3))
1135 if (abs->isa<AbstractMonad>()) {
1136 return abs;
1137 }
1138 ValuePtr sub_in_axes = in_axes;
1139 MS_EXCEPTION_IF_NULL(in_axes);
1140 if (in_axes->isa<ValueSequeue>()) {
1141 sub_in_axes = (*in_axes_seq)[index];
1142 index++;
1143 }
1144 auto arg_abs = GetLogicalViewAbs(abs, sub_in_axes, &axis_size);
1145 return arg_abs;
1146 });
1147 MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
1148 auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
1149 if (eval_result != nullptr) {
1150 return eval_result;
1151 }
1152 ConfigPtrList virtual_conf_list;
1153 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(virtual_conf_list),
1154 [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
1155
1156 // Call the original evaluator, get the result: y = f(x)
1157 EvalResultPtr result = evaluator_->Run(engine, virtual_conf_list, nullptr);
1158 MS_EXCEPTION_IF_NULL(result);
1159
1160 // If the primal func graph's output is sequence, set its elements use flags all true.
1161 SetSequenceElementsUseFlagsRecursively(result->abstract(), true);
1162
1163 if (axis_size == -1 && cell_size_ != 0) {
1164 axis_size = SizeToInt(cell_size_);
1165 } else if (axis_size != -1 && cell_size_ != 0 && axis_size != SizeToInt(cell_size_)) {
1166 MS_EXCEPTION(ValueError) << "If you want to execute the model ensembling parallel training, please make sure "
1167 << "the 'axis_size' in the scope of vmap consistent with the cell size of the input "
1168 << "'CellList', otherwise, please do not enter 'CellList' as the first argument, "
1169 << "but we get axis_size: " << axis_size << " and the cell size: " << cell_size_ << ".";
1170 }
1171
1172 AbstractBasePtr result_abs = result->abstract();
1173 AbstractBasePtr after_vmap = GetPhysicalViewAbs(result_abs, out_axes_, axis_size);
1174
1175 auto res = std::make_shared<EvalResult>(after_vmap, std::make_shared<AttrValueMap>());
1176 evaluator_cache_mgr_->SetValue(args_abs_list, res);
1177 return res;
1178 }
1179
Eval(AnalysisEnginePtr,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)1180 EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_abs_list,
1181 const AnfNodeConfigPtr &out_conf) {
1182 if (args_abs_list.size() != args_abs_list_.size()) {
1183 MS_LOG(INTERNAL_EXCEPTION) << "Arguments mismatch, parameters no: " << args_abs_list_.size()
1184 << ", arguments no: " << args_abs_list.size();
1185 }
1186 const auto sense_param_index = args_abs_list.size() - 1;
1187 bool sense_param_flag = false;
1188 MS_EXCEPTION_IF_NULL(this->bound_node());
1189 if (this->bound_node()->isa<CNode>()) {
1190 sense_param_flag = this->bound_node()->cast<CNodePtr>()->HasAttr("sens_param_");
1191 }
1192 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1193 // Check each parameter and argument match;
1194 for (std::size_t i = 0; i < args_abs_list.size(); i++) {
1195 MS_EXCEPTION_IF_NULL(args_abs_list[i]);
1196 // For VirtualAbstractClosure, likely J's bprop, we just set its tuple arguments as used before really grad.
1197 if (enable_eliminate_unused_element && args_abs_list[i]->isa<AbstractSequence>()) {
1198 MS_LOG(INFO) << "Notice: For VirtualAbstractClosure, update all use flags as true for arguments[" << i
1199 << "]: " << args_abs_list[i]->ToString();
1200 SetSequenceElementsUseFlagsRecursively(args_abs_list[i], true);
1201 }
1202 if (i == sense_param_index && sense_param_flag) {
1203 const auto &sense_shape = args_abs_list[i]->BuildShape();
1204 MS_EXCEPTION_IF_NULL(sense_shape);
1205 if (sense_shape->IsDynamic()) {
1206 MS_EXCEPTION(ValueError) << "The shape of sense must not be dynamic shape."
1207 << "\nFor more details with 'sense', please refer to "
1208 << "https://www.mindspore.cn/docs/zh-CN/master/faq/network_compilation.html.";
1209 }
1210 }
1211 (void)args_abs_list[i]->Join(args_abs_list_[i]);
1212 }
1213 return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>());
1214 }
1215
SingleRun(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)1216 EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
1217 const AnfNodeConfigPtr &out_conf) {
1218 EvalResultPtr result;
1219 try {
1220 result = this->Run(engine, args_conf_list, out_conf);
1221 } catch (const std::exception &ex) {
1222 MS_LOG(INFO) << "Eval " << ToString() << " throw exception.";
1223 AnalysisSchedule::GetInstance().HandleException(ex);
1224 }
1225 AnalysisSchedule::GetInstance().Wait();
1226 return result;
1227 }
1228 } // namespace abstract
1229 } // namespace mindspore
1230