• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/ps/static_analysis/evaluator.h"
18 
19 #include <algorithm>
20 #include <utility>
21 
22 #include "mindspore/core/ops/sequence_ops.h"
23 #include "mindspore/core/ops/framework_ops.h"
24 #include "mindspore/core/ops/structure_ops.h"
25 #include "utils/hash_set.h"
26 #include "ir/func_graph_cloner.h"
27 #include "abstract/utils.h"
28 #include "pipeline/jit/ps/debug/trace.h"
29 #include "utils/ms_context.h"
30 #include "utils/compile_config.h"
31 #include "pipeline/jit/ps/static_analysis/stack_frame.h"
32 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
33 #include "frontend/expander/bprop/bprop_meta_func_graph.h"
34 #include "frontend/operator/composite/unpack_call.h"
35 #include "frontend/optimizer/ad/dfunctor.h"
36 
37 namespace mindspore {
38 namespace abstract {
39 namespace {
EvalEntryLogging(const EvaluatorPtr & evaluator,const AbstractBasePtrList & arg_abs_list,const AnfNodeConfigPtr & out_conf)40 string EvalEntryLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &arg_abs_list,
41                         const AnfNodeConfigPtr &out_conf) {
42   MS_EXCEPTION_IF_NULL(evaluator);
43   std::stringstream ss;
44   if (out_conf != nullptr) {
45     MS_EXCEPTION_IF_NULL(out_conf->node());
46     MS_EXCEPTION_IF_NULL(out_conf->node()->scope());
47     ss << "Evaluator " << evaluator->ToString() << " run for " << out_conf->node()->scope()->name();
48   }
49   for (size_t i = 0; i < arg_abs_list.size(); i++) {
50     ss << evaluator->ToString() << " input[" << i
51        << "] abstract value: " << (arg_abs_list[i] ? arg_abs_list[i]->ToString() : "null abstract.");
52   }
53   return ss.str();
54 }
55 
EvalFailLogging(const EvaluatorPtr & evaluator,const AbstractBasePtrList &,const AnfNodeConfigPtr & out_conf)56 void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &, const AnfNodeConfigPtr &out_conf) {
57   MS_EXCEPTION_IF_NULL(evaluator);
58   if (out_conf != nullptr) {
59     auto node = out_conf->node();
60     MS_EXCEPTION_IF_NULL(node);
61     if (IsValueNode<Primitive>(node)) {
62       MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->fullname_with_scope()
63                     << ", with debug info: " << trace::GetDebugInfoStr(node->debug_info());
64     } else {
65       MS_LOG(ERROR) << "Evaluator " << evaluator->ToString() << " run failed for node " << node->DebugString()
66                     << ", with debug info: " << trace::GetDebugInfoStr(node->debug_info());
67     }
68   }
69 }
70 
ContainsAbstractAnyInner(const AbstractBasePtr & abs)71 bool ContainsAbstractAnyInner(const AbstractBasePtr &abs) {
72   MS_EXCEPTION_IF_NULL(abs);
73   if (abs->isa<AbstractSequence>()) {
74     auto abs_list = abs->cast<AbstractSequencePtr>();
75     const auto &elements = abs_list->elements();
76     return std::any_of(elements.begin(), elements.end(), [](const AbstractBasePtr &e) {
77       MS_EXCEPTION_IF_NULL(e);
78       return ContainsAbstractAnyInner(e);
79     });
80   }
81   return abs->isa<AbstractAny>();
82 }
83 
GetArgsUniqueDtype(const AbstractBasePtrList & args_abs_list)84 TypePtr GetArgsUniqueDtype(const AbstractBasePtrList &args_abs_list) {
85   TypePtr res = nullptr;
86   for (const auto &arg : args_abs_list) {
87     MS_EXCEPTION_IF_NULL(arg);
88     if (!arg->isa<AbstractTensor>()) {
89       continue;
90     }
91     // Check default dtype if it's AbstractAny(AbstractTensor)
92     if (arg->isa<abstract::AbstractAny>()) {
93       auto any_arg = arg->cast_ptr<abstract::AbstractAny>();
94       MS_EXCEPTION_IF_NULL(any_arg);
95       if (!any_arg->supposed_tensor_dtype()) {
96         continue;
97       }
98     }
99     // Fetch the dtype from item of tensor.
100     auto tensor_abs = arg->cast_ptr<AbstractTensor>();
101     MS_EXCEPTION_IF_NULL(tensor_abs);
102     MS_EXCEPTION_IF_NULL(tensor_abs->element());
103     const auto dtype = tensor_abs->element()->BuildType();
104     MS_EXCEPTION_IF_NULL(dtype);
105     if (res == nullptr) {
106       res = dtype;
107       continue;
108     }
109     if (dtype != res) {
110       return nullptr;
111     }
112   }
113   return res;
114 }
115 
GetCloneBpropGraph(const MetaFuncGraphPtr & meta_func_graph,const FuncGraphPtr & generated_func_graph,const AnfNodePtr & bound_node,const ScopePtr & scope)116 FuncGraphPtr GetCloneBpropGraph(const MetaFuncGraphPtr &meta_func_graph, const FuncGraphPtr &generated_func_graph,
117                                 const AnfNodePtr &bound_node, const ScopePtr &scope) {
118   MS_EXCEPTION_IF_NULL(meta_func_graph);
119   auto bound_cnode = dyn_cast_ptr<CNode>(bound_node);
120   if (bound_cnode == nullptr) {
121     MS_LOG(INTERNAL_EXCEPTION) << "For BpropMetaFuncGraph '" << meta_func_graph->ToString()
122                                << "', the evaluator should have the bound cnode.";
123   }
124   PrimalAttrGuard primal_attr_guard(bound_cnode->primal_attrs());
125   const auto &primal_debug_infos = bound_cnode->primal_debug_infos();
126   std::vector<NodeDebugInfoPtr> primal_debug_infos_vec;
127   (void)std::copy(primal_debug_infos.begin(), primal_debug_infos.end(), std::back_inserter(primal_debug_infos_vec));
128   PrimalDebugInfoGuard primal_debug_info_guard(primal_debug_infos_vec);
129   FuncGraphPtr cloned_func_graph =
130     BasicClone(generated_func_graph, false, std::make_shared<UpdateInfo>(scope, bound_cnode->debug_info()));
131   return cloned_func_graph;
132 }
133 
IsSideEffectCNode(const AnfNodePtr & node)134 bool IsSideEffectCNode(const AnfNodePtr &node) {
135   MS_EXCEPTION_IF_NULL(node);
136   const auto &primitive = GetCNodePrimitiveWithoutDoSignature(node);
137   if (primitive != nullptr) {
138     auto effect_info = GetPrimEffectInfo(primitive);
139     if (effect_info.memory || effect_info.io) {
140       MS_LOG(DEBUG) << "Side Effect Primitive CNode: " << node->DebugString();
141       return true;
142     }
143   } else if (node->isa<CNode>()) {
144     // Call side effect node.
145     auto first_node = node->cast<CNodePtr>()->input(0);
146     if (first_node->isa<CNode>() && IsSideEffectCNode(first_node)) {
147       return true;
148     }
149   }
150   return false;
151 }
152 
153 bool HasIsolatedSideEffectNode(const FuncGraphPtr &func_graph);
154 
CheckSideEffect(const AnfNodePtr & input)155 bool CheckSideEffect(const AnfNodePtr &input) {
156   if (IsSideEffectCNode(input)) {
157     MS_LOG(DEBUG) << "Multiple side-effect node: " << input->DebugString();
158     return true;
159   }
160   // Process {Depend -> StopGradient -> MakeTuple(call function, ...)}.
161   if (input->isa<CNode>()) {
162     auto fn_input = input->cast<CNodePtr>()->input(0);
163     if (IsValueNode<prim::UnpackCall>(fn_input)) {
164       fn_input = input->cast<CNodePtr>()->input(1);
165     }
166     if (IsValueNode<FuncGraph>(fn_input)) {
167       auto func = GetValueNode<FuncGraphPtr>(fn_input);
168       if (IsSideEffectCNode(func->output()) || HasIsolatedSideEffectNode(func)) {
169         MS_LOG(DEBUG) << "Single nested side-effect node: " << input->DebugString();
170         return true;
171       }
172     }
173   }
174   return false;
175 }
176 
HasIsolatedSideEffectNode(const FuncGraphPtr & func_graph)177 bool HasIsolatedSideEffectNode(const FuncGraphPtr &func_graph) {
178   MS_EXCEPTION_IF_NULL(func_graph);
179   const auto node = func_graph->output();
180   if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
181     return false;
182   }
183   auto cnode = dyn_cast<CNode>(node);
184   MS_EXCEPTION_IF_NULL(cnode);
185   auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
186   auto sort_rhs_first =
187     attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
188   if (!sort_rhs_first) {
189     // Return false if it's definitely not side-effect Depend CNode.
190     return false;
191   }
192 
193   // To check side-effect nodes in {Depend -> StopGradient -> MakeTuple(...)}.
194   constexpr size_t stop_gradient_pos = 2;
195   auto stop_gradient_node = cnode->input(stop_gradient_pos);
196   auto stop_gradient_cnode = dyn_cast<CNode>(stop_gradient_node);
197   MS_EXCEPTION_IF_NULL(stop_gradient_cnode);
198   constexpr size_t isolated_node_pos = 1;
199   auto isolated_node = stop_gradient_cnode->input(isolated_node_pos);
200   MS_EXCEPTION_IF_NULL(isolated_node);
201   if (CheckSideEffect(isolated_node)) {
202     return true;
203   }
204   if (IsPrimitiveCNode(isolated_node, prim::kPrimMakeTuple)) {
205     auto isolated_cnode = dyn_cast<CNode>(isolated_node);
206     MS_EXCEPTION_IF_NULL(isolated_cnode);
207     for (size_t i = 1; i < isolated_cnode->size(); ++i) {
208       auto input = isolated_cnode->input(i);
209       if (CheckSideEffect(input)) {
210         return true;
211       }
212     }
213   }
214   return false;
215 }
216 
217 // Mark the side effect at output and func graph for later constant folding.
PresetCertainSideEffect(const FuncGraphPtr & func_graph)218 void PresetCertainSideEffect(const FuncGraphPtr &func_graph) {
219   MS_EXCEPTION_IF_NULL(func_graph);
220   if (!HasIsolatedSideEffectNode(func_graph)) {
221     return;
222   }
223 
224   auto new_return = func_graph->get_return();
225   new_return->set_has_side_effect_node(true);
226   func_graph->set_has_side_effect_node(true);
227   auto output_cnode = dyn_cast<CNode>(func_graph->output());
228   if (output_cnode != nullptr) {
229     output_cnode->set_has_side_effect_node(true);
230   }
231   MS_LOG(DEBUG) << "Set isolated side-effect node flag for " << func_graph->ToString();
232 }
233 }  // namespace
234 
ContainsAbstractAny(const AbstractBasePtrList & args_abs_list)235 bool ContainsAbstractAny(const AbstractBasePtrList &args_abs_list) {
236   return std::any_of(args_abs_list.cbegin(), args_abs_list.cend(), [](const AbstractBasePtr &item) {
237     MS_EXCEPTION_IF_NULL(item);
238     return ContainsAbstractAnyInner(item);
239   });
240 }
241 
242 // MakeTuple and MakeList will handle AbstractAny in ops infer.
243 const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> ignore_any_type_checking_prims{
244   prim::kPrimReturn,         prim::kPrimDepend,       prim::kPrimSwitch,      prim::kPrimSwitchLayer,
245   prim::kPrimUpdateState,    prim::kPrimLoad,         prim::kPrimIsConstant,  prim::kPrimMakeKeywordArg,
246   prim::kPrimIsShapeUnknown, prim::kPrimIsDimUnknown, prim::kPrimListGetItem, prim::kPrimTupleGetItem,
247   prim::kPrimSequenceLen,    prim::kPrimMakeDict,     prim::kPrimMutable};
248 
EvaluateArguments(const ConfigPtrList & args_conf_list)249 AbstractBasePtrList EvaluateArguments(const ConfigPtrList &args_conf_list) {
250   AbstractBasePtrList args_abs_list;
251   args_abs_list.reserve(args_conf_list.size());
252   for (auto &config : args_conf_list) {
253     MS_EXCEPTION_IF_NULL(config);
254     auto result = config->ObtainEvalResult();
255     MS_EXCEPTION_IF_NULL(result);
256     const auto &abs = result->abstract();
257     // Check if there's an inplace abstract and use it.
258     AbstractBasePtr real_abs;
259     MS_EXCEPTION_IF_NULL(abs);
260     if (abs->inplace_abstract() == nullptr) {
261       real_abs = abs;
262     } else {
263       real_abs = abs->inplace_abstract();
264       MS_LOG(INFO) << "Use inplace abstract, " << abs->ToString() << " -> " << real_abs->ToString();
265     }
266     (void)args_abs_list.emplace_back(real_abs);
267   }
268   return args_abs_list;
269 }
270 
CheckIfAlwaysEval(const AnfNodeConfigPtr & conf,const AbstractBasePtr & arg)271 bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
272   MS_EXCEPTION_IF_NULL(arg);
273   auto new_sequence = dyn_cast_ptr<AbstractSequence>(arg);
274   if (new_sequence != nullptr && !new_sequence->dynamic_len() && new_sequence->sequence_nodes() != nullptr &&
275       new_sequence->size() != 0) {
276     const auto &prev_result = ObtainEvalResultFromCache(conf);
277     if (prev_result == nullptr) {
278       return false;
279     }
280     auto prev_abs = prev_result->abstract();
281     auto old_sequence = dyn_cast_ptr<AbstractSequence>(prev_abs);
282     if (old_sequence != nullptr &&
283         (old_sequence->sequence_nodes() == nullptr || old_sequence->sequence_nodes()->empty()) && *arg == *prev_abs) {
284       MS_LOG(DEBUG) << "Always eval";
285       return true;
286     }
287   }
288   return false;
289 }
290 
EnterStackFrame(const AnalysisEnginePtr & engine,const StackFramePtr & current_stack_frame,const StackFramePtr & new_stack_frame)291 void BaseFuncGraphEvaluator::EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr &current_stack_frame,
292                                              const StackFramePtr &new_stack_frame) {
293   MS_EXCEPTION_IF_NULL(current_stack_frame);
294   MS_EXCEPTION_IF_NULL(new_stack_frame);
295   MS_EXCEPTION_IF_NULL(engine);
296   // Enter new func graph.
297   auto &current_node = current_stack_frame->CurrentNode();
298   auto current_context = current_stack_frame->current_context();
299   AnfNodeConfigPtr call_conf = engine->MakeConfig(current_node, current_context, current_context->func_graph());
300   auto evaluator = new_stack_frame->evaluator();
301   MS_EXCEPTION_IF_NULL(evaluator);
302   auto new_context = new_stack_frame->current_context();
303   trace::TraceGraphEvalEnter(new_context, call_conf);
304 
305   // Increase & Check the func graph call depth.
306   // Don't check it if the user set no_recursive flag.
307   IncreaseFunctionCallDepth();
308   IncreaseStackFrameDepth();
309   const auto &top_graph = parse::Parser::GetTopFuncGraph();
310   bool no_recursive = (top_graph == nullptr ? false : top_graph->has_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE));
311   const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
312   if (!no_recursive && FunctionCallDepth() > max_depth) {
313     MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
314                       << ", (function call depth: " << FunctionCallDepth()
315                       << ", simulate call depth: " << StackFrameDepth() << ").\n"
316                       << "It's always happened with complex construction of code or infinite recursion or loop.\n"
317                       << "Please check the code if it's has the infinite recursion "
318                       << "or call 'context.set_context(max_call_depth=value)' to adjust this value.\n"
319                       << "If max_call_depth is set larger, the system max stack depth should be set larger too "
320                       << "to avoid stack overflow.\n"
321                       << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
322   }
323   MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
324                 << "), enter, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
325 }
326 
LeaveStackFrame(const AnalysisEnginePtr &,const StackFramePtr & current_stack_frame)327 void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &, const StackFramePtr &current_stack_frame) {
328   MS_EXCEPTION_IF_NULL(current_stack_frame);
329   // Leave current func graph.
330   auto current_context = current_stack_frame->current_context();
331   trace::TraceGraphEvalLeave(current_context);
332 
333   // Decrease the func graph call depth.
334   DecreaseFunctionCallDepth();
335   DecreaseStackFrameDepth();
336 
337   auto evaluator = current_stack_frame->evaluator();
338   MS_EXCEPTION_IF_NULL(evaluator);
339   MS_LOG(DEBUG) << evaluator << "(" << evaluator->type_name() << "/" << evaluator->ToString()
340                 << "), leave, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
341 }
342 
343 // Start running stack frames in a Evaluator.
LaunchStackFrame(const AnalysisEnginePtr & engine,const FuncGraphPtr & fg,const AnalysisContextPtr & context)344 AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
345                                                          const AnalysisContextPtr &context) {
346   EvalResultPtr eval_result = nullptr;
347   AbstractBasePtr abstract = nullptr;
348   std::stack<StackFramePtr> stack_frames;
349   auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context, parent_context_);
350   MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame;
351   stack_frames.push(current_stack_frame);
352   while (true) {
353     current_stack_frame = stack_frames.top();
354     MS_EXCEPTION_IF_NULL(current_stack_frame);
355     if (current_stack_frame->Done()) {
356       MS_EXCEPTION_IF_NULL(abstract);
357       MS_EXCEPTION_IF_NULL(current_stack_frame->func_graph());
358       if (current_stack_frame->func_graph()->has_flag(FUNC_GRAPH_FLAG_PRIMAL_OF_BPROP)) {
359         // Set all fprop outputs as used.
360         SetSequenceElementsUseFlagsRecursively(abstract, true);
361       }
362       MS_LOG(DEBUG) << "[" << this << "/StackFrame] Leave from func graph, " << current_stack_frame;
363       stack_frames.pop();
364       if (stack_frames.empty()) {
365         MS_LOG(DEBUG) << "[" << this << "/StackFrame] Finish at func graph, " << current_stack_frame
366                       << ", abstract: " << abstract->ToString();
367         break;
368       }
369       // Leave current func graph.
370       LeaveStackFrame(engine, current_stack_frame);
371       // Switch the stack frame.
372       auto last_stack_frame = current_stack_frame;
373       current_stack_frame = stack_frames.top();
374       MS_LOG(DEBUG) << "[" << this << "/StackFrame] Back to func graph, " << current_stack_frame;
375       current_stack_frame->Back(engine, last_stack_frame, eval_result);
376       continue;
377     }
378 
379     auto new_stack_frame = current_stack_frame->Jump(engine);
380     if (new_stack_frame != nullptr) {
381       // Enter new func graph.
382       EnterStackFrame(engine, current_stack_frame, new_stack_frame);
383       // Update current stack frame.
384       stack_frames.push(new_stack_frame);
385       MS_LOG(DEBUG) << "[" << this << "/StackFrame] Jump to new func graph, " << new_stack_frame;
386       continue;
387     }
388 
389     eval_result = current_stack_frame->Step(engine);
390     MS_EXCEPTION_IF_NULL(eval_result);
391     abstract = eval_result->abstract();
392   }
393   return abstract;
394 }
395 
LaunchRecursiveEval(const AnalysisEnginePtr & engine,const FuncGraphPtr & fg,const AnalysisContextPtr & context) const396 AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg,
397                                                             const AnalysisContextPtr &context) const {
398   MS_EXCEPTION_IF_NULL(fg);
399   MS_EXCEPTION_IF_NULL(engine);
400   const AnfNodePtr &func_node = fg->get_return();
401   const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType {
402     MS_EXCEPTION_IF_NULL(node);
403     static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
404     if (node->isa<ValueNode>() || node->isa<Parameter>() ||
405         (enable_pre_lift && IsPrimitiveCNode(node, prim::kPrimPartial))) {
406       return EXCLUDE;
407     }
408     return FOLLOW;
409   });
410   AbstractBasePtr abstract = nullptr;
411   for (const auto &node : all_nodes) {
412     MS_EXCEPTION_IF_NULL(node);
413     AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context, fg);
414     MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString()
415                   << ", node: " << node->DebugString() << ", node_conf: " << node_conf->ToString();
416     EvalResultPtr node_eval_result = nullptr;
417     if (always_eval_flag()) {
418       MS_LOG(DEBUG) << "Always eval node";
419       node_eval_result = engine->ObtainEvalResultWithoutCache(node_conf);
420     } else {
421       node_eval_result = ObtainEvalResultFromCache(node_conf);
422       if (node_eval_result != nullptr) {
423         static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
424         if (enable_eliminate_unused_element) {
425           const auto &cnode = node->cast<CNodePtr>();
426           MS_EXCEPTION_IF_NULL(cnode);
427           const auto &maybe_func = engine->GetCNodeOperatorAbstract(cnode, context, fg);
428           if (maybe_func->isa<MetaFuncGraphAbstractClosure>() || maybe_func->isa<FuncGraphAbstractClosure>()) {
429             const auto &abs_func_graph = maybe_func->cast<AbstractFunctionPtr>();
430             SynchronizeSequenceElementsUseFlagsForFuncGraphArgs(engine, fg, cnode, abs_func_graph, context);
431           }
432         }
433         if (engine->check_side_effect() && node_eval_result->has_side_effect_node()) {
434           auto cnode = dyn_cast_ptr<CNode>(node);
435           MS_EXCEPTION_IF_NULL(cnode);
436           MS_LOG(DEBUG) << "Found side-effect, cnode: " << cnode->DebugString() << ", func_graph: " << fg->ToString();
437           cnode->set_has_side_effect_node(true);
438           fg->set_has_side_effect_node(true);
439         }
440         MS_LOG(DEBUG) << "No need to jump as found result from cache for node_config";
441       } else {
442         node_eval_result = engine->ObtainEvalResultWithoutCache(node_conf);
443       }
444     }
445     MS_EXCEPTION_IF_NULL(node_eval_result);
446     abstract = node_eval_result->abstract();
447     MS_EXCEPTION_IF_NULL(abstract);
448     MS_LOG(DEBUG) << GetInferThread() << "Eval ( " << node_conf->ToString() << ") = " << abstract->ToString();
449   }
450   MS_EXCEPTION_IF_NULL(abstract);
451   if (fg->has_flag(FUNC_GRAPH_FLAG_PRIMAL_OF_BPROP)) {
452     // Set all fprop outputs as used.
453     SetSequenceElementsUseFlagsRecursively(abstract, true);
454   }
455   return abstract;
456 }
457 
Eval(AnalysisEnginePtr engine,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)458 EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list,
459                                            const AnfNodeConfigPtr &out_conf) {
460   auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
461   if (eval_result != nullptr) {
462     MS_LOG(ERROR) << ToString() << ArgsToString(args_abs_list) << " entered again. There is something wrong.";
463     return eval_result;
464   }
465   MS_LOG(DEBUG) << ToString() << " entered first.";
466   MS_EXCEPTION_IF_NULL(engine);
467   // Increase & Check the func graph call depth.
468   // Don't check it if the user set no_recursive flag.
469   IncreaseFunctionCallDepth();
470   const auto &top_graph = parse::Parser::GetTopFuncGraph();
471   bool no_recursive = (top_graph == nullptr ? false : top_graph->has_flag(FUNC_GRAPH_FLAG_NO_RECURSIVE));
472   const uint32_t max_depth = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH);
473   if (!no_recursive && FunctionCallDepth() > max_depth) {
474     MS_LOG(EXCEPTION) << "Exceed function call depth limit " << max_depth
475                       << ", (function call depth: " << FunctionCallDepth()
476                       << ", simulate call depth: " << StackFrameDepth() << ").\n"
477                       << "It's always happened with complex construction of code or infinite recursion or loop.\n"
478                       << "Please check the code if it's has the infinite recursion "
479                       << "or call 'context.set_context(max_call_depth=value)' to adjust this value.\n"
480                       << "If max_call_depth is set larger, the system max stack depth should be set larger too "
481                       << "to avoid stack overflow.\n"
482                       << "For more details, please refer to the FAQ at https://www.mindspore.cn.";
483   }
484   MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
485                 << "), enter, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
486 
487   FuncGraphPtr fg = GetFuncGraph(engine, args_abs_list);
488   MS_EXCEPTION_IF_NULL(fg);
489   MS_EXCEPTION_IF_NULL(parent_context_);
490   auto context = NewContext(parent_context_, fg, args_abs_list);
491   trace::TraceGraphEvalEnter(context, out_conf);
492 
493   std::size_t nargs = fg->parameters().size();
494   if (args_abs_list.size() != nargs) {
495     MS_EXCEPTION(TypeError) << "The parameters number of the function is " << fg->parameters().size()
496                             << ", but the number of provided arguments is " << args_abs_list.size() << ".\n"
497                             << "FunctionGraph : " << fg->ToString()
498                             << "\nNodeInfo: " << trace::GetDebugInfoStr(fg->debug_info());
499   }
500   MS_LOG(DEBUG) << GetInferThread() << "@" << fg->ToString() << ArgsToString(args_abs_list) << " { ";
501   if (parent_context_->func_graph() != nullptr) {
502     MS_LOG(DEBUG) << GetInferThread() << "graph_: " << AnalysisSchedule::thread_id() << ":"
503                   << parent_context_->func_graph()->ToString() << "()->" << AnalysisSchedule::thread_id() << ":"
504                   << fg->ToString() << "();";
505   }
506 
507   auto func_graph_evaluator = mindspore::cast<FuncGraphEvaluator>(this);
508   if (func_graph_evaluator != nullptr) {
509     MS_EXCEPTION_IF_NULL(engine->root_func_graph());
510     if (engine->root_func_graph() == func_graph_evaluator->func_graph()) {
511       engine->set_root_context(context);
512     }
513   }
514   bool always_eval_flag = false;
515   const auto &parameters = fg->parameters();
516   for (size_t i = 0; i < nargs; i++) {
517     const auto &arg = args_abs_list[i];
518     const auto &node = parameters[i];
519     AnfNodeConfigPtr conf = engine->MakeConfig(node, context, fg);
520     always_eval_flag = always_eval_flag || CheckIfAlwaysEval(conf, arg);
521     auto result = std::make_shared<EvalResult>(arg, nullptr);
522     engine->SaveEvalResultInCache(conf, result);
523     MS_EXCEPTION_IF_NULL(arg);
524     MS_LOG(DEBUG) << GetInferThread() << ", Save argument[" << i << "] result for " << fg->ToString()
525                   << ", NodeConfig: " << conf->ToString() << ", result: " << arg << "/" << arg->ToString();
526   }
527   PushAlwaysEvalFlag(always_eval_flag);
528   if (fg->get_return() == nullptr) {
529     MS_LOG(EXCEPTION) << "The func graph " << fg << "/" << fg->ToString() << " has no return node.";
530   }
531   MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString()
532                 << ", context: " << context->ToString() << ", return node: " << fg->get_return()->DebugString()
533                 << ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL")
534                 << ", current function call depth: " << FunctionCallDepth();
535   AbstractBasePtr abstract = nullptr;
536   if (engine->enable_recursive_eval()) {
537     abstract = LaunchRecursiveEval(engine, fg, context);
538   } else {
539     abstract = LaunchStackFrame(engine, fg, context);
540   }
541   PopAlwaysEvalFlag();
542 
543   MS_EXCEPTION_IF_NULL(abstract);
544   MS_LOG(DEBUG) << "Analysis FuncGraph end, " << fg << "/" << fg->ToString()
545                 << ", evaluated abstract: " << abstract->ToString() << ", is stub: " << fg->stub();
546   if (fg->stub()) {
547     abstract = std::make_shared<AbstractUndetermined>();
548   }
549   MS_LOG(DEBUG) << GetInferThread() << "} //" << fg->ToString() << " = " << abstract->ToString();
550 
551   SyncFuncGraphSideEffectFlag(fg);
552 
553   trace::TraceGraphEvalLeave(context);
554   // Decrease the func graph call depth.
555   DecreaseFunctionCallDepth();
556   MS_LOG(DEBUG) << this << "(" << type_name() << "/" << ToString()
557                 << "), leave, function call depth: " << FunctionCallDepth() << " - " << StackFrameDepth();
558   auto res = std::make_shared<EvalResult>(abstract, nullptr);
559   return res;
560 }
561 
BroadenArgs(const AbstractBasePtrList & args_abs_list,AbstractBasePtrList * broaded_args,bool broaden_scalar)562 void BroadenArgs(const AbstractBasePtrList &args_abs_list, AbstractBasePtrList *broaded_args, bool broaden_scalar) {
563   MS_EXCEPTION_IF_NULL(broaded_args);
564   (void)std::transform(
565     args_abs_list.begin(), args_abs_list.end(), std::back_inserter(*broaded_args),
566     [&broaden_scalar](const AbstractBasePtr &arg) -> AbstractBasePtr {
567       auto arg_sequence = arg->cast<AbstractSequencePtr>();
568       if (arg_sequence != nullptr && !arg_sequence->dynamic_len() && !arg->isa<AbstractSparseTensor>()) {
569         MS_LOG(DEBUG) << "set as arg of dyn len param, arg:" << arg->ToString();
570         auto dyn_len_arg = arg_sequence->BroadenToDynamicLenSequence();
571         return broaden_scalar ? AbstractBroaden(dyn_len_arg) : dyn_len_arg->Broaden();
572       }
573       if (arg->GetValueTrack() != kValueAny) {
574         return broaden_scalar ? AbstractBroaden(arg) : arg->Broaden();
575       }
576       return arg;
577     });
578 }
579 
NormalizeArgs(const AbstractBasePtrList & args_abs_list) const580 AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_abs_list) const {
581   MS_EXCEPTION_IF_NULL(func_graph_);
582   if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
583     AbstractBasePtrList broadened_list;
584     auto broaden_scalar = !func_graph_->has_flag(FUNC_GRAPH_FLAG_VMAP_TRANSFORMED);
585     BroadenArgs(args_abs_list, &broadened_list, broaden_scalar);
586     MS_LOG(DEBUG) << func_graph_->ToString() << ", original: " << mindspore::ToString(args_abs_list)
587                   << ", broadened: " << mindspore::ToString(broadened_list);
588     return broadened_list;
589   }
590   return args_abs_list;
591 }
592 
BroadenUndeterminedArgs(const AbstractBasePtrList & args_abs_list,const AnalysisEnginePtr & engine)593 AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_abs_list,
594                                                                 const AnalysisEnginePtr &engine) {
595   MS_EXCEPTION_IF_NULL(func_graph_);
596   if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
597     return args_abs_list;
598   }
599   // Set ignore flag for mutlithread eval.
600   engine->SetIgnoreValueFlag(AnalysisSchedule::thread_id(), func_graph_.get());
601   // Set ignore flag for recursive eval.
602   if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) {
603     func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE, true);
604     MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag in recursive eval.";
605   }
606   if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE)) {
607     auto normalized_args_abs_list = NormalizeArgs(args_abs_list);
608     MS_LOG(DEBUG) << "Normalized args " << mindspore::ToString(normalized_args_abs_list);
609     return normalized_args_abs_list;
610   }
611   return args_abs_list;
612 }
613 
GetFuncGraph(AnalysisEnginePtr engine,const AbstractBasePtrList & args_abs_list)614 FuncGraphPtr FuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) {
615   auto iter = func_graph_cache_.find(args_abs_list);
616   FuncGraphPtr res;
617   if (iter == func_graph_cache_.end()) {
618     auto fg = func_graph();
619     MS_EXCEPTION_IF_NULL(fg);
620     FuncGraphPtr generated_graph = fg->GenerateFuncGraph(args_abs_list);
621     func_graph_cache_[args_abs_list] = generated_graph;
622     MS_LOG(DEBUG) << "Generate special instance of function graph: " << ToString()
623                   << ", special function: " << generated_graph->ToString() << ", args: " << ArgsToString(args_abs_list);
624 
625     MS_EXCEPTION_IF_NULL(engine);
626     MS_EXCEPTION_IF_NULL(engine->func_graph_manager());
627     engine->func_graph_manager()->AddFuncGraph(generated_graph);
628     if (engine->check_side_effect()) {
629       PresetCertainSideEffect(generated_graph);
630     }
631     res = generated_graph;
632   } else {
633     res = iter->second;
634   }
635 
636   // For the top graph, if it is replaced by generated graph, update the top graph to the new one.
637   if (parse::Parser::GetTopFuncGraph() == func_graph()) {
638     if (res != func_graph()) {
639       parse::Parser::UpdateTopFuncGraph(res);
640     }
641   }
642   return res;
643 }
644 
GetFuncGraph(AnalysisEnginePtr engine,const AbstractBasePtrList & args_abs_list)645 FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, const AbstractBasePtrList &args_abs_list) {
646   auto iter = func_graph_cache_.find(args_abs_list);
647   if (iter != func_graph_cache_.end()) {
648     return iter->second;
649   }
650   MS_EXCEPTION_IF_NULL(meta_func_graph_);
651   (void)meta_func_graph_->GetChecker("check_infer_inputs").Execute(args_abs_list);
652 
653   if (scope_ != nullptr) {
654     meta_func_graph_->set_scope_name(scope_->name());
655   }
656   if (this->bound_node() != nullptr) {
657     auto node_debug_info = bound_node()->debug_info();
658     TraceGuard trace_guard(std::make_shared<TraceGenMetaFuncGraph>(node_debug_info));
659     auto node_location = trace::GetSourceCodeDebugInfo(node_debug_info)->location();
660     if (node_location != nullptr) {
661       meta_func_graph_->set_node_expr_src(node_location->expr_src());
662     }
663     generated_func_graph_ = meta_func_graph_->GenerateFuncGraph(args_abs_list);
664   } else {
665     generated_func_graph_ = meta_func_graph_->GenerateFuncGraph(args_abs_list);
666   }
667 
668   FuncGraphPtr cloned_func_graph;
669   NodeDebugInfoPtr debug_info;
670   if (this->bound_node() != nullptr) {
671     debug_info = this->bound_node()->debug_info();
672   }
673   if (meta_func_graph_->isa<expander::bprop::BpropMetaFuncGraph>()) {
674     cloned_func_graph = GetCloneBpropGraph(meta_func_graph_, generated_func_graph_, this->bound_node(), scope_);
675   } else {
676     cloned_func_graph = BasicClone(generated_func_graph_, false, std::make_shared<UpdateInfo>(scope_, debug_info));
677   }
678   func_graph_cache_[args_abs_list] = cloned_func_graph;
679   MS_EXCEPTION_IF_NULL(engine);
680   MS_EXCEPTION_IF_NULL(engine->func_graph_manager());
681   engine->func_graph_manager()->AddFuncGraph(cloned_func_graph);
682   if (engine->check_side_effect()) {
683     PresetCertainSideEffect(cloned_func_graph);
684   }
685   return cloned_func_graph;
686 }
687 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)688 EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
689                              const AnfNodeConfigPtr &out_conf) {
690   AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
691   args_abs_list = NormalizeArgs(args_abs_list);
692   args_abs_list = BroadenUndeterminedArgs(args_abs_list, engine);
693   MS_LOG(DEBUG) << EvalEntryLogging(shared_from_base<Evaluator>(), args_abs_list, out_conf);
694   EvalResultPtr eval_result = nullptr;
695   const std::string &evaluator_name = ToString();
696   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
697   auto &cache = evaluator_cache_mgr_->GetCache();
698   auto iter = cache.find(args_abs_list);
699   if (iter == cache.end()) {
700     MS_LOG(DEBUG) << "[" << this << "/" << evaluator_name << "] cache miss, call Eval(), args: " << args_abs_list;
701     eval_result = Eval(engine, args_abs_list, out_conf);
702     MS_EXCEPTION_IF_NULL(eval_result);
703     if (eval_result->abstract() == nullptr) {
704       EvalFailLogging(shared_from_base<Evaluator>(), args_abs_list, out_conf);
705       MS_LOG(INTERNAL_EXCEPTION) << "Evaluator " << evaluator_name << " result is nullptr.";
706     }
707     MS_LOG(DEBUG) << "[" << this << "/" << evaluator_name
708                   << "] set cache. result: " << eval_result->abstract()->ToString()
709                   << ", args_abs_list hash: " << AbstractBasePtrListHash(args_abs_list)
710                   << ", args_abs_list: " << args_abs_list;
711     evaluator_cache_mgr_->SetValue(args_abs_list, eval_result);
712   } else {
713     eval_result = iter->second;
714     MS_EXCEPTION_IF_NULL(eval_result->abstract());
715     MS_LOG(DEBUG) << "[" << this << "/" << evaluator_name
716                   << "] cache hit. result: " << eval_result->abstract()->ToString() << ", args: " << args_abs_list;
717     // Update inputs sequence nodes info, if matched in cache.
718     static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
719     if (enable_eliminate_unused_element) {
720       for (size_t i = 0; i < args_abs_list.size(); ++i) {
721         auto new_sequence = dyn_cast<AbstractSequence>(args_abs_list[i]);
722         auto old_sequence = dyn_cast<AbstractSequence>(iter->first[i]);
723         if (old_sequence != nullptr && new_sequence != nullptr) {
724           MS_LOG(DEBUG) << "Before synchronize sequence nodes use flags for NodeConfig: "
725                         << (out_conf ? out_conf->ToString() : "NULL") << "old_sequence: " << old_sequence->ToString()
726                         << ", new_sequence: " << new_sequence->ToString();
727           SynchronizeSequenceElementsUseFlagsRecursively(old_sequence, new_sequence);
728           MS_LOG(DEBUG) << "After synchronize sequence nodes use flags for NodeConfig: "
729                         << (out_conf ? out_conf->ToString() : "NULL") << ", old_sequence: " << old_sequence->ToString()
730                         << ", new_sequence: " << new_sequence->ToString();
731         }
732       }
733     }
734   }
735   return eval_result;
736 }
737 
EvalUndeterminedArgs(const AbstractBasePtrList & args_abs_list)738 EvalResultPtr Evaluator::EvalUndeterminedArgs(const AbstractBasePtrList &args_abs_list) {
739   auto is_undetermined = std::any_of(args_abs_list.begin(), args_abs_list.end(), [](auto &arg) -> bool {
740     return arg->IsSameTypeId(AbstractUndetermined::kTypeId);
741   });
742   if (is_undetermined) {
743     MS_LOG(DEBUG) << "Eval " << identifier_ << " return undetermined abstract result";
744     return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), std::make_shared<AttrValueMap>());
745   }
746   return nullptr;
747 }
748 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)749 EvalResultPtr TrivialPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
750                                         const AnfNodeConfigPtr &) {
751   AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
752 
753   EvalResultPtr res;
754   // If the arguments contain Any, return Any directly.
755   // Only check in TrivialPrimEvaluator, not in TransitionPrimEvaluator.
756   const auto standard_prim_eval = dyn_cast_ptr<StandardPrimEvaluator>(shared_from_this());
757   bool ignore_any_type_checking =
758     (standard_prim_eval != nullptr &&
759      ignore_any_type_checking_prims.find(standard_prim_eval->prim()) != ignore_any_type_checking_prims.end());
760   if (!ignore_any_type_checking && ContainsAbstractAny(args_abs_list)) {
761     MS_LOG(INFO) << ToString() << " receives arguments that contain Any.";
762     auto any_abstract = std::make_shared<AbstractAny>();
763     const auto &dtype = GetArgsUniqueDtype(args_abs_list);
764     if (dtype != nullptr) {
765       MS_EXCEPTION_IF_NULL(any_abstract->element());
766       any_abstract->element()->set_type(dtype);
767       any_abstract->set_supposed_tensor_dtype(true);
768     }
769     for (const auto &abs : args_abs_list) {
770       MS_EXCEPTION_IF_NULL(abs);
771       if (abs->isa<abstract::AbstractSequence>()) {
772         SetSequenceElementsUseFlagsRecursively(abs, true);
773       }
774     }
775     res = std::make_shared<EvalResult>(any_abstract, std::make_shared<AttrValueMap>());
776   } else {
777     try {
778       res = EvalPrim(engine, args_abs_list);
779     } catch (std::exception &e) {
780       MS_LOG(ERROR) << "Primitive: <" << ToString() << "> infer failed, failed info: " << e.what();
781       std::rethrow_exception(std::current_exception());
782     }
783   }
784   MS_EXCEPTION_IF_NULL(res);
785   // Update the input abstract for inplace primitive.
786   if (inplace_prim() && !args_abs_list.empty() && args_abs_list[0] != res->abstract()) {
787     MS_LOG(DEBUG) << "Set inplace abstract, " << args_abs_list[0]->ToString() << " -> " << res->abstract()->ToString();
788     // Always update the inplace abstract.
789     args_abs_list[0]->set_inplace_abstract(res->abstract());
790   }
791   return res;
792 }
793 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)794 EvalResultPtr TransitionPrimEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
795                                            const AnfNodeConfigPtr &out_conf) {
796   if (args_conf_list.empty() && identifier_ != "MakeTupleEvaluator" && identifier_ != "MakeListEvaluator" &&
797       identifier_ != "RaiseEvaluator" && identifier_ != "ConstexprEvaluator") {
798     MS_LOG(INTERNAL_EXCEPTION) << "Size should be greater than 0, during running " << identifier_;
799   }
800   AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
801   EvalResultPtr res = EvalPrim(engine, args_abs_list, args_conf_list[0], out_conf);
802   MS_EXCEPTION_IF_NULL(res);
803   // Update the input abstract for inplace primitive.
804   if (inplace_prim() && !args_abs_list.empty() && args_abs_list[0] != res->abstract()) {
805     MS_LOG(DEBUG) << "Set inplace abstract, " << args_abs_list[0]->ToString() << " -> " << res->abstract()->ToString();
806     // Always update the inplace abstract.
807     args_abs_list[0]->set_inplace_abstract(res->abstract());
808   }
809   // No need to cache.
810   return res;
811 }
812 
Run(AnalysisEnginePtr,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)813 EvalResultPtr SymbolicPrimEvaluator::Run(AnalysisEnginePtr, const ConfigPtrList &args_conf_list,
814                                          const AnfNodeConfigPtr &) {
815   return EvalPrim(args_conf_list);
816 }
817 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)818 EvalResultPtr TrackedEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
819                                     const AnfNodeConfigPtr &out_conf) {
820   AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
821   EvalResultPtr res = sub_evaluator_->Run(engine, args_conf_list, out_conf);
822   // Don't lookup from cache, as different out_conf with same node but different context
823   // may add different entry to anfnode_config_map_, like getattr primitive.
824   evaluator_cache_mgr_->SetValue(args_abs_list, res);
825   return res;
826 }
827 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)828 EvalResultPtr PartialAppEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
829                                        const AnfNodeConfigPtr &out_conf) {
830   AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
831   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
832   auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
833   if (eval_result != nullptr) {
834     return eval_result;
835   }
836 
837   ConfigPtrList partial_args_conf_list;
838   // Join arguments in partial and the rest arguments from args_conf_list.
839   (void)std::transform(args_abs_list_.begin(), args_abs_list_.end(), std::back_inserter(partial_args_conf_list),
840                        [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
841 
842   (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(partial_args_conf_list),
843                        [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
844   EvalResultPtr res = evaluator_->Run(engine, partial_args_conf_list, out_conf);
845   evaluator_cache_mgr_->SetValue(args_abs_list, res);
846   return res;
847 }
848 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)849 EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
850                               const AnfNodeConfigPtr &out_conf) {
851   AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
852   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
853   auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
854   if (eval_result != nullptr) {
855     return eval_result;
856   }
857 
858   // Call the original evaluator, get the result: y = f(x)
859   EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
860   MS_EXCEPTION_IF_NULL(result);
861   // If the primal func graph's output is sequence, set its elements use flags all true.
862   SetSequenceElementsUseFlagsRecursively(result->abstract(), true);
863   // Build a virtual function: bprop_f which use sense of y as input, return sense of function free variable and input
864   // parameters. (sense_f, sense_x, ...)(*bpro_f) (sense_y)
865   AbstractBasePtrList bparams;
866   bparams.push_back(SensitivityTransform(primal_func_));
867   // Check if primal func graph has the primitive returned sparse result in its bprop().
868   auto real_primal_func = dyn_cast_ptr<FuncGraphAbstractClosure>(primal_func_);
869   MS_EXCEPTION_IF_NULL(real_primal_func);
870   FuncGraphPtr primal_func_graph = real_primal_func->func_graph();
871   MS_EXCEPTION_IF_NULL(primal_func_graph);
872   bool has_sparse_bprop_prim = primal_func_graph->has_flag(FUNC_GRAPH_FLAG_SPARSE_BPROP);
873   (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(bparams),
874                        [&has_sparse_bprop_prim](const AbstractBasePtr &arg_abs) -> AbstractBasePtr {
875                          MS_EXCEPTION_IF_NULL(arg_abs);
876                          if (has_sparse_bprop_prim && arg_abs->isa<AbstractTensor>()) {
877                            return std::make_shared<AbstractUndetermined>();
878                          }
879                          return SensitivityTransform(arg_abs);
880                        });
881   AbstractBasePtr bparams_final = std::make_shared<AbstractTuple>(bparams);
882   AbstractFunctionPtr bprop;
883   MS_EXCEPTION_IF_NULL(out_conf);
884   auto current_node = out_conf->node();
885   MS_EXCEPTION_IF_NULL(current_node);
886   if (current_node->isa<CNode>()) {
887     auto current_cnode = current_node->cast<CNodePtr>();
888     auto effect_info = current_cnode->GetEffectInfo();
889     if (current_cnode->IsEffectHandled() && effect_info.back_mem) {
890       AbstractBasePtrList bprop_inputs{SensitivityTransform(result->abstract()), kUMonad->ToAbstract()};
891       bprop = std::make_shared<VirtualAbstractClosure>(bprop_inputs, bparams_final);
892     } else {
893       bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
894     }
895   } else {
896     bprop = std::make_shared<VirtualAbstractClosure>(SensitivityTransform(result->abstract()), bparams_final);
897   }
898 
899   // J(f)(J(x)) return a tuple (y, bprop_f)
900   AbstractBasePtrList jargs = {result->abstract(), bprop};
901   AbstractBasePtr jtuple = std::make_shared<AbstractTuple>(jargs);
902   auto res = std::make_shared<EvalResult>(jtuple, std::make_shared<AttrValueMap>());
903   evaluator_cache_mgr_->SetValue(args_abs_list, res);
904   return res;
905 }
906 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)907 EvalResultPtr TaylorEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
908                                    const AnfNodeConfigPtr &) {
909   AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
910   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
911   auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
912   if (eval_result != nullptr) {
913     return eval_result;
914   }
915 
916   // Call the original evaluator, get the result: y = f(x)
917   EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
918   MS_EXCEPTION_IF_NULL(result);
919   evaluator_cache_mgr_->SetValue(args_abs_list, result);
920   return result;
921 }
922 
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)923 EvalResultPtr ShardEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
924                                   const AnfNodeConfigPtr &) {
925   AbstractBasePtrList args_abs_list = EvaluateArguments(args_conf_list);
926   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
927   auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
928   if (eval_result != nullptr) {
929     return eval_result;
930   }
931 
932   // Call the original evaluator, get the result: y = f(x)
933   EvalResultPtr result = evaluator_->Run(engine, args_conf_list, nullptr);
934   MS_EXCEPTION_IF_NULL(result);
935   auto res = std::make_shared<EvalResult>(result->abstract(), std::make_shared<AttrValueMap>());
936   evaluator_cache_mgr_->SetValue(args_abs_list, res);
937   return res;
938 }
939 
940 namespace {
ReduceDim(int * axis,const AbstractBasePtr & orig_abs,int * axis_size)941 AbstractBasePtr ReduceDim(int *axis, const AbstractBasePtr &orig_abs, int *axis_size) {
942   MS_EXCEPTION_IF_NULL(axis);
943   MS_EXCEPTION_IF_NULL(orig_abs);
944   MS_EXCEPTION_IF_NULL(axis_size);
945   if (!orig_abs->isa<AbstractTensor>()) {
946     MS_LOG(EXCEPTION) << "The orig_abs should be AbstractTensor when corresponding axis is " << *axis << ", but got a "
947                       << orig_abs->ToString() << ". Tip: Please check the correspondence between "
948                       << "vmap's 'in_axes' and inputs. You may want to explicitly specify the 'in_axes' "
949                       << "corresponding to " << orig_abs->ToString() << " as 'None' to solve this problem.";
950   }
951   auto orig_abs_shape = dyn_cast_ptr<Shape>(orig_abs->BuildShape());
952   MS_EXCEPTION_IF_NULL(orig_abs_shape);
953   ShapeVector orig_shape = orig_abs_shape->shape();
954   int shape_len = SizeToInt(orig_shape.size());
955   if (*axis < -shape_len || *axis >= shape_len) {
956     MS_LOG(EXCEPTION) << "The axis: " << *axis << " in 'in_axes' is out of bounds for array of dimension ["
957                       << -shape_len << "," << shape_len << ").";
958   }
959   *axis = *axis < 0 ? shape_len + *axis : *axis;
960   auto temp_axes_size = orig_shape[IntToSize(*axis)];
961   if (*axis_size == -1) {
962     *axis_size = LongToInt(temp_axes_size);
963   } else if (*axis_size != temp_axes_size) {
964     MS_LOG(EXCEPTION) << "The 'axis_size' of each argument in the scope of 'vmap' should be equal, but got "
965                       << *axis_size << " and " << temp_axes_size << ".";
966   }
967   (void)orig_shape.erase(orig_shape.begin() + *axis);
968   BaseShapePtr new_shape = std::make_shared<Shape>(orig_shape);
969   MS_EXCEPTION_IF_NULL(orig_abs->Clone());
970   AbstractBasePtr abs_clone = orig_abs->Clone()->Broaden();
971   abs_clone->set_shape(new_shape);
972   return abs_clone;
973 }
974 
GetLogicalViewAbs(const AbstractBasePtr & physical_view_abs,const ValuePtr & in_axes,int * axis_size)975 AbstractBasePtr GetLogicalViewAbs(const AbstractBasePtr &physical_view_abs, const ValuePtr &in_axes, int *axis_size) {
976   MS_EXCEPTION_IF_NULL(physical_view_abs);
977   MS_EXCEPTION_IF_NULL(in_axes);
978   auto physical_view_abs_sequence = dyn_cast_ptr<AbstractSequence>(physical_view_abs);
979   if (physical_view_abs_sequence != nullptr) {
980     AbstractBasePtrList abs_list = physical_view_abs_sequence->elements();
981     AbstractBasePtrList logical_view_abs_list;
982     auto in_axes_seq = dyn_cast_ptr<ValueSequeue>(in_axes);
983     int index = 0;
984     (void)std::transform(abs_list.begin(), abs_list.end(), std::back_inserter(logical_view_abs_list),
985                          [&axis_size, &index, in_axes_seq, in_axes](const AbstractBasePtr &sub_abs) -> AbstractBasePtr {
986                            ValuePtr sub_in_axes = in_axes;
987                            if (in_axes->isa<ValueSequeue>()) {
988                              sub_in_axes = (*in_axes_seq)[index];
989                              index++;
990                            }
991                            return GetLogicalViewAbs(sub_abs, sub_in_axes, axis_size);
992                          });
993     if (physical_view_abs->isa<AbstractList>()) {
994       return std::make_shared<AbstractList>(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes());
995     }
996     return std::make_shared<AbstractTuple>(logical_view_abs_list, physical_view_abs_sequence->sequence_nodes());
997   }
998   ValuePtr in_axis = in_axes;
999   if (in_axis->isa<Int64Imm>()) {
1000     int axis = dyn_cast_ptr<Int64Imm>(in_axis)->value();
1001     auto logical_view_abs = ReduceDim(&axis, physical_view_abs, axis_size);
1002     return logical_view_abs;
1003   }
1004   if (!in_axis->isa<None>()) {
1005     MS_LOG(EXCEPTION) << "The axis in vmap's 'in_axes' should be a None or a scalar of type Int64Imm, but got a "
1006                       << in_axis->ToString() << ".";
1007   }
1008   // in_axis is None.
1009   return physical_view_abs;
1010 }
1011 
ExtendDim(int * axis,const AbstractBasePtr & orig_abs,int axis_size)1012 AbstractBasePtr ExtendDim(int *axis, const AbstractBasePtr &orig_abs, int axis_size) {
1013   MS_EXCEPTION_IF_NULL(orig_abs);
1014   MS_EXCEPTION_IF_NULL(axis);
1015   AbstractBasePtr out_abs = nullptr;
1016   ShapeVector orig_shape;
1017   if (orig_abs->isa<AbstractTensor>()) {
1018     auto shape = dyn_cast_ptr<Shape>(orig_abs->BuildShape());
1019     if (shape != nullptr) {
1020       orig_shape = shape->shape();
1021     }
1022     if (std::any_of(orig_shape.begin(), orig_shape.end(),
1023                     [](ShapeValueDType s) { return s == Shape::kShapeRankAny; })) {
1024       return orig_abs;
1025     }
1026   }
1027   int shape_len = SizeToInt(orig_shape.size() + 1);
1028   if (*axis < -shape_len || *axis >= shape_len) {
1029     MS_LOG(EXCEPTION) << "The axis: " << *axis << " in 'out_axes' is out of bounds for array of dimension ["
1030                       << -shape_len << "," << shape_len << ").";
1031   }
1032   *axis = *axis < 0 ? shape_len + *axis : *axis;
1033   (void)orig_shape.insert(orig_shape.begin() + *axis, axis_size);
1034   BaseShapePtr new_shape = std::make_shared<Shape>(orig_shape);
1035   if (orig_abs->isa<AbstractTensor>()) {
1036     auto tmp_abs = orig_abs->Clone();
1037     MS_EXCEPTION_IF_NULL(tmp_abs);
1038     out_abs = tmp_abs->Broaden();
1039     MS_EXCEPTION_IF_NULL(out_abs);
1040     out_abs->set_shape(new_shape);
1041   } else if (orig_abs->isa<AbstractScalar>()) {
1042     out_abs = std::make_shared<AbstractTensor>(orig_abs, new_shape);
1043   } else {
1044     MS_LOG(EXCEPTION) << "The outputs of vmap's 'fn' should be consisting of tensors or constants, but got "
1045                       << orig_abs->ToString() << ".";
1046   }
1047   return out_abs;
1048 }
1049 
GetPhysicalViewAbs(const AbstractBasePtr & logical_view_abs,const ValuePtr & out_axes,int axis_size)1050 AbstractBasePtr GetPhysicalViewAbs(const AbstractBasePtr &logical_view_abs, const ValuePtr &out_axes, int axis_size) {
1051   MS_EXCEPTION_IF_NULL(logical_view_abs);
1052   MS_EXCEPTION_IF_NULL(out_axes);
1053   auto logical_view_abs_sequence = dyn_cast_ptr<AbstractSequence>(logical_view_abs);
1054   if (logical_view_abs_sequence != nullptr) {
1055     AbstractBasePtrList logical_view_abs_list = logical_view_abs_sequence->elements();
1056     AbstractBasePtrList physical_view_abs_list;
1057     auto out_axes_seq = dyn_cast_ptr<ValueSequeue>(out_axes);
1058     if (out_axes_seq != nullptr) {
1059       if (logical_view_abs_list.size() != out_axes_seq->size()) {
1060         MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the number of results of 'fn': "
1061                           << logical_view_abs_list.size() << ", but got size: " << out_axes_seq->size() << ".";
1062       }
1063     }
1064     int index = 0;
1065     (void)std::transform(
1066       logical_view_abs_list.begin(), logical_view_abs_list.end(), std::back_inserter(physical_view_abs_list),
1067       [&axis_size, &index, out_axes_seq, out_axes](const AbstractBasePtr &arg_abs) -> AbstractBasePtr {
1068         ValuePtr sub_out_axes = out_axes;
1069         if (out_axes->isa<ValueSequeue>()) {
1070           sub_out_axes = (*out_axes_seq)[index];
1071           index++;
1072         }
1073         if (arg_abs->isa<AbstractSequence>()) {
1074           return GetPhysicalViewAbs(arg_abs, sub_out_axes, axis_size);
1075         }
1076         if (sub_out_axes->isa<Int64Imm>()) {
1077           int axis = static_cast<int>(dyn_cast_ptr<Int64Imm>(sub_out_axes)->value());
1078           return ExtendDim(&axis, arg_abs, axis_size);
1079         } else if (sub_out_axes->isa<None>()) {
1080           return arg_abs;
1081         }
1082         MS_LOG(EXCEPTION) << "The axis in vmap's 'out_axes' should be a None or a scalar of type Int64Imm, but got a "
1083                           << sub_out_axes->ToString() << ".";
1084       });
1085     if (logical_view_abs->isa<AbstractList>()) {
1086       return std::make_shared<AbstractList>(physical_view_abs_list);
1087     }
1088     return std::make_shared<AbstractTuple>(physical_view_abs_list);
1089   }
1090 
1091   // for the single output case, outputs: A, and out_axes: 1 or (1,).
1092   ValuePtr sub_out_axes = out_axes;
1093   ValueSequeuePtr out_axes_seq = dyn_cast<ValueSequeue>(out_axes);
1094   if (out_axes_seq != nullptr) {
1095     if (out_axes_seq->size() != 1) {
1096       MS_LOG(EXCEPTION) << "The size of vmap's 'out_axes' should be equal to the result size: 1, but got size: "
1097                         << out_axes_seq->size() << ".";
1098     }
1099     sub_out_axes = (*out_axes_seq)[0];
1100   }
1101 
1102   int axis = 0;
1103   auto axis_int_ptr = dyn_cast_ptr<Int64Imm>(sub_out_axes);
1104   if (axis_int_ptr != nullptr) {
1105     axis = LongToInt(axis_int_ptr->value());
1106   } else {
1107     MS_LOG(EXCEPTION) << "The axis in vmap's 'out_axes' should be a None or a scalar of type Int64Imm, but got a "
1108                       << sub_out_axes->ToString() << ".";
1109   }
1110   return ExtendDim(&axis, logical_view_abs, axis_size);
1111 }
1112 }  // namespace
1113 
1114 // According to the in_axes (e.g. (1,(None,3))), the abstraction of input parameters with the
1115 // physical view (e.g. (A,(B,C))) are converted into that with the logical view (e.g.(a,(b,c))),
1116 // more specific, the input `A` with shape (32, 16, 8) fitting the axis index `1` is converted in to
1117 // `a` with shape (32, 8). And then leverage the original graph to perform the evaluation.
1118 // Finally, the outputs with the logical view are converted back into the physical view in
1119 // combination with the out_axes. The inferring result is consistent with that after eliminating
1120 // the VmapOperator.
Run(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr &)1121 EvalResultPtr VmapEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
1122                                  const AnfNodeConfigPtr &) {
1123   AbstractBasePtrList args_abs_list;
1124   int axis_size = -1;
1125   int index = 0;
1126   auto in_axes = in_axes_;
1127   auto in_axes_seq = dyn_cast_ptr<ValueSequeue>(in_axes);
1128   (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_abs_list),
1129                        [&axis_size, &index, in_axes_seq, in_axes](const ConfigPtr &conf) -> AbstractBasePtr {
1130                          MS_EXCEPTION_IF_NULL(conf);
1131                          AbstractBasePtr abs = conf->ObtainEvalResult()->abstract();
1132                          MS_EXCEPTION_IF_NULL(abs);
1133                          // Drop the side effect tag parameters, because it has no mapping axis.
1134                          // e.g. args=(A,(B,C),U), in_axes=(1,(None,3))
1135                          if (abs->isa<AbstractMonad>()) {
1136                            return abs;
1137                          }
1138                          ValuePtr sub_in_axes = in_axes;
1139                          MS_EXCEPTION_IF_NULL(in_axes);
1140                          if (in_axes->isa<ValueSequeue>()) {
1141                            sub_in_axes = (*in_axes_seq)[index];
1142                            index++;
1143                          }
1144                          auto arg_abs = GetLogicalViewAbs(abs, sub_in_axes, &axis_size);
1145                          return arg_abs;
1146                        });
1147   MS_EXCEPTION_IF_NULL(evaluator_cache_mgr_);
1148   auto eval_result = evaluator_cache_mgr_->GetValue(args_abs_list);
1149   if (eval_result != nullptr) {
1150     return eval_result;
1151   }
1152   ConfigPtrList virtual_conf_list;
1153   (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(virtual_conf_list),
1154                        [](const AbstractBasePtr &arg) -> ConfigPtr { return std::make_shared<VirtualConfig>(arg); });
1155 
1156   // Call the original evaluator, get the result: y = f(x)
1157   EvalResultPtr result = evaluator_->Run(engine, virtual_conf_list, nullptr);
1158   MS_EXCEPTION_IF_NULL(result);
1159 
1160   // If the primal func graph's output is sequence, set its elements use flags all true.
1161   SetSequenceElementsUseFlagsRecursively(result->abstract(), true);
1162 
1163   if (axis_size == -1 && cell_size_ != 0) {
1164     axis_size = SizeToInt(cell_size_);
1165   } else if (axis_size != -1 && cell_size_ != 0 && axis_size != SizeToInt(cell_size_)) {
1166     MS_EXCEPTION(ValueError) << "If you want to execute the model ensembling parallel training, please make sure "
1167                              << "the 'axis_size' in the scope of vmap consistent with the cell size of the input "
1168                              << "'CellList', otherwise, please do not enter 'CellList' as the first argument, "
1169                              << "but we get axis_size: " << axis_size << " and the cell size: " << cell_size_ << ".";
1170   }
1171 
1172   AbstractBasePtr result_abs = result->abstract();
1173   AbstractBasePtr after_vmap = GetPhysicalViewAbs(result_abs, out_axes_, axis_size);
1174 
1175   auto res = std::make_shared<EvalResult>(after_vmap, std::make_shared<AttrValueMap>());
1176   evaluator_cache_mgr_->SetValue(args_abs_list, res);
1177   return res;
1178 }
1179 
Eval(AnalysisEnginePtr,const AbstractBasePtrList & args_abs_list,const AnfNodeConfigPtr & out_conf)1180 EvalResultPtr VirtualEvaluator::Eval(AnalysisEnginePtr, const AbstractBasePtrList &args_abs_list,
1181                                      const AnfNodeConfigPtr &out_conf) {
1182   if (args_abs_list.size() != args_abs_list_.size()) {
1183     MS_LOG(INTERNAL_EXCEPTION) << "Arguments mismatch, parameters no: " << args_abs_list_.size()
1184                                << ", arguments no: " << args_abs_list.size();
1185   }
1186   const auto sense_param_index = args_abs_list.size() - 1;
1187   bool sense_param_flag = false;
1188   MS_EXCEPTION_IF_NULL(this->bound_node());
1189   if (this->bound_node()->isa<CNode>()) {
1190     sense_param_flag = this->bound_node()->cast<CNodePtr>()->HasAttr("sens_param_");
1191   }
1192   static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
1193   // Check each parameter and argument match;
1194   for (std::size_t i = 0; i < args_abs_list.size(); i++) {
1195     MS_EXCEPTION_IF_NULL(args_abs_list[i]);
1196     // For VirtualAbstractClosure, likely J's bprop, we just set its tuple arguments as used before really grad.
1197     if (enable_eliminate_unused_element && args_abs_list[i]->isa<AbstractSequence>()) {
1198       MS_LOG(INFO) << "Notice: For VirtualAbstractClosure, update all use flags as true for arguments[" << i
1199                    << "]: " << args_abs_list[i]->ToString();
1200       SetSequenceElementsUseFlagsRecursively(args_abs_list[i], true);
1201     }
1202     if (i == sense_param_index && sense_param_flag) {
1203       const auto &sense_shape = args_abs_list[i]->BuildShape();
1204       MS_EXCEPTION_IF_NULL(sense_shape);
1205       if (sense_shape->IsDynamic()) {
1206         MS_EXCEPTION(ValueError) << "The shape of sense must not be dynamic shape."
1207                                  << "\nFor more details with 'sense', please refer to "
1208                                  << "https://www.mindspore.cn/docs/zh-CN/master/faq/network_compilation.html.";
1209       }
1210     }
1211     (void)args_abs_list[i]->Join(args_abs_list_[i]);
1212   }
1213   return std::make_shared<EvalResult>(output_, std::make_shared<AttrValueMap>());
1214 }
1215 
SingleRun(AnalysisEnginePtr engine,const ConfigPtrList & args_conf_list,const AnfNodeConfigPtr & out_conf)1216 EvalResultPtr Evaluator::SingleRun(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
1217                                    const AnfNodeConfigPtr &out_conf) {
1218   EvalResultPtr result;
1219   try {
1220     result = this->Run(engine, args_conf_list, out_conf);
1221   } catch (const std::exception &ex) {
1222     MS_LOG(INFO) << "Eval " << ToString() << " throw exception.";
1223     AnalysisSchedule::GetInstance().HandleException(ex);
1224   }
1225   AnalysisSchedule::GetInstance().Wait();
1226   return result;
1227 }
1228 }  // namespace abstract
1229 }  // namespace mindspore
1230