• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/ps/parse/parse.h"
20 
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <sstream>
25 #include <algorithm>
26 #include <stack>
27 #include <regex>
28 #include "mindspore/core/ops/structure_ops.h"
29 #include "mindspore/core/ops/sequence_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "utils/hash_map.h"
32 #include "pipeline/jit/ps/fallback.h"
33 #include "pipeline/jit/ps/parse/resolve.h"
34 #include "pipeline/jit/ps/parse/data_converter.h"
35 #include "frontend/operator/ops.h"
36 #include "frontend/operator/composite/composite.h"
37 #include "utils/ms_context.h"
38 #include "utils/log_adapter.h"
39 #include "utils/compile_config.h"
40 #include "utils/interpret_node_recorder.h"
41 #include "pipeline/jit/ps/debug/trace.h"
42 #include "mindspore/core/ir/cell.h"
43 #include "include/common/fallback.h"
44 #include "include/common/utils/utils.h"
45 #include "include/common/utils/python_adapter.h"
46 #include "include/common/utils/convert_utils_py.h"
47 
48 namespace mindspore {
49 namespace parse {
ParsePythonCode(const py::object & obj,const std::string & python_mod_get_parse_method,const ValuePtrList & args_value_list)50 FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method,
51                              const ValuePtrList &args_value_list) {
52   (void)python_adapter::set_python_scoped();
53   py::gil_scoped_acquire gil;
54 
55   if (!obj || py::isinstance<py::none>(obj)) {
56     MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none";
57     return nullptr;
58   }
59   MS_LOG(DEBUG) << "Parse ast obj: " << py::str(obj)
60                 << ", python_mod_get_parse_method: " << python_mod_get_parse_method;
61 
62   auto ast = std::make_shared<ParseFunctionAst>(obj);
63   bool success = ast->InitParseAstInfo(python_mod_get_parse_method);
64   if (!success) {
65     MS_LOG(ERROR) << "Parse code to ast tree failed. obj: " << py::str(obj)
66                   << ", python_mod_get_parse_method: " << python_mod_get_parse_method;
67     return nullptr;
68   }
69 
70   auto parser = std::make_shared<Parser>(ast, args_value_list);
71 
72   FuncGraphPtr func_graph = parser->ParseFuncGraph();
73   if (func_graph == nullptr) {
74     MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode();
75     py::object node = ast->GetAstNode();
76     const auto &location = parser->GetLocation(node);
77     py::str desc = python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(),
78                                                location->file_name(), location->line());
79     MS_LOG(ERROR) << "\nlocation:" << desc.cast<std::string>();
80     return nullptr;
81   }
82 
83   // Handle no_inline function
84   auto no_inline_value = py::getattr(obj, FUNC_GRAPH_FLAG_NO_INLINE, py::none());
85   if (no_inline_value != py::none()) {
86     func_graph->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, py::cast<bool>(no_inline_value));
87   }
88   // Handle cell_reusing function
89   auto cell_reuse_value = py::getattr(obj, FUNC_GRAPH_FLAG_CELL_REUSE, py::none());
90   if (cell_reuse_value != py::none()) {
91     func_graph->set_flag(FUNC_GRAPH_FLAG_CELL_REUSE, py::cast<bool>(cell_reuse_value));
92   }
93 
94   MS_LOG(DEBUG) << "Finish Parsing " << py::str(obj);
95   return func_graph;
96 }
97 
98 FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
99 
Parser(const std::shared_ptr<ParseFunctionAst> & ast,ValuePtrList args_value_list)100 Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast, ValuePtrList args_value_list)
101     : ast_(ast), errcode_(PARSE_SUCCESS), args_value_list_(std::move(args_value_list)) {
102   BuildMethodMap();
103 }
104 
BuildMethodMap()105 void Parser::BuildMethodMap() {
106   stmt_method_map_["Return"] = &Parser::ParseReturn;
107   stmt_method_map_["Expr"] = &Parser::ParseExpr;
108   stmt_method_map_["If"] = &Parser::ParseIf;
109   stmt_method_map_["Assign"] = &Parser::ParseAssign;
110   stmt_method_map_["AnnAssign"] = &Parser::ParseAnnAssign;
111   stmt_method_map_["While"] = &Parser::ParseWhile;
112   stmt_method_map_["For"] = &Parser::ParseFor;
113   stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
114   stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
115   stmt_method_map_["Global"] = &Parser::ParseGlobal;
116   stmt_method_map_["Break"] = &Parser::ParseBreak;
117   stmt_method_map_["Continue"] = &Parser::ParseContinue;
118   stmt_method_map_["Pass"] = &Parser::ParsePass;
119   stmt_method_map_["Raise"] = &Parser::ParseRaise;
120   stmt_method_map_["Assert"] = &Parser::ParseAssert;
121   stmt_method_map_["With"] = &Parser::ParseWith;
122   expr_method_map_["NoneType"] = &Parser::ParseNone;
123   expr_method_map_["BinOp"] = &Parser::ParseBinOp;
124   expr_method_map_["Name"] = &Parser::ParseName;
125   expr_method_map_["Num"] = &Parser::ParseNum;
126   expr_method_map_["Str"] = &Parser::ParseStr;
127   expr_method_map_["Constant"] = &Parser::ParseConstant;
128   expr_method_map_["NameConstant"] = &Parser::ParseNameConstant;
129   expr_method_map_["Call"] = &Parser::ParseCall;
130   expr_method_map_["IfExp"] = &Parser::ParseIfExp;
131   expr_method_map_["Attribute"] = &Parser::ParseAttribute;
132   expr_method_map_["Compare"] = &Parser::ParseCompare;
133   expr_method_map_["BoolOp"] = &Parser::ParseBoolOp;
134   expr_method_map_["Lambda"] = &Parser::ParseLambda;
135   expr_method_map_["Tuple"] = &Parser::ParseTuple;
136   expr_method_map_["List"] = &Parser::ParseList;
137   expr_method_map_["Subscript"] = &Parser::ParseSubscript;
138   expr_method_map_["Slice"] = &Parser::ParseSlice;
139   expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice;
140   expr_method_map_["Index"] = &Parser::ParseIndex;
141   expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
142   expr_method_map_["Dict"] = &Parser::ParseDict;
143   expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
144   expr_method_map_["DictComp"] = &Parser::ParseDictComp;
145   expr_method_map_["ListComp"] = &Parser::ParseListComp;
146   expr_method_map_["GeneratorExp"] = &Parser::ParseListComp;  // We treat 'GeneratorExp' the same as 'ListComp'.
147   expr_method_map_["JoinedStr"] = &Parser::ParseJoinedStr;
148   expr_method_map_["FormattedValue"] = &Parser::ParseFormattedValue;
149   expr_method_map_["Starred"] = &Parser::ParseStarred;
150   condition_method_map_["Attribute"] = &Parser::CheckAttributeConstantCond;
151   condition_method_map_["Name"] = &Parser::CheckNameConstantCond;
152   condition_method_map_["UnaryOp"] = &Parser::CheckUnaryOpConstantCond;
153   condition_method_map_["Compare"] = &Parser::CheckCompareConstantCond;
154   condition_method_map_["BoolOp"] = &Parser::CheckBoolOpConstantCond;
155   compare_method_map_["is"] = &Parser::CompareIs;
156   compare_method_map_["is not"] = &Parser::CompareIsNot;
157   compare_method_map_["=="] = &Parser::CompareEqual;
158   compare_method_map_["!="] = &Parser::CompareNotEqual;
159   compare_method_map_[">"] = &Parser::CompareGreater;
160   compare_method_map_[">="] = &Parser::CompareGreaterEqual;
161   compare_method_map_["<"] = &Parser::CompareLess;
162   compare_method_map_["<="] = &Parser::CompareLessEqual;
163 }
164 
UpdateTopFuncGraph(const FuncGraphPtr & func_graph)165 void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
166 
InitParserEnvironment(const py::object & obj)167 void Parser::InitParserEnvironment(const py::object &obj) {
168   Parser::top_func_graph_ = FuncGraphWeakPtr();
169   ScopeManager::GetInstance().ClearScope();
170   (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj);
171   // CellList need convert to FuncGraph in Parse, add flag for input from top graph.
172   if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
173     py::setattr(obj, PYTHON_CELL_LIST_FROM_TOP, py::bool_(true));
174   }
175 }
176 
CleanParserResource()177 void Parser::CleanParserResource() {
178   Parser::top_func_graph_ = FuncGraphWeakPtr();
179   ScopeManager::GetInstance().ClearScope();
180   parse::CleanParameterNameCache();
181 }
182 
CheckFuncReturn(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fn)183 void Parser::CheckFuncReturn(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fn) {
184   // Check whether the functions referred by this function and itself are missing 'return' statement
185   MS_EXCEPTION_IF_NULL(manager);
186   MS_EXCEPTION_IF_NULL(ast_);
187   for (const auto &func_graph : manager->func_graphs()) {
188     MS_EXCEPTION_IF_NULL(func_graph);
189     if (func_graph->get_return() != nullptr) {
190       continue;
191     }
192     py::object node = ast_->GetAstNode();
193     const auto &location = GetLocation(node);
194     MS_EXCEPTION_IF_NULL(location);
195     py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
196                                                location->file_name(), location->line());
197     MS_LOG(INFO) << "Function must has 'return' statement, but missing in " << desc.cast<std::string>()
198                  << ". FuncGraph: " << func_graph->ToString() << ", location: " << location->ToString()
199                  << "\nWe will add a 'return None' statement automatically.";
200     // If the def function has no return statement, mean that return none.
201     TraceGuard trace_guard_none(location);
202     auto none_node = NewValueNode(kNone);
203     auto return_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), none_node});
204     func_graph->set_return(return_node);
205   }
206 }
207 
GetFreeVariable(const FuncGraphPtr & func_graph)208 std::vector<std::pair<CNodePtr, size_t>> GetFreeVariable(const FuncGraphPtr &func_graph) {
209   // Considering the performance, we didn't use Manager here.
210   std::vector<std::pair<CNodePtr, size_t>> free_variables;
211   MS_EXCEPTION_IF_NULL(func_graph);
212   std::vector<AnfNodePtr> nodes =
213     TopoSort(func_graph->get_return(), SuccIncoming, [&func_graph](const AnfNodePtr &node) -> IncludeType {
214       MS_EXCEPTION_IF_NULL(node);
215       // Not follow FV's inputs.
216       if (node->func_graph() != nullptr && node->func_graph() != func_graph) {
217         return NOFOLLOW;
218       }
219       return FOLLOW;
220     });
221   for (auto &node : nodes) {
222     // Only check Non-FV CNode.
223     auto cnode = dyn_cast<CNode>(node);
224     if (cnode == nullptr || cnode->func_graph() != func_graph) {
225       continue;
226     }
227 
228     for (size_t i = 0; i < cnode->size(); ++i) {
229       auto &input = cnode->input(i);
230       if (input->func_graph() != nullptr && input->func_graph() != func_graph) {
231         (void)free_variables.emplace_back(std::make_pair(cnode, i));
232         constexpr auto recur_2 = 2;
233         MS_LOG(DEBUG) << "Found FV: input[" << i << "] of " << cnode->DebugString(recur_2);
234       }
235     }
236   }
237   return free_variables;
238 }
239 
LiftRolledBodyGraphFV()240 void Parser::LiftRolledBodyGraphFV() {
241   for (auto &rolled_call_pair : rolled_body_calls_) {
242     auto rolled_call_cnode = rolled_call_pair.first;
243     auto rolled_graph = rolled_call_pair.second->func_graph();
244     MS_EXCEPTION_IF_NULL(rolled_graph);
245     const auto &free_variables = GetFreeVariable(rolled_graph);
246     for (auto &free_node_pair : free_variables) {
247       auto &cnode = free_node_pair.first;
248       auto index = free_node_pair.second;
249       // Move the free variable to parent.
250       auto &free_node = cnode->input(index);
251       rolled_call_cnode->add_input(free_node);
252       // Change the free variable to the parameter.
253       auto parameter = rolled_graph->add_parameter();
254       cnode->set_input(index, parameter);
255       constexpr auto recur_2 = 2;
256       MS_LOG(DEBUG) << "Change FV: " << cnode->DebugString(recur_2);
257     }
258   }
259 }
260 
LiftIfBranchGraphFV()261 void Parser::LiftIfBranchGraphFV() {
262   for (auto &branch_call_tuple : if_branch_calls_) {
263     auto call_cnode = std::get<0>(branch_call_tuple);
264     auto true_branch_graph = std::get<1>(branch_call_tuple)->func_graph();
265     MS_EXCEPTION_IF_NULL(true_branch_graph);
266     auto false_branch_graph = std::get<2>(branch_call_tuple)->func_graph();
267     MS_EXCEPTION_IF_NULL(false_branch_graph);
268     const auto &true_free_variables = GetFreeVariable(true_branch_graph);
269     const auto &false_free_variables = GetFreeVariable(false_branch_graph);
270     // Handle true branch.
271     for (auto &free_node_pair : true_free_variables) {
272       auto &cnode = free_node_pair.first;
273       MS_EXCEPTION_IF_NULL(cnode);
274       auto index = free_node_pair.second;
275       // Move the free variable to parent.
276       auto &free_node = cnode->input(index);
277       call_cnode->add_input(free_node);
278       // Change the free variable to the parameter.
279       auto parameter = true_branch_graph->add_parameter();
280       cnode->set_input(index, parameter);
281       // Add a unused parameter in other branch.
282       (void)false_branch_graph->add_parameter();
283       constexpr auto recur_2 = 2;
284       MS_LOG(DEBUG) << "True branch, change FV: " << cnode->DebugString(recur_2);
285     }
286     // Handle false branch.
287     for (auto &free_node_pair : false_free_variables) {
288       auto &cnode = free_node_pair.first;
289       MS_EXCEPTION_IF_NULL(cnode);
290       auto index = free_node_pair.second;
291       // Move the free variable to parent.
292       auto &free_node = cnode->input(index);
293       call_cnode->add_input(free_node);
294       // Change the free variable to the parameter.
295       auto parameter = false_branch_graph->add_parameter();
296       cnode->set_input(index, parameter);
297       // Add a unused parameter in other branch.
298       (void)true_branch_graph->add_parameter();
299       constexpr auto recur_2 = 2;
300       MS_LOG(DEBUG) << "False branch, change FV: " << cnode->DebugString(recur_2);
301     }
302   }
303 }
304 
305 namespace {
IsDependOfIsolatedNodes(const AnfNodePtr & node)306 bool IsDependOfIsolatedNodes(const AnfNodePtr &node) {
307   if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
308     return false;
309   }
310   auto cnode = dyn_cast<CNode>(node);
311   MS_EXCEPTION_IF_NULL(cnode);
312   auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
313   auto sort_rhs_first =
314     attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
315   return sort_rhs_first;
316 }
317 
GetRealOutputNodes(const FuncGraphPtr & call_graph)318 std::pair<CNodePtr, AnfNodePtr> GetRealOutputNodes(const FuncGraphPtr &call_graph) {
319   MS_EXCEPTION_IF_NULL(call_graph);
320   auto graph_output = call_graph->output();
321   if (graph_output == nullptr) {
322     MS_LOG(INTERNAL_EXCEPTION) << "graph_output is null, call_graph: " << call_graph->ToString();
323   }
324   auto graph_output_cnode = dyn_cast<CNode>(graph_output);
325   MS_EXCEPTION_IF_NULL(graph_output_cnode);
326   // If output cnode is not the tail call but a Depend CNode, keep the dependency node for later use.
327   AnfNodePtr graph_dependency_node = nullptr;
328   if (IsDependOfIsolatedNodes(graph_output_cnode)) {
329     auto graph_real_output_cnode = dyn_cast<CNode>(graph_output_cnode->input(1));
330     // Get the dependency node;
331     constexpr auto dependency_node_index = 2;
332     graph_dependency_node = graph_output_cnode->input(dependency_node_index);
333     MS_EXCEPTION_IF_NULL(graph_real_output_cnode);
334     graph_output_cnode = graph_real_output_cnode;
335   }
336   return {graph_output_cnode, graph_dependency_node};
337 }
338 
TransformParallelCallFormerToMiddle(const FuncGraphPtr & former_call_graph,const FuncGraphPtr & latter_call_graph,size_t middle_graph_output_cnode_size,bool use_arguments_pack)339 void TransformParallelCallFormerToMiddle(const FuncGraphPtr &former_call_graph, const FuncGraphPtr &latter_call_graph,
340                                          size_t middle_graph_output_cnode_size, bool use_arguments_pack) {
341   // The 'former_graph_output' is middle graph call or depend.
342   const auto &[former_graph_output_cnode, former_graph_dependency_node] = GetRealOutputNodes(former_call_graph);
343   MS_EXCEPTION_IF_NULL(former_graph_output_cnode);
344   MS_EXCEPTION_IF_NULL(former_call_graph);
345   std::vector<AnfNodePtr> inputs({NewValueNode(latter_call_graph)});
346   if (use_arguments_pack) {
347     for (size_t i = 0; i < middle_graph_output_cnode_size - 1; ++i) {
348       auto getitem_input = former_call_graph->NewCNodeInOrder(
349         {NewValueNode(prim::kPrimTupleGetItem), former_graph_output_cnode, NewValueNode(SizeToLong(i))});
350       (void)inputs.emplace_back(getitem_input);
351     }
352   } else {
353     (void)inputs.emplace_back(former_graph_output_cnode);
354   }
355   auto new_output = former_call_graph->NewCNodeBefore(former_call_graph->return_node(), std::move(inputs));
356   if (former_graph_dependency_node != nullptr) {
357     // Adjust the former funcgraph output with Depend.
358     new_output = former_call_graph->NewCNodeAfter(
359       new_output, {NewValueNode(prim::kPrimDepend), new_output, former_graph_dependency_node});
360     // Origin dependency_node has this attribute(refer to function IsDependOfIsolatedNodes), so we keep it.
361     new_output->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
362   }
363   former_call_graph->set_output(new_output);
364 }
365 
TransformParallelCallMiddleToLatter(const FuncGraphPtr & middle_call_graph,const CNodePtr & middle_graph_output_cnode,const AnfNodePtr & middle_graph_dependency_node,size_t middle_graph_output_cnode_size)366 bool TransformParallelCallMiddleToLatter(const FuncGraphPtr &middle_call_graph,
367                                          const CNodePtr &middle_graph_output_cnode,
368                                          const AnfNodePtr &middle_graph_dependency_node,
369                                          size_t middle_graph_output_cnode_size) {
370   MS_EXCEPTION_IF_NULL(middle_graph_output_cnode);
371   MS_EXCEPTION_IF_NULL(middle_call_graph);
372   bool use_arguments_pack = false;
373   constexpr auto output_inputs_num = 2;
374   AnfNodePtr new_middle_graph_output = nullptr;
375   if (middle_graph_output_cnode_size == output_inputs_num) {  // Only one argument.
376     new_middle_graph_output = middle_graph_output_cnode->input(1);
377   } else {  // More than one argument, pack them with tuple.
378     use_arguments_pack = true;
379     middle_graph_output_cnode->set_input(0, NewValueNode(prim::kPrimMakeTuple));
380     new_middle_graph_output = middle_graph_output_cnode;
381   }
382   // Adjust the middle funcgraph output with Depend.
383   if (middle_graph_dependency_node != nullptr) {
384     new_middle_graph_output = middle_graph_output_cnode->func_graph()->NewCNode(
385       {NewValueNode(prim::kPrimDepend), new_middle_graph_output, middle_graph_dependency_node});
386   }
387   middle_call_graph->set_output(new_middle_graph_output);
388   return use_arguments_pack;
389 }
390 
IsValueContainScalar(const ValuePtr & value)391 bool IsValueContainScalar(const ValuePtr &value) {
392   if (value->isa<Scalar>()) {
393     return true;
394   }
395   return false;
396 }
397 
IsOutputContainScalar(const CNodePtr & output_cnode)398 bool IsOutputContainScalar(const CNodePtr &output_cnode) {
399   return std::any_of(output_cnode->weak_inputs().cbegin() + 1, output_cnode->weak_inputs().end(),
400                      [](const AnfNodeWeakPtr &weak_node) {
401                        auto node = weak_node.lock();
402                        MS_EXCEPTION_IF_NULL(node);
403                        if (node->isa<ValueNode>()) {
404                          auto value_node = node->cast<ValueNodePtr>();
405                          return IsValueContainScalar(value_node->value());
406                        }
407                        return false;
408                      });
409 }
410 
CheckMiddleGraphOutputContainScalar(const std::vector<std::pair<FunctionBlockPtr,FunctionBlockPtr>> & parallel_call_vec)411 bool CheckMiddleGraphOutputContainScalar(
412   const std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> &parallel_call_vec) {
413   std::vector<bool> contains_scalar;
414   for (auto &call_graphs_pair : parallel_call_vec) {
415     MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
416     auto middle_call_graph = call_graphs_pair.second->func_graph();
417     MS_EXCEPTION_IF_NULL(middle_call_graph);
418     if (middle_call_graph->get_return() == nullptr) {
419       continue;
420     }
421     constexpr auto recur_2 = 2;
422     const auto &middle_graph_output_pair = GetRealOutputNodes(middle_call_graph);
423     const auto middle_graph_output_cnode = middle_graph_output_pair.first;
424     MS_EXCEPTION_IF_NULL(middle_graph_output_cnode);
425     auto middle_graph_output_cnode_size = middle_graph_output_cnode->size();
426     if (middle_graph_output_cnode_size <= 1) {
427       MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
428       return false;
429     }
430 
431     static const auto transform_if_const_scalar = (common::GetCompileConfig("IF_PARALLEL_CALL") == "2");
432     if (!transform_if_const_scalar && IsOutputContainScalar(middle_graph_output_cnode)) {
433       MS_LOG(DEBUG) << "CNode's inputs contain const scalar, " << middle_graph_output_cnode->DebugString(recur_2);
434       contains_scalar.push_back(true);
435     } else {
436       contains_scalar.push_back(false);
437     }
438   }
439 
440   return std::all_of(contains_scalar.cbegin(), contains_scalar.cend(), [](bool is_scalar) { return is_scalar; });
441 }
442 
CheckMiddleGraphOutputPyInterpret(const std::vector<std::pair<FunctionBlockPtr,FunctionBlockPtr>> & parallel_call_vec)443 bool CheckMiddleGraphOutputPyInterpret(
444   const std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> &parallel_call_vec) {
445   bool contain_py_interpret = false;
446   for (auto &call_graphs_pair : parallel_call_vec) {
447     MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
448     auto middle_call_graph = call_graphs_pair.second->func_graph();
449     MS_EXCEPTION_IF_NULL(middle_call_graph);
450     if (middle_call_graph->get_return() == nullptr) {
451       continue;
452     }
453     constexpr auto recur_2 = 2;
454     const auto &middle_graph_output_pair = GetRealOutputNodes(middle_call_graph);
455     const auto middle_graph_output_cnode = middle_graph_output_pair.first;
456     MS_EXCEPTION_IF_NULL(middle_graph_output_cnode);
457     auto middle_graph_output_cnode_size = middle_graph_output_cnode->size();
458     if (middle_graph_output_cnode_size <= 1) {
459       MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
460       return false;
461     }
462     bool exist_interpret = std::any_of(
463       middle_graph_output_cnode->weak_inputs().cbegin() + 1, middle_graph_output_cnode->weak_inputs().cend(),
464       [](const AnfNodeWeakPtr &weak_node) { return IsPrimitiveCNode(weak_node.lock(), prim::kPrimPyInterpret); });
465     contain_py_interpret |= exist_interpret;
466     if (contain_py_interpret) {
467       return true;
468     }
469   }
470 
471   return false;
472 }
473 }  // namespace
474 
475 // Transform tail call to parallel call.
TransformParallelCall()476 void Parser::TransformParallelCall() {
477   mindspore::HashSet<FuncGraphPtr> latter_call_graphs_set;
478   for (auto &parallel_call_vec : parallel_call_graphs_) {
479     bool all_middle_graphs_output_scalar = CheckMiddleGraphOutputContainScalar(parallel_call_vec);
480     if (all_middle_graphs_output_scalar) {
481       MS_LOG(DEBUG) << "All middle func graph's output contain const scalar, cannot transform to Parallel_If.";
482       continue;
483     }
484     // After Join, Value in Abstract of PyInterpret CNode will be kValueAny, it cannot be PyInterpreted again, so
485     // ignore the transformation.
486     bool is_middle_graphs_output_py_interpret = CheckMiddleGraphOutputPyInterpret(parallel_call_vec);
487     if (is_middle_graphs_output_py_interpret) {
488       MS_LOG(DEBUG) << "Middle func graph's output contain PyInterpret CNode, cannot transform to Parallel_If.";
489       continue;
490     }
491     for (auto &call_graphs_pair : parallel_call_vec) {
492       MS_EXCEPTION_IF_NULL(call_graphs_pair.first);
493       auto former_call_graph = call_graphs_pair.first->func_graph();
494       MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
495       auto middle_call_graph = call_graphs_pair.second->func_graph();
496       // Transform the call of {middle_graph -> latter_graph}.
497       auto middle_graph_return = middle_call_graph->get_return();
498       if (middle_graph_return == nullptr) {
499         MS_LOG(INFO) << "middle_graph_return is null, middle_call_graph: " << middle_call_graph->ToString();
500         continue;
501       }
502       constexpr auto recur_3 = 3;
503       constexpr auto recur_2 = 2;
504       MS_LOG(DEBUG) << "Tail call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
505                     << ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
506       const auto &[middle_graph_output_cnode, middle_graph_dependency_node] = GetRealOutputNodes(middle_call_graph);
507       auto middle_graph_output_cnode_size = middle_graph_output_cnode->size();
508       if (middle_graph_output_cnode_size <= 1) {
509         MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
510         continue;
511       }
512 
513       auto latter_graph_node = middle_graph_output_cnode->input(0);
514       bool use_arguments_pack = TransformParallelCallMiddleToLatter(
515         middle_call_graph, middle_graph_output_cnode, middle_graph_dependency_node, middle_graph_output_cnode_size);
516 
517       // Transform the call of {former_graph -> middle_graph}.
518       auto latter_call_graph = GetValueNode<FuncGraphPtr>(latter_graph_node);
519       if (latter_call_graph == nullptr) {
520         MS_LOG(ERROR) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2);
521         continue;
522       }
523       if (latter_call_graphs_set.find(latter_call_graph) != latter_call_graphs_set.end()) {
524         MS_LOG(DEBUG) << "The latter graph is handled before, " << latter_call_graph->ToString();
525         continue;
526       }
527       (void)latter_call_graphs_set.emplace(latter_call_graph);
528       TransformParallelCallFormerToMiddle(former_call_graph, latter_call_graph, middle_graph_output_cnode_size,
529                                           use_arguments_pack);
530 
531       MS_LOG(DEBUG) << "Parallel call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
532                     << ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
533     }
534   }
535 
536   // Lift inner, then lift outer.
537   LiftIfBranchGraphFV();
538   LiftRolledBodyGraphFV();
539 }
540 
ParseFuncGraph()541 FuncGraphPtr Parser::ParseFuncGraph() {
542   // Get ast FunctionDef node
543   py::object node = ast_->GetAstNode();
544   constexpr char function_def_name[] = "FunctionDef";
545   constexpr char lambda_name[] = "Lambda";
546   FunctionBlockPtr fn_block = nullptr;
547   MS_EXCEPTION_IF_NULL(ast_->GetNodeType(node));
548   if (ast_->GetNodeType(node)->node_name() == function_def_name) {
549     fn_block = ParseDefFunction(node);
550   } else {
551     auto lambda_node = python_adapter::GetPyObjAttr(node, "value");
552     if (py::isinstance<py::none>(lambda_node) || ast_->GetNodeType(lambda_node)->node_name() != lambda_name) {
553       MS_INTERNAL_EXCEPTION(TypeError) << "Parse Lambda Function Fail. Node type must be Lambda, but got "
554                                        << ast_->GetNodeType(lambda_node)->node_name() << ". Please check lambda"
555                                        << " expression to make sure it is defined on a separate line.\n For example, "
556                                        << "the code 'func = nn.ReLU() if y < 1 else lambda x: x + 1' rewritten as\n"
557                                        << "'if y < 1:\n    func = nn.ReLU()\nelse:\n    func = lambda x: x + 1\n'"
558                                        << "will solve the problem.";
559     }
560     fn_block = ParseLambdaFunction(lambda_node);
561   }
562   if (errcode() != PARSE_SUCCESS) {
563     MS_LOG(ERROR) << "Parse function error, code is " << errcode();
564     return nullptr;
565   }
566   for (auto &func_block_item : func_block_list_) {
567     MS_EXCEPTION_IF_NULL(func_block_item);
568     MS_EXCEPTION_IF_NULL(func_block_item->func_graph());
569     if (!func_block_item->isolated_nodes().empty()) {
570       // Find unused variables.
571       func_block_item->FindIsolatedNodes();
572       // Attach all isolated nodes.
573       func_block_item->AttachIsolatedNodesBeforeReturn();
574     }
575   }
576   MS_EXCEPTION_IF_NULL(fn_block);
577   auto manager = Manage(fn_block->func_graph(), false);
578   RemoveUnnecessaryPhis(manager);
579   CheckFuncReturn(manager, fn_block->func_graph());
580   TransformParallelCall();
581   return fn_block->func_graph();
582 }
583 
584 // If any mixed precision flag add a cast node after the parameter node.
GetMixedPrecisionCastHelp(const FuncGraphPtr & func_graph,const AnfNodePtr & param)585 AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param) {
586   MS_EXCEPTION_IF_NULL(func_graph);
587   TypePtr dst_type;
588   if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
589     dst_type = kFloat32;
590   } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
591     dst_type = kFloat16;
592   } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_BF16)) {
593     dst_type = kBFloat16;
594   } else {
595     return param;
596   }
597   auto cast_helper = prim::kPrimMixedPrecisionCast;
598   auto cast = func_graph->NewCNodeAfter(param, {NewValueNode(cast_helper), NewValueNode(dst_type), param});
599   return cast;
600 }
601 
GenerateArgsNodeForFunction(const FunctionBlockPtr & block,const py::object & fn_node)602 void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
603   py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args");
604   py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg");
605   MS_EXCEPTION_IF_NULL(block);
606   auto block_fg = block->func_graph();
607   block_fg->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
608 
609   py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg");
610   block_fg->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
611 
612   py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs");
613   block_fg->set_kwonlyargs_count(SizeToInt(kwonly_args.size()));
614 
615   MS_EXCEPTION_IF_NULL(ast_);
616   py::list args = ast_->GetArgs(fn_node);
617   for (std::size_t i = 0; i < args.size(); i++) {
618     std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
619     if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
620       if (arg_name == "self") {
621         continue;
622       }
623     }
624     TraceGuard guard(GetLocation(args[i]));
625     auto para_node = std::make_shared<Parameter>(block_fg);
626     MS_EXCEPTION_IF_NULL(para_node);
627     para_node->set_name(arg_name);
628     MS_EXCEPTION_IF_NULL(para_node->debug_info());
629     para_node->debug_info()->set_name(arg_name);
630     block_fg->add_parameter(para_node);
631     AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block_fg, para_node);
632     MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
633     block->WriteVariable(arg_name, para_after_cast);
634   }
635 }
636 
GenerateArgsDefaultValueForFunction(const FunctionBlockPtr & block,const py::object & fn_node)637 void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
638   MS_EXCEPTION_IF_NULL(block);
639   py::list defaults = ast_->GetArgsDefaultValues(fn_node);
640   py::list args = ast_->GetArgs(fn_node);
641   std::vector<std::string> namelist_for_default_value;
642   std::vector<AnfNodePtr> default_values;
643   for (std::size_t i = 0; i < args.size(); i++) {
644     std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
645     if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
646       if (arg_name == "self") {
647         continue;
648       }
649     }
650 
651     namelist_for_default_value.push_back(arg_name);
652     if (i >= defaults.size()) {
653       MS_LOG(INTERNAL_EXCEPTION) << "Index: " << i << " out of range: " << defaults.size();
654     }
655     if (py::isinstance<py::none>(defaults[i])) {
656       default_values.push_back(NewValueNode(kNull));
657     } else {
658       AnfNodePtr arg_node = ParseExprNode(block, defaults[i]);
659       default_values.push_back(arg_node);
660     }
661   }
662   MS_EXCEPTION_IF_NULL(block->func_graph());
663   block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values);
664 }
665 
GetScopeForParseFunction()666 ScopePtr Parser::GetScopeForParseFunction() {
667   ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope();
668   if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
669     py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj());
670     if (!py::isinstance<py::none>(scope_str)) {
671       auto scope_name = py::cast<std::string>(scope_str);
672       scope = std::make_shared<Scope>(scope_name);
673     }
674   }
675   return scope;
676 }
677 
ConvertGetattrNodes()678 void Parser::ConvertGetattrNodes() {
679   // If obj.attr has been set a new value in graph, convert all getattr node to PyExecute.
680   AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
681   for (const auto &setattr_node_pair : setattr_nodes_map_) {
682     const auto &obj_str = setattr_node_pair.first;
683     const auto &attr_map = setattr_node_pair.second;
684     auto getattr_nodes_map_iter = getattr_nodes_map_.find(obj_str);
685     // If the same object is not in both setattr map and getattr map, no need to convert getattr node.
686     if (getattr_nodes_map_iter == getattr_nodes_map_.end()) {
687       continue;
688     }
689     const auto &getattr_map = getattr_nodes_map_iter->second;
690     for (const auto &attr_pair : attr_map) {
691       const auto &attr_str = attr_pair.first;
692       auto getattr_map_iter = getattr_map.find(attr_str);
693       // If the same attr for the same obj is not in both setattr map and getattr map, no need to convert getattr node.
694       if (getattr_map_iter == getattr_map.end()) {
695         continue;
696       }
697       const auto &setattr_node = attr_pair.second;
698       auto setattr_cnode = setattr_node->cast<CNodePtr>();
699       MS_EXCEPTION_IF_NULL(setattr_cnode);
700       const auto getattr_nodes = getattr_map_iter->second;
701       constexpr size_t obj_index = 1;
702       const auto &setattr_cnode_obj_node = setattr_cnode->input(obj_index);
703       AnfNodePtr cur_getattr_node = nullptr;
704       for (const auto &getattr_node : getattr_nodes) {
705         auto getattr_node_fg = getattr_node->func_graph();
706         if (getattr_node_fg == nullptr) {
707           MS_LOG(DEBUG) << "Has no func graph, getattr_node: " << getattr_node->DebugString();
708           continue;
709         }
710         std::vector<AnfNodePtr> new_getattr_node_inputs = {op_node, setattr_cnode_obj_node, NewValueNode(attr_str)};
711         if (cur_getattr_node != nullptr && cur_getattr_node->func_graph() == getattr_node_fg) {
712           (void)new_getattr_node_inputs.emplace_back(cur_getattr_node);
713         }
714         auto new_getattr_node = getattr_node_fg->NewCNode(new_getattr_node_inputs);
715         new_getattr_node->set_user_data<bool>(fallback::kObjectAttrChange, std::make_shared<bool>(true));
716         new_getattr_node->set_debug_info(getattr_node->debug_info());
717         MS_LOG(DEBUG) << "Generate new getattr node: " << new_getattr_node->DebugString();
718         const auto &manager = Manage(getattr_node_fg, false);
719         MS_EXCEPTION_IF_NULL(manager);
720         (void)manager->Replace(getattr_node, new_getattr_node);
721         cur_getattr_node = new_getattr_node;
722       }
723     }
724   }
725 }
726 
ParseDefFunction(const py::object & node,const FunctionBlockPtr & block)727 FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const FunctionBlockPtr &block) {
728   ScopePtr scope = GetScopeForParseFunction();
729   // The node created in the parsefunction context, will inherit the scope created using scope_guard
730   ScopeGuard scope_guard(scope);
731   const auto debug_info = std::make_shared<DebugInfo>(GetLocation(node));
732   TraceGuard trace_guard(std::make_shared<TraceParse>(debug_info));
733   FunctionBlockPtr func_block = MakeFunctionBlock();
734   if (block != nullptr) {
735     func_block->AddPrevBlock(block);
736   } else {
737     func_graph_ = func_block->func_graph();
738   }
739   func_block->Mature();
740   auto current_fg = func_block->func_graph();
741   auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
742   MS_LOG(DEBUG) << "The function name is " << function_name << ", loc: " << GetLocation(node)->ToString();
743   // Replace the construct function name with the cell name
744   constexpr auto cell_construct = "construct";
745   bool is_construct_function = false;
746   if (function_name == cell_construct) {
747     is_construct_function = true;
748     // 'py_class_name' format is like: <class 'x.x.xxx'>
749     std::string py_class_name = py::cast<std::string>(py::str(ast()->obj().get_type()));
750     constexpr auto py_class_prefix_len = 8;  // <class '
751     constexpr auto py_class_suffix_len = 2;  // '>
752     auto py_class_len = py_class_name.length();
753     // Exclude class prefix and suffix.
754     auto class_name =
755       py_class_name.substr(py_class_prefix_len, py_class_len - py_class_prefix_len - py_class_suffix_len);
756     function_name = class_name + '_' + cell_construct;
757     MS_LOG(DEBUG) << "The generated function full name: " << function_name;
758   }
759   // Normalize the name.
760   std::replace(function_name.begin(), function_name.end(), '.', '_');
761   std::replace(function_name.begin(), function_name.end(), '<', '_');
762   std::replace(function_name.begin(), function_name.end(), '>', '_');
763 
764   // Save the function node to block
765   func_block->WriteVariable(function_name, NewValueNode(current_fg));
766   MS_EXCEPTION_IF_NULL(current_fg->debug_info());
767   current_fg->debug_info()->set_name(function_name);
768   py::list deco_list = node.attr("decorator_list");
769   if (!deco_list.empty()) {
770     current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
771   }
772   MS_EXCEPTION_IF_NULL(ast_);
773   bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
774   if (!ast_->obj().is(ast_->function())) {
775     set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg, is_construct_function);
776   }
777 
778   if (!set_flag) {
779     MS_LOG(ERROR) << "Set flags failed";
780     return nullptr;
781   }
782   GenerateArgsNodeForFunction(func_block, node);
783 
784   // When parsing the top graph of construct, save the top graph
785   if (GetTopFuncGraph() == nullptr) {
786     UpdateTopFuncGraph(func_block->func_graph());
787   }
788 
789   py::object func_obj = python_adapter::GetPyObjAttr(node, "body");
790   (void)ParseStatements(func_block, func_obj);
791   if (current_fg->get_return() == nullptr) {
792     // If the def function has no return statement, mean that return none.
793     py::object location_node = ast_->GetAstNode();
794     const auto &location = GetLocation(location_node);
795     MS_EXCEPTION_IF_NULL(location);
796     py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
797                                                location->file_name(), location->line());
798     MS_LOG(INFO) << "Function must has 'return' statement, but missing in " << desc.cast<std::string>()
799                  << ". FuncGraph: " << current_fg->ToString() << ", location: " << location->ToString()
800                  << ". We will add a 'return None' statement automatically.";
801     TraceGuard trace_guard_none(location);
802     auto none_node = NewValueNode(kNone);
803     auto return_node = current_fg->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), none_node});
804     current_fg->set_return(return_node);
805   }
806 
807   // Add unused variables as isolate nodes.
808   for (auto &func_block_item : func_block_list_) {
809     MS_EXCEPTION_IF_NULL(func_block_item);
810     MS_EXCEPTION_IF_NULL(func_block_item->func_graph());
811     if (func_block_item->func_graph()->get_return() != nullptr) {
812       // Find unused variables.
813       func_block_item->FindIsolatedNodes();
814       // Attach all isolated nodes.
815       func_block_item->AttachIsolatedNodesBeforeReturn();
816     }
817   }
818 
819   ConvertGetattrNodes();
820   GenerateArgsDefaultValueForFunction(func_block, node);
821   return func_block;
822 }
823 
ParseLambdaFunction(const py::object & node,const FunctionBlockPtr & block)824 FunctionBlockPtr Parser::ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block) {
825   MS_EXCEPTION_IF_NULL(ast_);
826   ScopePtr scope = GetScopeForParseFunction();
827   ScopeGuard scope_guard(scope);
828   const auto debug_info = std::make_shared<DebugInfo>(GetLocation(node));
829   TraceGuard trace_guard(std::make_shared<TraceParse>(debug_info));
830   FunctionBlockPtr func_block = MakeFunctionBlock();
831   MS_EXCEPTION_IF_NULL(func_block);
832   if (block != nullptr) {
833     func_block->AddPrevBlock(block);
834   } else {
835     func_graph_ = func_block->func_graph();
836   }
837   func_block->Mature();
838   auto current_fg = func_block->func_graph();
839 
840   MS_EXCEPTION_IF_NULL(current_fg);
841   auto lambda_function_name = ast_->function_name();
842   // Normalize the name.
843   std::replace(lambda_function_name.begin(), lambda_function_name.end(), '.', '_');
844   std::replace(lambda_function_name.begin(), lambda_function_name.end(), '<', '_');
845   std::replace(lambda_function_name.begin(), lambda_function_name.end(), '>', '_');
846   constexpr auto lambda_suffix = "_lambda_";  // Represent <lambda>.
847   auto function_name = lambda_function_name + "_" + lambda_suffix;
848   MS_LOG(DEBUG) << "The function name is " << function_name;
849   MS_EXCEPTION_IF_NULL(current_fg->debug_info());
850   current_fg->debug_info()->set_name(function_name);
851   GenerateArgsNodeForFunction(func_block, node);
852 
853   // When parsing the top graph of construct, save the top graph
854   if (GetTopFuncGraph() == nullptr) {
855     UpdateTopFuncGraph(func_block->func_graph());
856   }
857 
858   py::object body_node = python_adapter::GetPyObjAttr(node, "body");
859   AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
860   current_fg->set_output(lambda_body_node);
861 
862   // Add unused variables as isolate nodes.
863   for (auto &func_block_item : func_block_list_) {
864     MS_EXCEPTION_IF_NULL(func_block_item);
865     MS_EXCEPTION_IF_NULL(func_block_item->func_graph());
866     if (!func_block_item->isolated_nodes().empty()) {
867       // Find unused variables.
868       func_block_item->FindIsolatedNodes();
869       // Attach all isolated nodes.
870       func_block_item->AttachIsolatedNodesBeforeReturn();
871     }
872   }
873 
874   GenerateArgsDefaultValueForFunction(func_block, node);
875   return func_block;
876 }
877 
ParseStatements(const FunctionBlockPtr & block,const py::object & nodes)878 FunctionBlockPtr Parser::ParseStatements(const FunctionBlockPtr &block, const py::object &nodes) {
879   auto node_list = py::cast<py::list>(nodes);
880   size_t count = py::len(node_list);
881   MS_LOG(DEBUG) << "The nodes count is " << count;
882   auto sub_block = block;
883   for (size_t i = 0; i < count; ++i) {
884     MS_LOG(DEBUG) << "Start parse statement[" << i << "]: " << py::str(node_list[i])
885                   << ", block: " << sub_block->ToString();
886     auto node = node_list[i];
887     // Flag of return statement is set on sub_block inside ParseStatement, so use next_block
888     // to store the returned block temporarily.
889     auto next_block = ParseStatement(sub_block, node);
890     MS_EXCEPTION_IF_NULL(next_block);
891     MS_EXCEPTION_IF_NULL(next_block->func_graph());
892     // Propagate flag of return statement back;
893     if (sub_block != block && sub_block->is_return_statement_inside()) {
894       MS_LOG(DEBUG) << "Sub block: " << sub_block->ToString()
895                     << " has return statement inside, propagate flag back to block: " << block->ToString();
896       block->set_is_return_statement_inside();
897     }
898     // Propagate flag of break or continue statement back;
899     if (sub_block != block && sub_block->is_break_continue_statement_inside()) {
900       MS_LOG(DEBUG) << "Sub block: " << sub_block->ToString()
901                     << " has break or continue statement inside, propagate flag back to block: " << block->ToString();
902       block->set_break_continue_statement_inside();
903     }
904     sub_block = next_block;
905 
906     static const auto boost_parse = common::GetCompileConfig("BOOST_PARSE");
907     if (boost_parse != "0" && sub_block->is_dead_block()) {
908       break;
909     }
910     if (boost_parse == "0") {
911       // Insert appropriate depended items for the function block if it has a return node
912       if (sub_block->func_graph()->get_return() != nullptr || sub_block->is_dead_block()) {
913         // If break is not the last expr.
914         if (i != count - 1) {
915           TraceGuard trace_guard(GetLocation(node_list[i + 1]));
916           MS_LOG(EXCEPTION) << "Dead code exist, please remove it. [" << (i + 1) << "/" << count
917                             << "], node: " << py::str(node_list[i]) << ", block: " << sub_block->ToString()
918                             << ", has_return: " << (sub_block->func_graph()->get_return() != nullptr)
919                             << ", is_dead_block: " << sub_block->is_dead_block();
920         }
921         // Skip statements after 'return' (or 'break', 'continue').
922         break;
923       }
924     }
925     // If the current block has multi return statements,
926     // only parse the statements before first return statement.
927     // Statements after the continue and break statements are also not parsed.
928     if (ast_->GetNodeType(node)->node_name() == "Break" || ast_->GetNodeType(node)->node_name() == "Continue" ||
929         ast_->GetNodeType(node)->node_name() == "Return") {
930       break;
931     }
932   }
933   return sub_block;
934 }
935 
ParseStatement(const FunctionBlockPtr & block,const py::object & node)936 FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) {
937   TraceGuard trace_guard(GetLocation(node));
938   auto node_type = ast_->GetNodeType(node);
939 
940   // Check the node type
941   AstMainType nodeType = node_type->main_type();
942   if (nodeType != AST_MAIN_TYPE_STMT) {
943     MS_LOG(INFO) << "Node type is error : " << nodeType;
944     return block;
945   }
946   // Call the process function
947   std::string node_name = node_type->node_name();
948   MS_LOG(DEBUG) << "Ast node is " << node_name << ", location:" << GetLocation(node)->ToString();
949   if (stmt_method_map_.count(node_name) != 0) {
950     auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
951     return stmt_block;
952   } else {
953     errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
954     MS_LOG(EXCEPTION) << "Unsupported statement '" << node_name
955                       << "'.\nMore details please refer to syntax support at https://www.mindspore.cn";
956   }
957 }
958 
ParseExprNode(const FunctionBlockPtr & block,const py::object & node)959 AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) {
960   MS_LOG(DEBUG) << "Process ast expr.";
961   TraceGuard trace_guard(GetLocation(node));
962   auto node_type = ast_->GetNodeType(node);
963   // Check the node type
964   AstMainType node_main_type = node_type->main_type();
965   if (node_main_type != AST_MAIN_TYPE_EXPR) {
966     errcode_ = PARSE_NODE_TYPE_NO_MATCH;
967     MS_LOG(INTERNAL_EXCEPTION) << "Node type is error : " << node_main_type;
968   }
969   // Call the process function
970   const std::string &node_type_name = node_type->node_name();
971   MS_LOG(DEBUG) << "Ast node is " << node_type_name << ", location:" << GetLocation(node)->ToString();
972   if (expr_method_map_.count(node_type_name) != 0) {
973     auto expr_node = (this->*expr_method_map_[node_type_name])(block, node);
974     MS_LOG(DEBUG) << "Get parsed anf node:" << expr_node->DebugString();
975     return expr_node;
976   } else {
977     errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
978     MS_LOG(EXCEPTION) << "Unsupported expression '" << node_type_name
979                       << "'.\nMore details please refer to syntax support at https://www.mindspore.cn";
980   }
981 }
982 
983 // If self.attr.func is inplace operation, then
984 //   self.attr.func(inputs)
985 //   -->
986 //   self.attr = self.attr.func(inputs)
987 //   setattr(self, "attr", self.attr.func(inputs))
HandleSetAttrClassMemberForInplace(const FunctionBlockPtr & block,const AnfNodePtr & node)988 bool Parser::HandleSetAttrClassMemberForInplace(const FunctionBlockPtr &block, const AnfNodePtr &node) {
989   if (!node->isa<CNode>()) {
990     return false;
991   }
992   auto cnode = node->cast<CNodePtr>();
993   auto call_node = cnode->input(0);
994   if (!IsPrimitiveCNode(call_node, prim::kPrimGetAttr)) {
995     return false;
996   }
997   constexpr int recursive_level = 2;
998   // call_cnode: self.attr.func
999   auto call_cnode = call_node->cast<CNodePtr>();
1000   MS_LOG(DEBUG) << "call cnode: " << call_cnode->DebugString(recursive_level);
1001   const auto &call_cnode_inputs = call_cnode->inputs();
1002   constexpr size_t attr_node_index = 1;
1003   constexpr size_t func_str_index = 2;
1004   auto func_str_node = call_cnode_inputs[func_str_index];
1005   if (!IsValueNode<StringImm>(func_str_node)) {
1006     return false;
1007   }
1008   const auto &func_str = GetValue<std::string>(GetValueNode(func_str_node));
1009   std::vector<std::string> inplace_ops{"extend", "pop", "insert", "reverse"};
1010   MS_LOG(DEBUG) << "func str: " << func_str;
1011   if (std::all_of(inplace_ops.begin(), inplace_ops.end(),
1012                   [&func_str](const std::string &ops) { return func_str != ops; })) {
1013     return false;
1014   }
1015   auto attr_node = call_cnode_inputs[attr_node_index];
1016   if (!attr_node->isa<CNode>()) {
1017     return false;
1018   }
1019   // attr_cnode: self.attr
1020   auto attr_cnode = attr_node->cast<CNodePtr>();
1021   MS_LOG(DEBUG) << "attr cnode: " << attr_cnode->DebugString(recursive_level);
1022   const auto &attr_cnode_inputs = attr_cnode->inputs();
1023   constexpr size_t target_index = 1;
1024   constexpr size_t attr_index = 2;
1025   auto target_node = attr_cnode_inputs[target_index];
1026   MS_LOG(DEBUG) << "target node: " << target_node->DebugString();
1027   auto target_attr_node = attr_cnode_inputs[attr_index];
1028   auto symbol_val = GetValueNode<SymbolPtr>(target_attr_node);
1029   if (symbol_val == nullptr) {
1030     return false;
1031   }
1032   const auto &target_symbol_str = symbol_val->symbol();
1033   MakeSetAttrNode(block, target_node, cnode, "self", target_symbol_str);
1034   return true;
1035 }
1036 
1037 // Process the expr statement and expand it
ParseExpr(const FunctionBlockPtr & block,const py::object & node)1038 FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) {
1039   MS_LOG(DEBUG) << "Process ast Expr";
1040   // Expr only have value, no target
1041   py::tuple expand_info = ast_->CallParseModFunction(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node);
1042 
1043   // Refer python function expand_expr_statement, expand_info is one of the following:
1044   // True, expr.value, x
1045   // True, expr.value
1046   // False, None, None
1047   //
1048   // Check the expand info result
1049   if (expand_info.empty()) {
1050     MS_LOG(INTERNAL_EXCEPTION) << "Empty expand_info.";
1051   }
1052   auto is_expand = py::cast<bool>(expand_info[0]);
1053   if (is_expand) {
1054     // Process the expr statement
1055     constexpr size_t expect_size = 2;
1056     if (expand_info.size() < expect_size) {
1057       MS_LOG(INTERNAL_EXCEPTION) << "expand_info size:" << expand_info.size() << " less than " << expect_size << ".";
1058     }
1059     py::object value_object = expand_info[1];
1060     // Make a Expr CNode.
1061     AnfNodePtr call_node = ParseExprNode(block, value_object);
1062     if (py::len(expand_info) == expect_size) {
1063       // list_x.pop(a) does not write the return value of pop.
1064       // -->  list_x = list_x.pop(a) need renew the list_x.
1065       if (IsPopOperation(call_node)) {
1066         if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE && ast_->IsClassMemberOfSelf(list_pop_target_obj_)) {
1067           // self.list_x = [xx, xx]
1068           // self.list_x.pop()
1069           MS_LOG(DEBUG) << "The variables whose type is not parameter do not support pop operation.";
1070         } else {
1071           auto func_graph = block->func_graph();
1072           MS_EXCEPTION_IF_NULL(func_graph);
1073           auto new_list = func_graph->NewCNodeInOrder(
1074             {NewValueNode(prim::kPrimTupleGetItem), call_node, NewValueNode(SizeToLong(0))});
1075           WriteAssignVars(block, list_pop_target_obj_, new_list);
1076           block->AddIsolatedNode(call_node);
1077           return block;
1078         }
1079       }
1080       // Expression that not assigned to any variable.
1081       // This is usually a call with side effects.
1082       // e.g.: print(x)
1083       // We save it as an isolated node.
1084       auto &no_return_node = call_node;
1085       MS_LOG(INFO) << "Isolated node found(NoReturn), no_return_node: " << no_return_node->DebugString()
1086                    << ", block: " << block << "/"
1087                    << (block->func_graph() ? block->func_graph()->ToString() : "FG(Null)")
1088                    << ", Line: " << trace::GetDebugInfoStr(no_return_node->debug_info(), "", kSourceLineTipDiscard);
1089       block->AddIsolatedNode(no_return_node);
1090     } else {
1091       // Expand the assign statement,
1092       // e.g.: x.append(y)  -> x = x.append(y)
1093       py::object target_node = expand_info[2];
1094       // Check whether the target_node is class member recursively.
1095       // e.g.: self.a1.a1.update()
1096       if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE && ast_->IsClassMemberRecursive(target_node)) {
1097         // self.x = [xx, xx]
1098         // self.x.append()
1099         const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
1100         if (!allow_fallback_runtime || !HandleSetAttrClassMemberForInplace(block, call_node)) {
1101           block->AddIsolatedNode(call_node);
1102         }
1103       } else {
1104         WriteAssignVars(block, target_node, call_node);
1105       }
1106     }
1107   }
1108   return block;
1109 }
1110 
GetLocation(const py::object & node) const1111 LocationPtr Parser::GetLocation(const py::object &node) const {
1112   MS_EXCEPTION_IF_NULL(ast_);
1113   py::list res = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
1114   constexpr size_t list_size = 6;
1115   if (res.size() < list_size) {
1116     MS_LOG(INTERNAL_EXCEPTION) << "List size should not be less than 5.";
1117   }
1118   constexpr size_t file_name_index = 0;
1119   constexpr size_t line_index = 1;
1120   constexpr size_t column_index = 2;
1121   constexpr size_t line_end_index = 3;
1122   constexpr size_t column_end_index = 4;
1123   constexpr size_t expr_src_index = 5;
1124   constexpr size_t comments_index = 6;
1125   // Deal with the comments.
1126   std::vector<std::string> comments_str_list;
1127   const auto comments_list = res[comments_index].cast<py::list>();
1128   for (size_t i = 0; i < comments_list.size(); ++i) {
1129     (void)comments_str_list.emplace_back(comments_list[i].cast<std::string>());
1130   }
1131   if (!comments_str_list.empty()) {
1132     MS_LOG(DEBUG) << "@jit comments: " << comments_str_list;
1133   }
1134   // Refer to Location::Location() for each member of res: line, column, line_end, column_end, expr_src.
1135   auto location = std::make_shared<Location>(res[file_name_index].cast<std::string>(), res[line_index].cast<int64_t>(),
1136                                              res[column_index].cast<int64_t>(), res[line_end_index].cast<int64_t>(),
1137                                              res[column_end_index].cast<int64_t>(),
1138                                              res[expr_src_index].cast<std::string>(), std::move(comments_str_list));
1139   MS_LOG(DEBUG) << "node: " << py::str(node) << ",\n" << location->DebugString();
1140   return location;
1141 }
1142 
1143 // NOTICE: Must add a TraceGuard before call it.
MakeFunctionBlock()1144 FunctionBlockPtr Parser::MakeFunctionBlock() {
1145   FunctionBlockPtr block = std::make_shared<FunctionBlock>(*this);
1146   // In order to keep effect order in the sub-graphs which generated by control flow.
1147   // We copy the flags from the top graph to the sub-graphs.
1148   if (func_graph_ && !func_graph_->attrs().empty()) {
1149     for (const auto &attr : func_graph_->attrs()) {
1150       // The flag FUNC_GRAPH_OUTPUT_NO_RECOMPUTE should be only set in the top graph.
1151       if (attr.first != FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) {
1152         block->func_graph()->set_attr(attr.first, attr.second);
1153       }
1154     }
1155   }
1156   func_block_list_.push_back(block);
1157   return block;
1158 }
1159 
MakeFunctionBlock(const TraceInfoPtr & trace_info)1160 FunctionBlockPtr Parser::MakeFunctionBlock(const TraceInfoPtr &trace_info) {
1161   TraceGuard trace_guard(trace_info);
1162   FunctionBlockPtr block = MakeFunctionBlock();
1163   return block;
1164 }
1165 
MakeConditionBlocks(const FunctionBlockPtr & pre_block,const FunctionBlockPtr & true_block,const FunctionBlockPtr & false_block) const1166 void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
1167                                  const FunctionBlockPtr &false_block) const {
1168   MS_EXCEPTION_IF_NULL(true_block);
1169   MS_EXCEPTION_IF_NULL(false_block);
1170   true_block->AddPrevBlock(pre_block);
1171   true_block->Mature();
1172 
1173   false_block->AddPrevBlock(pre_block);
1174   false_block->Mature();
1175 
1176   true_block->UpdateGlobalPyParam(pre_block->global_py_params());
1177   false_block->UpdateGlobalPyParam(pre_block->global_py_params());
1178 }
1179 
ParseReturn(const FunctionBlockPtr & block,const py::object & node)1180 FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
1181   MS_LOG(DEBUG) << "Process ast return";
1182   MS_EXCEPTION_IF_NULL(block);
1183   // Parse the return Statements value.
1184   py::object value_object = python_adapter::GetPyObjAttr(node, "value");
1185   AnfNodePtr return_expr_node = ParseExprNode(block, value_object);
1186   // Create the `return` CNode.
1187   auto func_graph = block->func_graph();
1188   CNodePtr return_cnode = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), return_expr_node});
1189   func_graph->set_return(return_cnode);
1190   MS_LOG(DEBUG) << "Inside the block has return statement, block: " << block->ToString();
1191   block->set_is_return_statement_inside();
1192   return block;
1193 }
1194 
1195 // Process binary operators,eg: `a + b`, `a | b`, etc.
ParseBinOp(const FunctionBlockPtr & block,const py::object & node)1196 AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) {
1197   MS_LOG(DEBUG) << "Process ast BinOP";
1198 
1199   MS_EXCEPTION_IF_NULL(block);
1200   py::object left = python_adapter::GetPyObjAttr(node, "left");
1201   py::object right = python_adapter::GetPyObjAttr(node, "right");
1202   py::object op = python_adapter::GetPyObjAttr(node, "op");
1203   // Create left and right ANF node
1204   AnfNodePtr left_node = ParseExprNode(block, left);
1205   if (left_node == nullptr) {
1206     MS_LOG(INTERNAL_EXCEPTION) << "DoBinOp process left node failed: " << errcode();
1207   }
1208   AnfNodePtr right_node = ParseExprNode(block, right);
1209   if (right_node == nullptr) {
1210     MS_LOG(INTERNAL_EXCEPTION) << "DoBinOp process right node failed:" << errcode();
1211   }
1212   // Resolve the op
1213   const auto &ns = block->GetAstOpNameSpace(op);
1214   auto op_node = block->MakeResolveAstOpNameSpace(ns);
1215 
1216   // Create apply node
1217   MS_EXCEPTION_IF_NULL(block->func_graph());
1218   AnfNodePtr new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
1219   // Handling % symbol in formatted string values by JIT Fallback.
1220   // The string AnfNode may be created by ParseJoinedStr or ParseStr.
1221   // For example, string % var, f"The string is: %s." % str  or "The number is: %d." % num
1222   constexpr size_t symbol_index = 1;
1223   SymbolPtr symbol = std::make_shared<Symbol>(ns[symbol_index].cast<std::string>());
1224   // Only support the pattern (string % xxx) by fallback.
1225   if (symbol != nullptr && symbol->symbol() == "mod") {
1226     if (IsPrimitiveCNode(left_node, prim::kPrimJoinedStr)) {
1227       // left_node created by ParseJoinedStr
1228       auto inputs = left_node->cast<CNodePtr>()->inputs();
1229       if (inputs.size() <= 1) {
1230         MS_LOG(INTERNAL_EXCEPTION) << "Unexpected maketuple node:" << left_node->DebugString();
1231       }
1232       auto str_node = inputs[1];
1233       if (IsValueNode<StringImm>(str_node)) {
1234         new_node->set_interpret(true);
1235         auto new_interpret_node = HandleInterpret(block, new_node, node);
1236         return new_interpret_node;
1237       }
1238     } else if (IsValueNode<StringImm>(left_node)) {
1239       // left_node created by ParseStr
1240       new_node->set_interpret(true);
1241       auto new_interpret_node = HandleInterpret(block, new_node, node);
1242       return new_interpret_node;
1243     }
1244   }
1245   return new_node;
1246 }
1247 
ParseName(const FunctionBlockPtr & block,const py::object & node)1248 AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) {
1249   MS_LOG(DEBUG) << "Process ast Name";
1250   auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
1251   MS_LOG(DEBUG) << "The Name id is " << name_id;
1252   MS_EXCEPTION_IF_NULL(block);
1253   // The Tensor object will be parsed into an Interpret node. For example, Tensor(0).astype("int32")
1254   if (block->IsGlobalVar(name_id) || name_id == "Tensor") {
1255     MS_LOG(DEBUG) << "name_id: " << name_id;
1256     AnfNodePtr res = block->MakeResolveSymbol(name_id);
1257     block->CheckUndefinedSymbol(name_id, res);
1258     return res;
1259   }
1260 
1261   AnfNodePtr res = block->ReadVariable(name_id);
1262   block->CheckUndefinedSymbol(name_id, res);
1263   return res;
1264 }
1265 
ParseNone(const FunctionBlockPtr &,const py::object &)1266 AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
1267   MS_LOG(DEBUG) << "Process ast NoneType";
1268   return NewValueNode(kNone);
1269 }
1270 
ParseEllipsis(const FunctionBlockPtr &,const py::object &)1271 AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
1272   MS_LOG(DEBUG) << "Process ast Ellipsis";
1273   return NewValueNode(kEllipsis);
1274 }
1275 
ParseNum(const FunctionBlockPtr &,const py::object & node)1276 AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
1277   MS_LOG(DEBUG) << "Process ast Num";
1278   py::object obj = python_adapter::GetPyObjAttr(node, "n");
1279   if (py::isinstance<py::int_>(obj)) {
1280     MS_LOG(INFO) << "The Num is int64_t:" << (std::string)py::str(obj);
1281     auto data = py::cast<int64_t>(obj);
1282     return NewValueNode(data);
1283   } else if (py::isinstance<py::float_>(obj)) {
1284     MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj);
1285     auto data = py::cast<float>(obj);
1286     auto res = NewValueNode(data);
1287     auto fp32_val = res->value()->cast<FP32ImmPtr>();
1288     if (fp32_val != nullptr) {
1289       MS_LOG(DEBUG) << "Set float64 value to FP32Imm.";
1290       fp32_val->set_prim_value(py::cast<double>(obj));
1291     }
1292     return res;
1293   } else {
1294     // no else actually
1295     errcode_ = PARSE_NODE_TYPE_UNKNOWN;
1296     MS_EXCEPTION(TypeError) << "Only support 'Number' type of 'int` and 'float', but got type: " << obj.get_type()
1297                             << " Value:" << py::str(obj);
1298   }
1299 }
1300 
ParseStr(const FunctionBlockPtr &,const py::object & node)1301 AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) {
1302   MS_LOG(DEBUG) << "Process ast Str";
1303   auto str_s = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "s"));
1304   return NewValueNode(str_s);
1305 }
1306 
ParseConstant(const FunctionBlockPtr &,const py::object & node)1307 AnfNodePtr Parser::ParseConstant(const FunctionBlockPtr &, const py::object &node) {
1308   MS_LOG(DEBUG) << "Process ast Constant";
1309   py::object obj = python_adapter::GetPyObjAttr(node, "value");
1310   if (py::isinstance<py::bool_>(obj)) {
1311     MS_LOG(INFO) << "The Constant is bool:" << (std::string)py::str(obj);
1312     return NewValueNode(py::cast<bool>(obj));
1313   } else if (py::isinstance<py::int_>(obj)) {
1314     MS_LOG(INFO) << "The Constant is int64_t:" << (std::string)py::str(obj);
1315     return NewValueNode(py::cast<int64_t>(obj));
1316   } else if (py::isinstance<py::float_>(obj)) {
1317     MS_LOG(INFO) << "The Constant is float:" << (std::string)py::str(obj);
1318     auto data = py::cast<float>(obj);
1319     auto res = NewValueNode(data);
1320     auto fp32_val = res->value()->cast<FP32ImmPtr>();
1321     if (fp32_val != nullptr) {
1322       MS_LOG(DEBUG) << "Set float64 value to FP32Imm.";
1323       fp32_val->set_prim_value(py::cast<double>(obj));
1324     }
1325     return res;
1326   } else if (py::isinstance<py::str>(obj)) {
1327     MS_LOG(INFO) << "The Constant is string:" << (std::string)py::str(obj);
1328     return NewValueNode(py::cast<std::string>(obj));
1329   } else if (py::isinstance<py::none>(obj)) {
1330     MS_LOG(INFO) << "The Constant is none:" << (std::string)py::str(obj);
1331     return NewValueNode(kNone);
1332   } else if (py::isinstance<py::ellipsis>(obj)) {
1333     MS_LOG(INFO) << "The Constance is ellipsis:" << (std::string)py::str(obj);
1334     return NewValueNode(kEllipsis);
1335   } else {
1336     // no else actually
1337     MS_EXCEPTION(TypeError) << "Unsupported Constant type : " << (std::string)py::str(obj);
1338   }
1339 }
1340 
ParseNameConstant(const FunctionBlockPtr &,const py::object & node)1341 AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) {
1342   MS_LOG(DEBUG) << "Process ast NameConstant";
1343   py::object obj = python_adapter::GetPyObjAttr(node, "value");
1344   if (py::isinstance<py::bool_>(obj)) {
1345     MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj);
1346     auto data = py::cast<bool>(obj);
1347     return NewValueNode(data);
1348   } else if (py::isinstance<py::none>(obj)) {
1349     MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
1350     return NewValueNode(kNone);
1351   }
1352   // no else actually
1353   errcode_ = PARSE_NODE_TYPE_UNKNOWN;
1354   MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
1355 }
1356 
GenerateMakeTuple(const FunctionBlockPtr & block,const std::vector<AnfNodePtr> & element_nodes)1357 AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
1358   MS_EXCEPTION_IF_NULL(block);
1359   AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
1360   std::vector<AnfNodePtr> make_tuple_nodes;
1361   make_tuple_nodes.push_back(make_tuple_op);
1362   (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes),
1363                        [](AnfNodePtr arg) -> AnfNodePtr { return arg; });
1364   MS_EXCEPTION_IF_NULL(block->func_graph());
1365   return block->func_graph()->NewCNodeInOrder(std::move(make_tuple_nodes));
1366 }
1367 
ParseSuper(const FunctionBlockPtr & block,const py::list & args)1368 AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
1369   MS_EXCEPTION_IF_NULL(block);
1370   py::object father_class;
1371   const size_t expect_args_size = 2;
1372   if (args.empty()) {
1373     father_class = py::none();
1374   } else if (args.size() == expect_args_size) {
1375     father_class = args[0];
1376     auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[1])));
1377     if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") {
1378       MS_EXCEPTION(ArgumentError) << "Argument 2 of 'super()' must be 'self', but got '"
1379                                   << py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) << "'.";
1380     }
1381   } else {
1382     MS_EXCEPTION(ArgumentError) << "Arguments number of 'super()' should be 0 or 2, but got " << args.size() << ".";
1383   }
1384   py::object target_class_instance = ast_->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast_->obj());
1385   py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
1386   NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
1387   SymbolPtr symbol = std::make_shared<Symbol>("namespace");
1388   MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
1389   return block->MakeResolve(name_space, symbol);
1390 }
1391 
HandleStrInError(const FunctionBlockPtr & block,const py::list & args,std::vector<AnfNodePtr> * str_nodes)1392 void Parser::HandleStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *str_nodes) {
1393   for (size_t i = 0; i < args.size(); ++i) {
1394     AnfNodePtr node = ParseExprNode(block, args[i]);
1395     (void)str_nodes->emplace_back(node);
1396   }
1397 }
1398 
HandleException(const FunctionBlockPtr & block,const py::list & args,const std::string & name)1399 std::vector<AnfNodePtr> Parser::HandleException(const FunctionBlockPtr &block, const py::list &args,
1400                                                 const std::string &name) {
1401   auto exception_type_node = NewValueNode(name);
1402   std::vector<AnfNodePtr> node_inputs = {exception_type_node};
1403   HandleStrInError(block, args, &node_inputs);
1404   return node_inputs;
1405 }
1406 
ParseRaiseCall(const FunctionBlockPtr & block,const py::object & node)1407 std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, const py::object &node) {
1408   MS_LOG(DEBUG) << "Process ast Call, the current node is raise.";
1409   // Process function call
1410   py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
1411   // Process raise ValueError
1412   if (py::isinstance<py::none>(function_ast_node)) {
1413     auto name = python_adapter::GetPyObjAttr(node, "id");
1414     auto name_id = py::cast<std::string>(name);
1415     if (block->CheckHasVariable(name_id)) {
1416       auto error_node = block->ReadVariable(name_id);
1417       block->CheckUndefinedSymbol(name_id, error_node);
1418       return {NewValueNode(name_id), error_node};
1419     } else if (exception_types_map.find(name_id) != exception_types_map.end()) {
1420       auto str_value = std::make_shared<StringImm>("None");
1421       return {NewValueNode(name_id), NewValueNode(str_value)};
1422     } else {
1423       MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id
1424                         << ". Raise only support some Python standard exception types: "
1425                         << SupportedExceptionsToString();
1426     }
1427   }
1428 
1429   py::list args = python_adapter::GetPyObjAttr(node, "args");
1430 
1431   auto arg_type =
1432     AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
1433   if (arg_type == AST_SUB_TYPE_NAME) {
1434     auto name = python_adapter::GetPyObjAttr(function_ast_node, "id");
1435     auto name_id = py::cast<std::string>(name);
1436     MS_LOG(DEBUG) << "The name of call node is: " << name_id;
1437     auto node_list = HandleException(block, args, name_id);
1438     if (block->CheckHasVariable(name_id)) {
1439       auto error_node = block->ReadVariable(name_id);
1440       block->CheckUndefinedSymbol(name_id, error_node);
1441       (void)node_list.emplace_back(error_node);
1442       return node_list;
1443     } else if (exception_types_map.find(name_id) != exception_types_map.end()) {
1444       auto str_value = std::make_shared<StringImm>("None");
1445       (void)node_list.emplace_back(NewValueNode(str_value));
1446       return node_list;
1447     } else {
1448       MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id
1449                         << ". Raise only support some Python standard exception types: "
1450                         << SupportedExceptionsToString();
1451     }
1452   }
1453   return {};
1454 }
1455 
CompareIs(const FunctionBlockPtr &,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1456 bool Parser::CompareIs(const FunctionBlockPtr &, const py::object &left_obj, const py::object &comparator_obj,
1457                        bool *bool_res) const {
1458   auto comparator_type_name = ast_->GetNodeType(comparator_obj)->node_name();
1459   // The type_name is "Constant" in py3.9.
1460   if (comparator_type_name != "NameConstant" && comparator_type_name != "Constant") {
1461     return false;
1462   }
1463   // xxx is None, the comparator must be a NameConstant.
1464   py::object name_constant_value = python_adapter::GetPyObjAttr(comparator_obj, "value");
1465   MS_LOG(DEBUG) << "name_constant_value: " << py::str(name_constant_value);
1466 
1467   // Compare with None.
1468   if (py::isinstance<py::none>(name_constant_value)) {
1469     *bool_res = py::isinstance<py::none>(left_obj);
1470     return true;
1471   }
1472   // To add more NameConstants.
1473   return false;
1474 }
1475 
CompareIsNot(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1476 bool Parser::CompareIsNot(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
1477                           bool *bool_res) const {
1478   if (!CompareIs(block, left_obj, comparator_obj, bool_res)) {
1479     return false;
1480   }
1481   *bool_res = !(*bool_res);
1482   return true;
1483 }
1484 
CompareEqual(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1485 bool Parser::CompareEqual(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
1486                           bool *bool_res) const {
1487   auto left_obj_type_name = ast_->GetNodeType(left_obj)->node_name();
1488   if (left_obj_type_name == "Tensor" || left_obj_type_name == "Parameter") {
1489     return false;
1490   }
1491   auto comparator_type_name = ast_->GetNodeType(comparator_obj)->node_name();
1492   MS_LOG(DEBUG) << "comparator_type_name: " << comparator_type_name;
1493   if (comparator_type_name == "Num") {
1494     py::object num_value = python_adapter::GetPyObjAttr(comparator_obj, "n");
1495     MS_LOG(DEBUG) << "num_value: " << py::str(num_value);
1496     if (!py::isinstance<py::int_>(num_value) && !py::isinstance<py::float_>(num_value)) {
1497       return false;
1498     }
1499     *bool_res = left_obj.equal(num_value);
1500     return true;
1501   }
1502   if (comparator_type_name == "Str") {
1503     if (!py::isinstance<py::str>(left_obj)) {
1504       *bool_res = false;
1505       return true;
1506     }
1507     py::object str_value = python_adapter::GetPyObjAttr(comparator_obj, "s");
1508     auto left_obj_str = left_obj.cast<std::string>();
1509     *bool_res = (left_obj_str == str_value.cast<std::string>());
1510     return true;
1511   }
1512   if (comparator_type_name == "NameConstant") {
1513     py::object name_constant_value = python_adapter::GetPyObjAttr(comparator_obj, "value");
1514     MS_LOG(DEBUG) << "name_constant_value: " << py::str(name_constant_value);
1515     if (!py::isinstance<py::none>(name_constant_value)) {
1516       return false;
1517     }
1518     *bool_res = py::isinstance<py::none>(left_obj);
1519     return true;
1520   }
1521   if (comparator_type_name == "Attribute") {
1522     bool is_constant;
1523     auto attr_cond = GetPyObjForAstAttr(block, comparator_obj, &is_constant);
1524     if (!is_constant) {
1525       return false;
1526     }
1527     *bool_res = left_obj.equal(attr_cond);
1528     return true;
1529   }
1530   // The type_name is "Constant" in py3.9.
1531   if (comparator_type_name == "Constant") {
1532     py::object constant_value = python_adapter::GetPyObjAttr(comparator_obj, "value");
1533     MS_LOG(DEBUG) << "constant_value: " << py::str(constant_value);
1534     if (!py::isinstance<py::int_>(constant_value) && !py::isinstance<py::float_>(constant_value) &&
1535         !py::isinstance<py::str>(constant_value)) {
1536       return false;
1537     }
1538     *bool_res = left_obj.equal(constant_value);
1539     return true;
1540   }
1541   return false;
1542 }
1543 
CompareNotEqual(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1544 bool Parser::CompareNotEqual(const FunctionBlockPtr &block, const py::object &left_obj,
1545                              const py::object &comparator_obj, bool *bool_res) const {
1546   if (!CompareEqual(block, left_obj, comparator_obj, bool_res)) {
1547     return false;
1548   }
1549   *bool_res = !(*bool_res);
1550   return true;
1551 }
1552 
CompareGreater(const FunctionBlockPtr &,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1553 bool Parser::CompareGreater(const FunctionBlockPtr &, const py::object &left_obj, const py::object &comparator_obj,
1554                             bool *bool_res) const {
1555   auto comparator_type_name = ast_->GetNodeType(comparator_obj)->node_name();
1556   if (comparator_type_name != "Num" || (!py::isinstance<py::int_>(left_obj) && !py::isinstance<py::float_>(left_obj))) {
1557     return false;
1558   }
1559   py::object num_value = python_adapter::GetPyObjAttr(comparator_obj, "n");
1560   MS_LOG(DEBUG) << "num_value: " << py::str(num_value);
1561 
1562   if (!py::isinstance<py::int_>(num_value) && !py::isinstance<py::float_>(num_value)) {
1563     return false;
1564   }
1565   *bool_res = (left_obj > num_value);
1566   return true;
1567 }
1568 
CompareGreaterEqual(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1569 bool Parser::CompareGreaterEqual(const FunctionBlockPtr &block, const py::object &left_obj,
1570                                  const py::object &comparator_obj, bool *bool_res) const {
1571   bool greater = false;
1572   bool equal = false;
1573   if (!CompareGreater(block, left_obj, comparator_obj, &greater) ||
1574       !CompareEqual(block, left_obj, comparator_obj, &equal)) {
1575     return false;
1576   }
1577   if (greater || equal) {
1578     *bool_res = true;
1579   } else {
1580     *bool_res = false;
1581   }
1582   return true;
1583 }
1584 
CompareLess(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1585 bool Parser::CompareLess(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
1586                          bool *bool_res) const {
1587   bool greater = false;
1588   bool equal = false;
1589   if (!CompareGreater(block, left_obj, comparator_obj, &greater) ||
1590       !CompareEqual(block, left_obj, comparator_obj, &equal)) {
1591     return false;
1592   }
1593   if (greater || equal) {
1594     *bool_res = false;
1595   } else {
1596     *bool_res = true;
1597   }
1598   return true;
1599 }
1600 
CompareLessEqual(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1601 bool Parser::CompareLessEqual(const FunctionBlockPtr &block, const py::object &left_obj,
1602                               const py::object &comparator_obj, bool *bool_res) const {
1603   bool greater = false;
1604   if (!CompareGreater(block, left_obj, comparator_obj, &greater)) {
1605     return false;
1606   }
1607   if (greater) {
1608     *bool_res = false;
1609   } else {
1610     *bool_res = true;
1611   }
1612   return true;
1613 }
1614 
GetParameterValue(const AnfNodePtr & parameter) const1615 ValuePtr Parser::GetParameterValue(const AnfNodePtr &parameter) const {
1616   if (args_value_list_.empty()) {
1617     return nullptr;
1618   }
1619   const auto &parameters = func_graph_->parameters();
1620   for (size_t i = 0; i < parameters.size(); ++i) {
1621     if (parameters.at(i) == parameter && i < args_value_list_.size()) {
1622       return args_value_list_[i];
1623     }
1624   }
1625   return nullptr;
1626 }
1627 
GetBoolObjForAstCompare(const FunctionBlockPtr & block,const py::object & node,bool * bool_res) const1628 bool Parser::GetBoolObjForAstCompare(const FunctionBlockPtr &block, const py::object &node, bool *bool_res) const {
1629   MS_EXCEPTION_IF_NULL(bool_res);
1630   MS_EXCEPTION_IF_NULL(block);
1631   py::list ops = python_adapter::GetPyObjAttr(node, "ops");
1632   if (ops.size() != 1) {
1633     return false;
1634   }
1635   py::object op = ops[0];
1636   py::tuple namespace_var = ast()->CallParseModFunction(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op);
1637   constexpr size_t namespace_size = 3;
1638   if (namespace_var.size() != namespace_size) {
1639     MS_LOG(INTERNAL_EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size();
1640   }
1641   constexpr size_t op_str_index = 2;
1642   std::string op_str = py::str(namespace_var[op_str_index]);
1643   MS_LOG(DEBUG) << "op: " << py::str(op) << ", " << op_str;
1644   auto func_iter = compare_method_map_.find(op_str);
1645   if (func_iter == compare_method_map_.end()) {
1646     return false;
1647   }
1648 
1649   py::object left = python_adapter::GetPyObjAttr(node, "left");
1650   py::object left_obj;
1651   auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, left)));
1652   if (arg_type == AST_SUB_TYPE_ATTRIBUTE) {
1653     bool is_constant;
1654     left_obj = GetPyObjForAstAttr(block, left, &is_constant);
1655     if (!is_constant) {
1656       return false;
1657     }
1658   } else {
1659     MS_LOG(DEBUG) << "Not attribute, attr_ast_node: " << py::str(left);
1660     py::object id = python_adapter::GetPyObjAttr(left, "id");
1661     if (!py::isinstance<py::str>(id)) {
1662       return false;
1663     }
1664 
1665     auto anf_node = block->ReadVariable(id.cast<std::string>());
1666     if (anf_node == nullptr) {
1667       return false;
1668     }
1669     if (anf_node->isa<ValueNode>()) {
1670       MS_LOG(DEBUG) << "left value node: " << anf_node->DebugString();
1671       left_obj = ValueToPyData(anf_node->cast_ptr<ValueNode>()->value());
1672     } else if (anf_node->isa<Parameter>()) {
1673       MS_LOG(DEBUG) << "left parameter node: " << anf_node->DebugString();
1674       auto value = GetParameterValue(anf_node);
1675       if (value == nullptr || value->ContainsValueAny()) {
1676         return false;
1677       }
1678       left_obj = ValueToPyData(value);
1679     } else {
1680       return false;
1681     }
1682   }
1683   MS_LOG(DEBUG) << "left_obj: " << py::str(left_obj);
1684 
1685   py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
1686   if (comparators.size() != 1) {
1687     return false;
1688   }
1689   return (this->*(func_iter->second))(block, left_obj, comparators[0], bool_res);
1690 }
1691 
GetPyObjForAstAttr(const FunctionBlockPtr & block,const py::object & attr_ast_node,bool * is_constant) const1692 py::object Parser::GetPyObjForAstAttr(const FunctionBlockPtr &block, const py::object &attr_ast_node,
1693                                       bool *is_constant) const {
1694   auto attr_value = python_adapter::GetPyObjAttr(attr_ast_node, "value");
1695   auto attr_value_type =
1696     AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, attr_value)));
1697   if (attr_value_type != AST_SUB_TYPE_NAME) {
1698     MS_LOG(DEBUG) << "attr_value: " << py::str(attr_value);
1699     *is_constant = false;
1700     return py::none();
1701   }
1702 
1703   auto value_name = py::cast<std::string>(python_adapter::GetPyObjAttr(attr_value, "id"));
1704   auto attr_name = py::cast<std::string>(python_adapter::GetPyObjAttr(attr_ast_node, "attr"));
1705   MS_LOG(DEBUG) << "attr name: " << value_name << "." << attr_name;
1706   py::object py_obj_attr_value = py::none();
1707   if (value_name != "self") {
1708     auto node = block->ReadVariable(value_name);
1709     if (node != nullptr && (node->isa<Parameter>() || IsPrimitiveCNode(node, prim::kPrimMixedPrecisionCast))) {
1710       *is_constant = false;
1711       return py::none();
1712     }
1713     py::tuple attr_namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value_name);
1714     constexpr size_t global_info_size = 4;
1715     // Handle nested function def.
1716     if (attr_namespace_info.size() == global_info_size) {
1717       constexpr size_t value_index = 2;
1718       py_obj_attr_value = attr_namespace_info[value_index];
1719     }
1720   } else {
1721     auto iter = setattr_nodes_map_.find(value_name);
1722     if (iter != setattr_nodes_map_.end()) {
1723       if (iter->second.find(attr_name) != iter->second.end()) {
1724         MS_LOG(DEBUG) << "The self." << attr_name << "has been modified.";
1725         *is_constant = false;
1726         return py::none();
1727       }
1728     }
1729     py_obj_attr_value = ast_->obj();
1730   }
1731   if (py::isinstance<py::none>(py_obj_attr_value) || !py::hasattr(py_obj_attr_value, py::str(attr_name))) {
1732     MS_LOG(DEBUG) << "Not found object for attribute, attr_ast_node: " << py::str(attr_ast_node);
1733     *is_constant = false;
1734     return py::none();
1735   }
1736   *is_constant = true;
1737   return python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_ATTR_FROM_OBJ, py_obj_attr_value,
1738                                      py::str(attr_name));
1739 }
1740 
1741 // Process function call, eg : f1(x, y) ...
ParseCall(const FunctionBlockPtr & block,const py::object & node)1742 AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) {
1743   MS_LOG(DEBUG) << "Process ast Call";
1744   // Process function call
1745   py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
1746   py::list args = python_adapter::GetPyObjAttr(node, "args");
1747 
1748   std::string name_id = "";
1749   auto arg_type =
1750     AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
1751   if (arg_type == AST_SUB_TYPE_NAME) {
1752     name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
1753     MS_LOG(DEBUG) << "The name of call node is: " << name_id;
1754     if (name_id == "super") {
1755       return ParseSuper(block, args);
1756     }
1757   }
1758   MS_LOG(DEBUG) << "Process ast Call, name_id: " << name_id;
1759   auto call_function_node = ParseExprNode(block, function_ast_node);
1760 
1761   // Function call arguments should be passed in as groups and unpacked later using unpack call
1762   ArgsContext args_context = ArgsContext();
1763   ParseArgsInCall(block, args, &args_context);
1764   ParseKeywordsInCall(block, node, &args_context);
1765 
1766   // If the expression is to create Tensor(including adapter tensor) without jit annotation,
1767   // using functional api to create corresponding tensor since the functional api has jit annotation.
1768   auto class_tensor_object = call_function_node->user_data<py::object>(kClassTensorObject);
1769   ClassInstanceType class_tensor_type = CLASS_INSTANCE_TYPE_INVALID;
1770   if (class_tensor_object != nullptr) {
1771     auto call_location = GetLocation(node);
1772     MS_EXCEPTION_IF_NULL(call_location);
1773     const auto &comments = call_location->comments();
1774     if (comments.empty()) {
1775       class_tensor_type = ClassInstanceType(
1776         ast_->CallParserObjMethod(PYTHON_PARSE_GET_CLASS_TENSOR_TYPE, *class_tensor_object).cast<int32_t>());
1777       AnfNodePtr new_call_function_node = nullptr;
1778       if (class_tensor_type == CLASS_INSTANCE_TYPE_TENSOR) {
1779         constexpr auto tensor_func_str = "__ms_tensor_func__";
1780         new_call_function_node = block->MakeResolveSymbol(tensor_func_str);
1781       } else if (class_tensor_type == CLASS_INSTANCE_TYPE_ADAPTER_TENSOR) {
1782         constexpr auto adapter_convert_function = "get_adapter_convert_function";
1783         py::object generate_func = ast_->CallParserObjMethod(adapter_convert_function, *class_tensor_object);
1784         if (!py::isinstance<py::none>(generate_func)) {
1785           new_call_function_node = NewValueNode(ParsePythonCode(generate_func));
1786         }
1787       }
1788       if (new_call_function_node != nullptr) {
1789         MS_LOG(INFO) << "Convert Tensor call node " << call_function_node->DebugString()
1790                      << " to functional tensor call node " << new_call_function_node->DebugString();
1791         return GenerateAnfNodeForCall(block, new_call_function_node, args_context);
1792       }
1793     }
1794   }
1795 
1796   auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, args_context);
1797   MS_EXCEPTION_IF_NULL(call_cnode);
1798   MS_LOG(DEBUG) << "call_cnode: " << call_cnode->DebugString()
1799                 << ", call_function_node: " << call_function_node->DebugString();
1800 
1801   // Process bulitin function, for example, sum(np.array(xx))
1802   py::tuple namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, name_id);
1803   constexpr size_t global_info_size = 4;
1804   if (namespace_info.size() == global_info_size) {
1805     constexpr size_t flag_index = 3;
1806     auto syntax_support = namespace_info[flag_index].cast<int32_t>();
1807     // For print, the inputs to the function determine whether the call_cnode is
1808     // a graph node or the interpret node. If the inputs contain interpret node (not Tensor), the call_cnode will
1809     // be interpretive executived. Otherwise, call_cnode will be a graph node.
1810     if (name_id == "print" && args_context.has_interpret_without_internal) {
1811       call_cnode->set_interpret(true);
1812       // Ensure the order of print
1813       call_cnode = fallback::ConvertCNodeToPyExecuteForPrim(call_cnode->cast<CNodePtr>(), name_id);
1814       return call_cnode;
1815     } else if (syntax_support != SYNTAX_SUPPORTED) {
1816       call_cnode->set_interpret(true);
1817       call_cnode = HandleInterpret(block, call_cnode, node);
1818       // For the unsupported type function, if the input to the function contains tensor, the return value of
1819       // the function should be graph node too.
1820       if (args_context.has_interpret_internal) {
1821         call_cnode->set_interpret_internal_type(true);
1822       }
1823     }
1824   }
1825   if (class_tensor_type == CLASS_INSTANCE_TYPE_ADAPTER_TENSOR) {
1826     MS_LOG(DEBUG) << "Current adapter tensor node: " << call_cnode->DebugString();
1827     call_cnode->set_user_data<bool>(fallback::kAdapterTensor, std::make_shared<bool>(true));
1828   }
1829   return call_cnode;
1830 }
1831 
MakeUnpackCall(const FuncGraphPtr & func_graph,const AnfNodePtr & call_function_node,const std::vector<AnfNodePtr> & packed_arguments)1832 CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_node,
1833                         const std::vector<AnfNodePtr> &packed_arguments) {
1834   MS_EXCEPTION_IF_NULL(func_graph);
1835   std::vector<AnfNodePtr> unpack_call_nodes;
1836   auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
1837   unpack_call_nodes.push_back(unpack_call_op);
1838   unpack_call_nodes.push_back(call_function_node);
1839   (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
1840                        [](AnfNodePtr node) -> AnfNodePtr { return node; });
1841   CNodePtr unpack_call = func_graph->NewCNodeInOrder(std::move(unpack_call_nodes));
1842   return unpack_call;
1843 }
1844 
GenerateAnfNodeForCall(const FunctionBlockPtr & block,const AnfNodePtr & call_function_node,const ArgsContext & args_context) const1845 AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node,
1846                                           const ArgsContext &args_context) const {
1847   // If there is keyword arguments or starred, using an unpack_call op to unpack the argument
1848   MS_EXCEPTION_IF_NULL(block);
1849   if (args_context.need_unpack) {
1850     return MakeUnpackCall(block->func_graph(), call_function_node, args_context.packed_arguments);
1851   }
1852   // else there is no keyword arguments and starred, parsed as normal arguments without unpack
1853   const auto &group_arguments = args_context.group_arguments;
1854   if (group_arguments.size() == 0 && IsPrimitiveCNode(call_function_node, prim::kPrimPyInterpret)) {
1855     // call Interpret node is invalid. Do not new call Interpret node.
1856     // %1 = Interpret_node
1857     // %2 = %1()
1858     return call_function_node;
1859   }
1860   std::vector<AnfNodePtr> func_call_nodes;
1861   func_call_nodes.push_back(call_function_node);
1862   (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes),
1863                        [](AnfNodePtr node) -> AnfNodePtr { return node; });
1864   MS_EXCEPTION_IF_NULL(block->func_graph());
1865   CNodePtr call_anf_node = block->func_graph()->NewCNodeInOrder(std::move(func_call_nodes));
1866   return call_anf_node;
1867 }
1868 
ParseArgsInCall(const FunctionBlockPtr & block,const py::list & args,ArgsContext * args_context)1869 void Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, ArgsContext *args_context) {
1870   MS_LOG(DEBUG) << "Process ast args in call";
1871   MS_EXCEPTION_IF_NULL(args_context);
1872   for (size_t i = 0; i < args.size(); i++) {
1873     auto arg_node = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[i])));
1874     if (arg_node == AST_SUB_TYPE_STARRED) {
1875       if (!args_context->group_arguments.empty()) {
1876         args_context->packed_arguments.push_back(GenerateMakeTuple(block, args_context->group_arguments));
1877       }
1878       args_context->packed_arguments.push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value")));
1879       args_context->group_arguments.clear();
1880       args_context->need_unpack = true;
1881     } else {
1882       MS_LOG(DEBUG) << "args[" << i << "]: " << py::str(args[i]);
1883       AnfNodePtr node = ParseExprNode(block, args[i]);
1884       auto internal = node->interpret_internal_type();
1885       auto interpret_without_internal =
1886         ((node->interpret() || IsPrimitiveCNode(node, prim::kPrimPyInterpret)) && !internal);
1887       if (internal) {
1888         args_context->has_interpret_internal = true;
1889       } else if (interpret_without_internal) {
1890         args_context->has_interpret_without_internal = true;
1891       }
1892       args_context->group_arguments.push_back(node);
1893     }
1894   }
1895   if (!args_context->group_arguments.empty()) {
1896     args_context->packed_arguments.push_back(GenerateMakeTuple(block, args_context->group_arguments));
1897   }
1898 }
1899 
ParseKeywordsInCall(const FunctionBlockPtr & block,const py::object & node,ArgsContext * args_context)1900 void Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, ArgsContext *args_context) {
1901   MS_LOG(DEBUG) << "Process ast key words in call";
1902   py::list keywords = python_adapter::GetPyObjAttr(node, "keywords");
1903   if (!keywords.empty()) {
1904     MS_EXCEPTION_IF_NULL(block);
1905     args_context->need_unpack = true;
1906     std::vector<AnfNodePtr> keys;
1907     std::vector<AnfNodePtr> values;
1908     for (size_t index = 0; index < keywords.size(); index++) {
1909       auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg");
1910       auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value");
1911       if (py::isinstance<py::none>(kw_key)) {
1912         args_context->packed_arguments.push_back(ParseExprNode(block, kw_value));
1913       } else {
1914         auto kw_key_c = kw_key.cast<std::string>();
1915         keys.push_back(NewValueNode(kw_key_c));
1916         auto ret_node = ParseExprNode(block, kw_value);
1917         values.push_back(ret_node);
1918       }
1919     }
1920     if (!keys.empty()) {
1921       auto keys_tuple = GenerateMakeTuple(block, keys);
1922       auto values_tuple = GenerateMakeTuple(block, values);
1923       auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
1924       std::vector<AnfNodePtr> make_dict_nodes = {make_dict_op, keys_tuple, values_tuple};
1925       MS_EXCEPTION_IF_NULL(block->func_graph());
1926       args_context->packed_arguments.push_back(block->func_graph()->NewCNodeInOrder(std::move(make_dict_nodes)));
1927     }
1928   }
1929 }
1930 
ProcessAttributeWithClassMember(const FunctionBlockPtr & block,const py::object & node) const1931 AnfNodePtr Parser::ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node) const {
1932   MS_EXCEPTION_IF_NULL(block);
1933   std::string var_name = "self.";
1934   std::string attr_name = node.attr("attr").cast<std::string>();
1935   (void)var_name.append(attr_name);
1936   MS_LOG(DEBUG) << "var_name: " << var_name;
1937   auto attr_obj = ast()->obj().attr(attr_name.c_str());
1938   bool check_need_resolve = py::hasattr(ast()->obj(), attr_name.c_str()) &&
1939                             (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
1940                              py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
1941                              py::isinstance<py::str>(attr_obj) || data_converter::IsCellInstance(attr_obj));
1942   if (check_need_resolve) {
1943     AnfNodePtr res = block->MakeResolveSymbol(var_name);
1944     block->CheckUndefinedSymbol(var_name, res);
1945     return res;
1946   }
1947   auto var_node = block->ReadVariable(var_name);
1948   block->CheckUndefinedSymbol(var_name, var_node);
1949   // Process numpy array, eg: self.x = np.array([1, 2])
1950   if (py::hasattr(ast()->obj(), attr_name.c_str()) && data_converter::IsNumpyArrayInstance(attr_obj)) {
1951     var_node->set_interpret(true);
1952   }
1953   return var_node;
1954 }
1955 
ParseMsTensor(const FunctionBlockPtr & block,const py::object & node,const py::object & value_body,const AnfNodePtr & value_node)1956 AnfNodePtr Parser::ParseMsTensor(const FunctionBlockPtr &block, const py::object &node, const py::object &value_body,
1957                                  const AnfNodePtr &value_node) {
1958   if (py::hasattr(value_body, "id")) {
1959     std::string module_name = py::cast<std::string>(python_adapter::GetPyObjAttr(value_body, "id"));
1960     py::dict global_dict = const_cast<py::dict &>(block->global_py_params());
1961     if (global_dict.contains(module_name)) {
1962       py::object module_obj = global_dict[py::str(module_name)];
1963       std::string module_str = py::cast<std::string>(py::str(module_obj));
1964       // The module of Tensor imported from MsAdapter could be:
1965       // module 'msadapter' or module 'msadapter.pytorch' and so on.
1966       if (module_str.find("module 'mindspore'") != std::string::npos ||
1967           module_str.find("module 'mindtorch") != std::string::npos ||
1968           module_str.find("module 'msadapter") != std::string::npos) {
1969         std::string script_text = py::cast<std::string>(ast()->GetAstNodeText(node));
1970         AnfNodePtr interpret_node = MakeInterpretNode(block, value_node, script_text);
1971         interpret_node->set_interpret(true);
1972         interpret_node->set_interpret_internal_type(true);
1973         if ((module_str.find("module 'mindtorch") != std::string::npos ||
1974              module_str.find("module 'msadapter") != std::string::npos) &&
1975             py::hasattr(module_obj, "Tensor")) {
1976           py::object tensor_obj = py::getattr(module_obj, "Tensor");
1977           interpret_node->set_user_data<py::object>(kClassTensorObject, std::make_shared<py::object>(tensor_obj));
1978         }
1979         return interpret_node;
1980       }
1981     }
1982   }
1983   return nullptr;
1984 }
1985 
ParseNull(const FunctionBlockPtr & block,const py::object & value_body) const1986 AnfNodePtr Parser::ParseNull(const FunctionBlockPtr &block, const py::object &value_body) const {
1987   if (py::hasattr(value_body, "id")) {
1988     std::string module_name = py::cast<std::string>(python_adapter::GetPyObjAttr(value_body, "id"));
1989     py::dict global_dict = const_cast<py::dict &>(block->global_py_params());
1990     if (global_dict.contains(module_name)) {
1991       py::object module_obj = global_dict[py::str(module_name)];
1992       std::string module_str = py::cast<std::string>(py::str(module_obj));
1993       if (module_str.find("module 'mindspore.common.dtype'") != std::string::npos) {
1994         return NewValueNode(std::make_shared<TypeNull>());
1995       }
1996     }
1997   }
1998   return nullptr;
1999 }
2000 
GetGetAttrVectotFromMap(const std::string & obj_name,const std::string & attr_name)2001 std::vector<AnfNodePtr> Parser::GetGetAttrVectotFromMap(const std::string &obj_name, const std::string &attr_name) {
2002   std::vector<AnfNodePtr> getattr_nodes;
2003   auto iter = getattr_nodes_map_.find(obj_name);
2004   if (iter != getattr_nodes_map_.end()) {
2005     auto attr_iter = iter->second.find(attr_name);
2006     if (attr_iter != iter->second.end()) {
2007       getattr_nodes = attr_iter->second;
2008     }
2009   }
2010   return getattr_nodes;
2011 }
2012 
GetSetAttrFromMap(const std::string & obj_name,const std::string & attr_name)2013 AnfNodePtr Parser::GetSetAttrFromMap(const std::string &obj_name, const std::string &attr_name) {
2014   auto iter = setattr_nodes_map_.find(obj_name);
2015   if (iter != setattr_nodes_map_.end()) {
2016     auto attr_iter = iter->second.find(attr_name);
2017     if (attr_iter != iter->second.end()) {
2018       return attr_iter->second;
2019     }
2020   }
2021   return nullptr;
2022 }
2023 
MakeGetAttrWithInterpret(const std::string & obj_name,const std::string & attr_name,const py::object & getattr_obj,const FuncGraphPtr & cur_fg)2024 AnfNodePtr Parser::MakeGetAttrWithInterpret(const std::string &obj_name, const std::string &attr_name,
2025                                             const py::object &getattr_obj, const FuncGraphPtr &cur_fg) {
2026   AnfNodePtr setattr_node = GetSetAttrFromMap(obj_name, attr_name);
2027   AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
2028   AnfNodePtr attr_node = NewValueNode(attr_name);
2029   AnfNodePtr ret_node = nullptr;
2030   if (setattr_node != nullptr) {
2031     const auto &interpreted_obj = std::make_shared<InterpretedObject>(getattr_obj);
2032     AnfNodePtr value_node = NewValueNode(interpreted_obj);
2033     auto prev_setattr_fg = setattr_node->func_graph();
2034     MS_EXCEPTION_IF_NULL(prev_setattr_fg);
2035     if (prev_setattr_fg != cur_fg) {
2036       ret_node = cur_fg->NewCNodeInOrder({op_node, value_node, attr_node});
2037     } else {
2038       // Only add to new setattr node input if two nodes is in the same graph.
2039       ret_node = cur_fg->NewCNodeInOrder({op_node, value_node, attr_node, setattr_node});
2040     }
2041     ret_node->set_user_data<bool>(fallback::kObjectAttrChange, std::make_shared<bool>(true));
2042   }
2043   return ret_node;
2044 }
2045 
TransPropertyToFunc(const FuncGraphPtr & fg,const AnfNodePtr & node,const py::object & property_net_obj,std::string attr_str)2046 AnfNodePtr TransPropertyToFunc(const FuncGraphPtr &fg, const AnfNodePtr &node, const py::object &property_net_obj,
2047                                std::string attr_str) {
2048   py::object property_func = py::none();
2049   try {
2050     property_func = property_net_obj.attr("__class__").attr(py::str(attr_str));
2051   } catch (const std::exception &e) {
2052     MS_LOG(ERROR) << property_net_obj << " has no attribute " << attr_str;
2053   }
2054   py::object property_func_fget = property_func.attr(py::str("fget"));
2055   auto inner_fg = ParsePythonCode(property_func_fget);
2056   std::vector<AnfNodePtr> new_inputs = {NewValueNode(inner_fg)};
2057   new_inputs.push_back(node);
2058   AnfNodePtr call_func_node = fg->NewCNodeInOrder(new_inputs);
2059   MS_LOG(DEBUG) << "call_func_node:" << call_func_node->DebugString();
2060   return call_func_node;
2061 }
2062 
2063 // Process call attributes of class type define, eg: x.y()
ParseAttribute(const FunctionBlockPtr & block,const py::object & node)2064 AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
2065   MS_LOG(DEBUG) << "Process ast Attribute";
2066   auto cur_fg = block->func_graph();
2067   MS_EXCEPTION_IF_NULL(cur_fg);
2068 
2069   // Process the get attr
2070   // Use the Primitive replace the operation resolve node (getattr),
2071   // because the getattr will eventually be converted to Primitive node
2072   AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
2073 
2074   // Process the node attr
2075   auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
2076   AnfNodePtr attr_node = NewValueNode(attr_str);
2077 
2078   // Process the attr body
2079   py::object value_body = python_adapter::GetPyObjAttr(node, "value");
2080   MS_LOG(DEBUG) << "node: " << node << ", attr: " << attr_str << ", value: " << value_body;
2081 
2082   // if getting class value 'self', eg: self.xx, use self object.
2083   std::string obj_name;
2084   py::object getattr_obj;
2085   const bool &is_self = ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE && ast()->IsClassMemberOfSelf(node);
2086   if (is_self) {
2087     // Check if the current Attribute is decorated by @property.
2088     py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
2089     bool is_property =
2090       (python_adapter::CallPyModFn(mod, PYTHON_PARSE_CHECK_ATTR_IS_PROPERTY, ast()->obj(), attr_str)).cast<bool>();
2091     if (is_property) {
2092       py::object property_net_obj = ast()->obj();
2093       AnfNodePtr value_node = ParseExprNode(block, value_body);
2094       return TransPropertyToFunc(cur_fg, value_node, property_net_obj, attr_str);
2095     }
2096     obj_name = "self";
2097     getattr_obj = ast()->obj();
2098     AnfNodePtr ret_node;
2099     AnfNodePtr getattr_node = MakeGetAttrWithInterpret(obj_name, attr_str, getattr_obj, cur_fg);
2100     // If setattr before, should make the getattr call into PyExecute also.
2101     if (getattr_node != nullptr) {
2102       ret_node = getattr_node;
2103       // if processing class value 'self', but did not find setattr before getattr, convert getattr later
2104     } else {
2105       ret_node = ProcessAttributeWithClassMember(block, node);
2106       (void)getattr_nodes_map_["self"][attr_str].emplace_back(ret_node);
2107     }
2108     ret_node->set_user_data<py::object>("__getattr__", std::make_shared<py::object>(getattr_obj));
2109     return ret_node;
2110   }
2111   // If not self.xx, process the obj, eg: obj.xx
2112   AnfNodePtr value_node = ParseExprNode(block, value_body);
2113   if (value_node == nullptr) {
2114     MS_LOG(INTERNAL_EXCEPTION) << "Parse attribute failed";
2115   }
2116   // Process xxx.Tensor() and xxx is mindspore.
2117   if (attr_str == "Tensor") {
2118     auto res = ParseMsTensor(block, node, value_body, value_node);
2119     if (res != nullptr) {
2120       return res;
2121     }
2122   }
2123   // For stype._null, return TypeNull value node directly.
2124   if (attr_str == "_null") {
2125     auto res = ParseNull(block, value_body);
2126     if (res != nullptr) {
2127       return res;
2128     }
2129   }
2130   // Create the apply node
2131   AnfNodePtr attr_cnode = cur_fg->NewCNodeInOrder({op_node, value_node, attr_node});
2132 
2133   auto value_id_str = GetLocation(value_body)->expr_src();
2134   auto iter = setattr_nodes_map_.find(value_id_str);
2135   if (iter != setattr_nodes_map_.end() && iter->second.find(attr_str) != iter->second.end()) {
2136     attr_cnode->set_user_data<bool>(fallback::kObjectAttrChange, std::make_shared<bool>(true));
2137   }
2138 
2139   // Directly resolve the symbol.
2140   if (IsValueNode<parse::NameSpace>(value_node)) {
2141     auto name_space = GetValueNode<parse::NameSpacePtr>(value_node);
2142     MS_EXCEPTION_IF_NULL(name_space);
2143     auto symbol = std::make_shared<parse::Symbol>(attr_str);
2144     attr_cnode = block->DoResolve(attr_cnode, name_space, symbol);
2145   }
2146 
2147   if (attr_str == "pop") {
2148     list_pop_target_obj_ = value_body;
2149   }
2150   if (py::hasattr(value_body, "id")) {
2151     // Check the value is side effect operate from third-party module. eg: np.load(xx) or ts.save(xxx)
2152     auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(value_body, "id"));
2153     MS_LOG(DEBUG) << "The Name id is " << name_id;
2154     bool is_third_party_side_effect =
2155       ast_->CallParserObjMethod(PYTHON_PARSE_CHECK_THIRD_PARTY_LIBRARY_SIDE_EFFECT, name_id, attr_str).cast<bool>();
2156     if (is_third_party_side_effect) {
2157       auto pyexecute_node = fallback::ConvertCNodeToPyExecuteForPrim(attr_cnode->cast<CNodePtr>(), "getattr");
2158       MS_LOG(DEBUG) << "pyexecute_node:" << pyexecute_node->DebugString();
2159       return pyexecute_node;
2160     }
2161   }
2162   // if getting other object, eg: obj.xx, find object from namespace by name
2163   obj_name = GetLocation(value_body)->expr_src();
2164   try {
2165     py::tuple namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, obj_name);
2166     constexpr size_t value_index = 2;
2167     getattr_obj = namespace_info[value_index];
2168   } catch (const std::exception &e) {
2169     MS_LOG(DEBUG) << obj_name << " is not supported in JIT Fallback. Original steps are processing instead.";
2170     getattr_obj = py::none();
2171   }
2172   const bool got_obj = !py::isinstance<py::none>(getattr_obj);
2173   if (got_obj) {
2174     AnfNodePtr getattr_node = MakeGetAttrWithInterpret(obj_name, attr_str, getattr_obj, cur_fg);
2175     // If setattr before, should make the getattr call into PyExecute also.
2176     if (getattr_node != nullptr) {
2177       attr_cnode = getattr_node;
2178     } else {
2179       // if getting attr from other obj, but did not find setattr before getattr, convert getattr later
2180       (void)getattr_nodes_map_[GetLocation(value_body)->expr_src()][attr_str].emplace_back(attr_cnode);
2181     }
2182     attr_cnode->set_user_data<py::object>("__getattr__", std::make_shared<py::object>(getattr_obj));
2183   }
2184   return attr_cnode;
2185 }
2186 
2187 // Process comparison expression : a == b. a > b  etc.
ParseCompare(const FunctionBlockPtr & block,const py::object & node)2188 AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) {
2189   MS_LOG(DEBUG) << "Process ast Compare";
2190 
2191   py::list ops = python_adapter::GetPyObjAttr(node, "ops");
2192   py::object left = python_adapter::GetPyObjAttr(node, "left");
2193   py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
2194   if (ops.size() == 0) {
2195     MS_LOG(INTERNAL_EXCEPTION) << "Parse ast Compare failed, found no ops.";
2196   }
2197   if (comparators.size() == 0) {
2198     MS_LOG(INTERNAL_EXCEPTION) << "Parse ast Compare failed, found no comparators.";
2199   }
2200   if (ops.size() != comparators.size()) {
2201     MS_LOG(INTERNAL_EXCEPTION) << "Parse ast Compare failed, length of ops and comparators not equal, len of ops: "
2202                                << ops.size() << " and length of comparators: " << comparators.size();
2203   }
2204 
2205   auto first_left = left;
2206   auto first_right = comparators[0];
2207   auto first_op = ops[0];
2208   auto first_compare_node = ParseSingleCompare(block, first_left, first_right, first_op);
2209   if (ops.size() == 1) {
2210     // For single compare, such as x < y.
2211     return first_compare_node;
2212   }
2213 
2214   // For multiple compare, such as x < y <= z,
2215   // convert it to x < y and y <= z.
2216   std::vector<AnfNodePtr> compare_nodes = {first_compare_node};
2217   for (size_t i = 1; i < ops.size(); ++i) {
2218     auto cur_left = comparators[i - 1];
2219     auto cur_right = comparators[i];
2220     auto cur_op = ops[i];
2221     auto cur_compare_node = ParseSingleCompare(block, cur_left, cur_right, cur_op);
2222     (void)compare_nodes.emplace_back(cur_compare_node);
2223   }
2224 
2225   AnfNodePtr ret_node = compare_nodes[0];
2226   for (size_t i = 1; i < compare_nodes.size(); ++i) {
2227     ret_node = ConnectSingleCompare(block, ret_node, compare_nodes[i]);
2228   }
2229 
2230   return ret_node;
2231 }
2232 
ParseSingleCompare(const FunctionBlockPtr & block,const py::object & left,const py::object & right,const py::object & compare_op)2233 AnfNodePtr Parser::ParseSingleCompare(const FunctionBlockPtr &block, const py::object &left, const py::object &right,
2234                                       const py::object &compare_op) {
2235   MS_LOG(DEBUG) << "Process ast Compare with single comparators";
2236 
2237   AnfNodePtr left_node = ParseExprNode(block, left);
2238   AnfNodePtr right_node = ParseExprNode(block, right);
2239 
2240   MS_EXCEPTION_IF_NULL(block);
2241   const auto &ns = block->GetAstOpNameSpace(compare_op);
2242   auto op_node = block->MakeResolveAstOpNameSpace(ns);
2243 
2244   MS_EXCEPTION_IF_NULL(block->func_graph());
2245   return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
2246 }
2247 
ConnectSingleCompare(const FunctionBlockPtr & block,const AnfNodePtr & left_node,const AnfNodePtr & right_node)2248 AnfNodePtr Parser::ConnectSingleCompare(const FunctionBlockPtr &block, const AnfNodePtr &left_node,
2249                                         const AnfNodePtr &right_node) {
2250   // Connect two compare result with 'and'.
2251   MS_LOG(DEBUG) << "Connect single compare node.";
2252 
2253   MS_EXCEPTION_IF_NULL(left_node);
2254   MS_EXCEPTION_IF_NULL(right_node);
2255   FunctionBlockPtr true_block = nullptr;
2256   FunctionBlockPtr false_block = nullptr;
2257   auto block_fg = block->func_graph();
2258   MS_EXCEPTION_IF_NULL(block_fg);
2259   {
2260     TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block_fg->debug_info()));
2261     true_block = MakeFunctionBlock();
2262   }
2263   {
2264     TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block_fg->debug_info()));
2265     false_block = MakeFunctionBlock();
2266   }
2267   MakeConditionBlocks(block, true_block, false_block);
2268   MS_EXCEPTION_IF_NULL(true_block->func_graph());
2269   MS_EXCEPTION_IF_NULL(false_block->func_graph());
2270   true_block->func_graph()->set_output(right_node);
2271   TraceGuard trace_guard(std::make_shared<TraceCopy>(left_node->debug_info()));
2272   false_block->func_graph()->set_output(left_node);
2273 
2274   AnfNodePtr cond_node = block->ForceToCondNode(left_node);
2275 
2276   auto switch_app =
2277     block_fg->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
2278                                NewValueNode(false_block->func_graph())});
2279 
2280   std::vector<AnfNodePtr> call_graph_nodes{switch_app};
2281   auto switch_app_call = block_fg->NewCNodeInOrder(std::move(call_graph_nodes));
2282   return switch_app_call;
2283 }
2284 
ProcessBoolOpValueList(const FunctionBlockPtr & block,const py::list & value_list,AstSubType mode)2285 AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
2286   // If there is only one bool op now
2287   MS_EXCEPTION_IF_NULL(block);
2288   if (value_list.empty()) {
2289     MS_LOG(INTERNAL_EXCEPTION) << "value list is empty.";
2290   }
2291   if (value_list.size() == 1) {
2292     AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
2293     return first_node;
2294   } else {
2295     py::object first = value_list[0];
2296     py::list rest;
2297     for (size_t i = 1; i < value_list.size(); i++) {
2298       rest.append(value_list[i]);
2299     }
2300     FunctionBlockPtr true_block = nullptr;
2301     FunctionBlockPtr false_block = nullptr;
2302     auto block_fg = block->func_graph();
2303     {
2304       TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block_fg->debug_info()));
2305       true_block = MakeFunctionBlock();
2306     }
2307     {
2308       TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block_fg->debug_info()));
2309       false_block = MakeFunctionBlock();
2310     }
2311     MakeConditionBlocks(block, true_block, false_block);
2312     FunctionBlockPtr b1;
2313     FunctionBlockPtr b2;
2314 
2315     // If it is and, we need to process the rest nodes;
2316     // If it is or, we continue to next
2317     if (mode == AST_SUB_TYPE_AND) {
2318       b1 = true_block;
2319       b2 = false_block;
2320     } else if (mode == AST_SUB_TYPE_OR) {
2321       b2 = true_block;
2322       b1 = false_block;
2323     } else {
2324       MS_LOG(ERROR) << "Not supported mode: " << mode;
2325       return nullptr;
2326     }
2327     AnfNodePtr test_node = ParseExprNode(block, first);
2328     AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode);
2329     MS_EXCEPTION_IF_NULL(b1->func_graph());
2330     MS_EXCEPTION_IF_NULL(b2->func_graph());
2331     b1->func_graph()->set_output(rest_node);
2332     TraceGuard trace_guard(GetLocation(value_list[1]));
2333     b2->func_graph()->set_output(test_node);
2334 
2335     AnfNodePtr cond_node = block->ForceToCondNode(test_node);
2336     auto switch_app =
2337       block_fg->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
2338                                  NewValueNode(false_block->func_graph())});
2339 
2340     std::vector<AnfNodePtr> call_graph_nodes{switch_app};
2341     auto switch_app_call = block_fg->NewCNodeInOrder(std::move(call_graph_nodes));
2342     return switch_app_call;
2343   }
2344 }
2345 
2346 // Process comparison expression : a and b. a or b .
ParseBoolOp(const FunctionBlockPtr & block,const py::object & node)2347 AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) {
2348   MS_LOG(DEBUG) << "Process ast BoolOp";
2349   py::object op_node = python_adapter::GetPyObjAttr(node, "op");
2350   AstSubType op_type = ast_->GetOpType(op_node);
2351   if (op_type == AST_SUB_TYPE_UNKNOWN) {
2352     MS_LOG(INTERNAL_EXCEPTION) << "ProcessBoolOp, got unknown op type";
2353   }
2354   py::list op_values = python_adapter::GetPyObjAttr(node, "values");
2355   return ProcessBoolOpValueList(block, op_values, op_type);
2356 }
2357 
2358 // Process a function def
ParseFunctionDef(const FunctionBlockPtr & block,const py::object & node)2359 FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
2360   MS_LOG(DEBUG) << "Process ast FunctionDef";
2361   FunctionBlockPtr function_block = ParseDefFunction(node, block);
2362   MS_EXCEPTION_IF_NULL(function_block);
2363 
2364   // Get function name
2365   py::str name = python_adapter::GetPyObjAttr(node, "name");
2366   std::string function_name = name;
2367   ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph());
2368   block->WriteVariable(function_name, valuenode_graph);
2369   return block;
2370 }
2371 
2372 // Process a lambda expression . like lambda x,y: x + y
ParseLambda(const FunctionBlockPtr & block,const py::object & node)2373 AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
2374   MS_LOG(DEBUG) << "Process ast Lambda";
2375   FunctionBlockPtr function_block = ParseLambdaFunction(node, block);
2376   MS_EXCEPTION_IF_NULL(function_block);
2377 
2378   auto block_fg = function_block->func_graph();
2379   ValueNodePtr const_graph = NewValueNode(block_fg);
2380   return const_graph;
2381 }
2382 
2383 // a = *[1, 2], (3, 4)
2384 // StarredUnpackMerge(assign_node1, assign_node2, starred_flags_node, is_tuple)
2385 // StarredUnpackMerge(StarredUnpack(*[1, 2]), (3, 4), (1, 0), 1)
2386 // --> StarredUnpackMerge((1, 2), (3, 4), (1, 0), 1)
2387 // --> (1, 2, (3, 4))
ParseTupleOrListWithStarred(const FunctionBlockPtr & block,const py::object & node,bool is_tuple,const std::vector<AnfNodePtr> & starred_flags)2388 AnfNodePtr Parser::ParseTupleOrListWithStarred(const FunctionBlockPtr &block, const py::object &node, bool is_tuple,
2389                                                const std::vector<AnfNodePtr> &starred_flags) {
2390   auto prim = std::make_shared<prim::StarredUnpackMerge>(NAMED_METAGRAPH_STARRED_UNPACK_MERGE);
2391   std::vector<AnfNodePtr> unpack_merge_inputs{NewValueNode(prim)};
2392   auto starred_flags_node = block->func_graph()->NewCNodeInOrder(starred_flags);
2393   py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
2394   for (size_t i = 0; i < elts.size(); i++) {
2395     AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
2396     auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elts[i])));
2397     if (elt_type == AST_SUB_TYPE_STARRED) {
2398       auto starred_unpack_prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
2399       CNodePtr unpack_node = block->func_graph()->NewCNodeInOrder({NewValueNode(starred_unpack_prim), node_ptr});
2400       (void)unpack_merge_inputs.emplace_back(unpack_node);
2401     } else {
2402       (void)unpack_merge_inputs.emplace_back(node_ptr);
2403     }
2404   }
2405   (void)unpack_merge_inputs.emplace_back(starred_flags_node);
2406   if (is_tuple) {
2407     auto is_tuple_node = NewValueNode(static_cast<int64_t>(1));
2408     (void)unpack_merge_inputs.emplace_back(is_tuple_node);
2409   } else {
2410     auto is_tuple_node = NewValueNode(static_cast<int64_t>(0));
2411     (void)unpack_merge_inputs.emplace_back(is_tuple_node);
2412   }
2413 
2414   CNodePtr unpack_merge_node = block->func_graph()->NewCNodeInOrder(unpack_merge_inputs);
2415   return unpack_merge_node;
2416 }
2417 
ParseTupleOrList(const FunctionBlockPtr & block,const py::object & node,bool is_tuple)2418 AnfNodePtr Parser::ParseTupleOrList(const FunctionBlockPtr &block, const py::object &node, bool is_tuple) {
2419   MS_EXCEPTION_IF_NULL(block);
2420   py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
2421   if (elts.empty()) {
2422     if (is_tuple) {
2423       auto empty_tuple = std::vector<ValuePtr>();
2424       return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
2425     }
2426     auto empty_list = std::vector<ValuePtr>();
2427     return NewValueNode(std::make_shared<ValueList>(empty_list));
2428   }
2429 
2430   bool exist_starred_expression = false;
2431   std::vector<AnfNodePtr> starred_flags{NewValueNode(prim::kPrimMakeTuple)};
2432   for (size_t i = 0; i < elts.size(); i++) {
2433     auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elts[i])));
2434     if (elt_type == AST_SUB_TYPE_STARRED) {
2435       exist_starred_expression = true;
2436       starred_flags.push_back(NewValueNode(static_cast<int64_t>(1)));
2437     } else {
2438       starred_flags.push_back(NewValueNode(static_cast<int64_t>(0)));
2439     }
2440   }
2441 
2442   if (!exist_starred_expression) {
2443     std::vector<AnfNodePtr> sequence_vec;
2444     AnfNodePtr sequence_op = nullptr;
2445     if (is_tuple) {
2446       sequence_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
2447     } else {
2448       sequence_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
2449     }
2450     (void)sequence_vec.emplace_back(sequence_op);
2451     for (size_t i = 0; i < elts.size(); i++) {
2452       AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
2453       (void)sequence_vec.emplace_back(node_ptr);
2454     }
2455     MS_EXCEPTION_IF_NULL(block->func_graph());
2456     CNodePtr sequence_app = block->func_graph()->NewCNodeInOrder(std::move(sequence_vec));
2457     return sequence_app;
2458   }
2459   return ParseTupleOrListWithStarred(block, node, is_tuple, starred_flags);
2460 }
2461 
2462 // Process a tuple
ParseTuple(const FunctionBlockPtr & block,const py::object & node)2463 AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
2464   MS_LOG(DEBUG) << "Process ast Tuple";
2465   return ParseTupleOrList(block, node, true);
2466 }
2467 
2468 // Process a list
ParseList(const FunctionBlockPtr & block,const py::object & node)2469 AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
2470   MS_LOG(DEBUG) << "Process ast List";
2471   return ParseTupleOrList(block, node, false);
2472 }
2473 
GetValuePythonObject(const py::object & value_node)2474 py::object Parser::GetValuePythonObject(const py::object &value_node) {
2475   auto value_name = py::cast<std::string>(value_node);
2476   py::tuple attr_namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value_name);
2477   // Handle nested function def.
2478   constexpr size_t value_index = 2;
2479   auto py_obj_attr_value = attr_namespace_info[value_index];
2480   if (!py::isinstance<py::none>(py_obj_attr_value)) {
2481     return py_obj_attr_value;
2482   }
2483   return py::none();
2484 }
2485 
2486 // Process a subscript, such as x[y] , node expressed as value[slice]
ParseSubscript(const FunctionBlockPtr & block,const py::object & node)2487 AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) {
2488   MS_LOG(DEBUG) << "Process ast Subscript";
2489   MS_EXCEPTION_IF_NULL(block);
2490   AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
2491   py::object value_node = python_adapter::GetPyObjAttr(node, "value");
2492   py::object slice_node = python_adapter::GetPyObjAttr(node, "slice");
2493   AnfNodePtr value = ParseExprNode(block, value_node);
2494   AnfNodePtr slice = ParseExprNode(block, slice_node);
2495   MS_EXCEPTION_IF_NULL(block->func_graph());
2496   auto value_id = python_adapter::GetPyObjAttr(value_node, "id");
2497   AnfNodePtr getitem_node;
2498   py::object value_obj = py::none();
2499   auto str_getitem = std::make_shared<StringImm>("__getitem__");
2500   AnfNodePtr new_node;
2501   // value[slice]. The value of subscript must not be built-in functions
2502   // if the id of value object has the same name as built-in function, should not get the value_obj.
2503   bool value_id_is_builtins =
2504     py::cast<bool>(ast_->CallParserObjMethod(PYTHON_PARSE_IS_BUILTIN_FUNCTION_NAME, py::str(value_id)));
2505   if (!py::isinstance<py::none>(value_id) && !value_id_is_builtins) {
2506     value_obj = GetValuePythonObject(value_id);
2507   }
2508   bool is_adapter = false;
2509   if (!py::isinstance<py::none>(value_obj)) {
2510     if (py::hasattr(value_obj, "adapter_flag")) {
2511       is_adapter = py::cast<bool>(py::getattr(value_obj, "adapter_flag"));
2512     }
2513   }
2514   if (!py::isinstance<py::none>(value_obj) && !is_adapter) {
2515     getitem_node =
2516       block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), value, NewValueNode(str_getitem)});
2517     new_node = block->func_graph()->NewCNodeInOrder({getitem_node, slice});
2518     getitem_node->set_user_data<py::object>("__getitem__", std::make_shared<py::object>(value_obj));
2519   } else {
2520     new_node = block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
2521   }
2522 
2523   return new_node;
2524 }
2525 
2526 // Process a slice, get the slice value
ParseSlice(const FunctionBlockPtr & block,const py::object & node)2527 AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) {
2528   MS_LOG(DEBUG) << "Process ast Slice";
2529   MS_EXCEPTION_IF_NULL(block);
2530   AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE);
2531   py::object start = python_adapter::GetPyObjAttr(node, "lower");
2532   py::object stop = python_adapter::GetPyObjAttr(node, "upper");
2533   py::object step = python_adapter::GetPyObjAttr(node, "step");
2534   AnfNodePtr start_node = ParseExprNode(block, start);
2535   AnfNodePtr stop_node = ParseExprNode(block, stop);
2536   AnfNodePtr step_node = ParseExprNode(block, step);
2537   MS_EXCEPTION_IF_NULL(block->func_graph());
2538   return block->func_graph()->NewCNodeInOrder({op_makeslice, start_node, stop_node, step_node});
2539 }
2540 
2541 // Process a extslice
ParseExtSlice(const FunctionBlockPtr & block,const py::object & node)2542 AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) {
2543   MS_LOG(DEBUG) << "Process ast ExtSlice";
2544   MS_EXCEPTION_IF_NULL(block);
2545   AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
2546   py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims");
2547 
2548   std::vector<AnfNodePtr> node_vec;
2549   (void)node_vec.emplace_back(make_tuple_op);
2550   for (size_t i = 0; i < slice_tuple.size(); i++) {
2551     AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]);
2552     (void)node_vec.emplace_back(node_ptr);
2553   }
2554   MS_EXCEPTION_IF_NULL(block->func_graph());
2555   CNodePtr tuple_conde = block->func_graph()->NewCNodeInOrder(std::move(node_vec));
2556   return tuple_conde;
2557 }
2558 
2559 // Process a index, get the index number
ParseIndex(const FunctionBlockPtr & block,const py::object & node)2560 AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) {
2561   MS_LOG(DEBUG) << "Process ast Index";
2562   py::object value_node = python_adapter::GetPyObjAttr(node, "value");
2563   return ParseExprNode(block, value_node);
2564 }
2565 
2566 // Process a UnaryOp, +a, -b
ParseUnaryOp(const FunctionBlockPtr & block,const py::object & node)2567 AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) {
2568   MS_LOG(DEBUG) << "Process ast UnaryOp";
2569   py::object op = python_adapter::GetPyObjAttr(node, "op");
2570 
2571   MS_EXCEPTION_IF_NULL(block);
2572   // Resolve the op
2573   const auto &ns = block->GetAstOpNameSpace(op);
2574   auto op_node = block->MakeResolveAstOpNameSpace(ns);
2575 
2576   py::object operand = python_adapter::GetPyObjAttr(node, "operand");
2577   AnfNodePtr operand_node = ParseExprNode(block, operand);
2578   MS_EXCEPTION_IF_NULL(block->func_graph());
2579   return block->func_graph()->NewCNodeInOrder({op_node, operand_node});
2580 }
2581 
2582 // Process a dict ast node expression
ParseDictByKeysAndValues(const FunctionBlockPtr & block,const std::vector<AnfNodePtr> & key_nodes,const std::vector<AnfNodePtr> & value_nodes)2583 AnfNodePtr Parser::ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
2584                                             const std::vector<AnfNodePtr> &value_nodes) {
2585   auto keys_tuple = GenerateMakeTuple(block, key_nodes);
2586   auto values_tuple = GenerateMakeTuple(block, value_nodes);
2587   MS_EXCEPTION_IF_NULL(block);
2588   auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
2589   MS_EXCEPTION_IF_NULL(block->func_graph());
2590   return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
2591 }
2592 
GetRealKeysValuesFromName(const FunctionBlockPtr & block,const py::object & node)2593 std::pair<AnfNodePtr, AnfNodePtr> Parser::GetRealKeysValuesFromName(const FunctionBlockPtr &block,
2594                                                                     const py::object &node) {
2595   MS_EXCEPTION_IF_NULL(block);
2596   auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
2597   AnfNodePtr dict = block->ReadVariable(name_id);
2598   auto keys = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimDictGetKeys), dict});
2599   auto values = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimDictGetValues), dict});
2600   // Using the MakeTuple node, pass the need_unpack tag from the AnfNode to the abstract
2601   auto tuple_keys = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), keys});
2602   auto tuple_values = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), values});
2603   constexpr auto need_unpack = "need_unpack";
2604   tuple_keys->set_user_data<bool>(need_unpack, std::make_shared<bool>(true));
2605   tuple_values->set_user_data<bool>(need_unpack, std::make_shared<bool>(true));
2606   return {tuple_keys, tuple_values};
2607 }
2608 
GetRealKeysValues(const FunctionBlockPtr & block,const py::object & node)2609 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> Parser::GetRealKeysValues(const FunctionBlockPtr &block,
2610                                                                                       const py::object &node) {
2611   py::list keys = node.attr("keys");
2612   py::list values = node.attr("values");
2613   if (keys.size() != values.size()) {
2614     MS_LOG(INTERNAL_EXCEPTION) << "The keys' size is not equal to the values' size.";
2615   }
2616   std::vector<AnfNodePtr> inner_key_nodes;
2617   std::vector<AnfNodePtr> inner_value_nodes;
2618   for (size_t index = 0; index < keys.size(); ++index) {
2619     auto inner_key_node_type = ast_->GetNodeType(keys[index]);
2620     const std::string &inner_key_node_type_name = inner_key_node_type->node_name();
2621     // The key does not exist, mean the value is a dict which need unpack.
2622     if (inner_key_node_type_name == "NoneType") {
2623       auto unpack_dict = values[index];
2624       auto inner_value_node_type =
2625         AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, unpack_dict)));
2626       if (inner_value_node_type == AST_SUB_TYPE_DICT) {
2627         auto [unpack_keys, unpack_values] = GetRealKeysValues(block, unpack_dict);
2628         for (size_t i = 0; i < unpack_keys.size(); ++i) {
2629           inner_key_nodes.push_back(unpack_keys[i]);
2630           inner_value_nodes.push_back(unpack_values[i]);
2631         }
2632       } else if (inner_value_node_type == AST_SUB_TYPE_NAME) {
2633         auto [unpack_key, unpack_value] = GetRealKeysValuesFromName(block, unpack_dict);
2634         inner_key_nodes.push_back(unpack_key);
2635         inner_value_nodes.push_back(unpack_value);
2636       } else {
2637         MS_LOG(INTERNAL_EXCEPTION) << "The input of dict which need unpack must be dict, but got "
2638                                    << inner_value_node_type;
2639       }
2640     } else {
2641       AnfNodePtr key_node = ParseExprNode(block, keys[index]);
2642       inner_key_nodes.push_back(key_node);
2643       AnfNodePtr value_node = ParseExprNode(block, values[index]);
2644       inner_value_nodes.push_back(value_node);
2645     }
2646   }
2647   return {inner_key_nodes, inner_value_nodes};
2648 }
2649 
ParseDict(const FunctionBlockPtr & block,const py::object & node)2650 AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
2651   MS_LOG(DEBUG) << "Process ast Dict";
2652   auto [key_nodes, value_nodes] = GetRealKeysValues(block, node);
2653   return ParseDictByKeysAndValues(block, key_nodes, value_nodes);
2654 }
2655 
2656 // Process a augment assign such as a += b or mat[stride_slice] += b.
ParseAugAssign(const FunctionBlockPtr & block,const py::object & node)2657 FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
2658   MS_LOG(DEBUG) << "Process ast AugAssign";
2659   MS_EXCEPTION_IF_NULL(block);
2660   MS_EXCEPTION_IF_NULL(ast_);
2661 
2662   py::object target_object = python_adapter::GetPyObjAttr(node, "target");
2663   py::object op_object = python_adapter::GetPyObjAttr(node, "op");
2664   py::object value_object = python_adapter::GetPyObjAttr(node, "value");
2665   AnfNodePtr target_node = nullptr;
2666 
2667   const auto &ns = block->GetAstOpNameSpace(op_object);
2668   auto op_node = block->MakeResolveAstOpNameSpace(ns);
2669 
2670   AnfNodePtr value_node = ParseExprNode(block, value_object);
2671   {
2672     TraceGuard trace_guard(GetLocation(target_object));
2673     target_node = ParseExprNode(block, target_object);
2674   }
2675 
2676   if (target_node == nullptr) {
2677     MS_LOG(INTERNAL_EXCEPTION) << "Can not get target node ";
2678   }
2679   MS_EXCEPTION_IF_NULL(block->func_graph());
2680   AnfNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
2681 
2682   // b += list_x.pop(a)
2683   // -->  list_x, b = list_x, b + list_x.pop(a) need renew the list_x.
2684   if (IsPopOperation(value_node)) {
2685     ProcessPopOperationInAugAssign(block, value_node, target_node, op_node, target_object);
2686     return block;
2687   }
2688 
2689   WriteAssignVars(block, target_object, augassign_app);
2690   return block;
2691 }
2692 
2693 // Process global declaration such as 'global x';
ParseGlobal(const FunctionBlockPtr & block,const py::object & node)2694 FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
2695   MS_LOG(DEBUG) << "Process ast Global";
2696   MS_EXCEPTION_IF_NULL(block);
2697   py::list vars = python_adapter::GetPyObjAttr(node, "names");
2698   for (auto &item : vars) {
2699     block->AddGlobalVar(py::cast<std::string>(item));
2700   }
2701   return block;
2702 }
2703 
CheckControlFlowAlterationInIf(std::pair<FunctionBlockPtr,FunctionBlockPtr> * branch_graphs_pair,const FunctionBlockPtr & branch_block,const FunctionBlockPtr & branch_end,const FunctionBlockPtr & after_block,const FunctionBlockPtr & block) const2704 void Parser::CheckControlFlowAlterationInIf(std::pair<FunctionBlockPtr, FunctionBlockPtr> *branch_graphs_pair,
2705                                             const FunctionBlockPtr &branch_block, const FunctionBlockPtr &branch_end,
2706                                             const FunctionBlockPtr &after_block, const FunctionBlockPtr &block) const {
2707   if (branch_block->is_return_statement_inside()) {
2708     MS_LOG(DEBUG)
2709       << "Inside the branch block has return statement, ignore for transformation to parallel-if call, branch block:"
2710       << branch_block->ToString() << ", block: " << block->ToString();
2711     block->set_is_return_statement_inside();
2712     return;
2713   }
2714   if (branch_block->is_break_continue_statement_inside()) {
2715     MS_LOG(DEBUG) << "Inside the branch block has break or continue statement, ignore for transformation to "
2716                      "parallel-if call, branch block: "
2717                   << branch_block->ToString() << ", branch end: " << branch_end->ToString()
2718                   << ", block: " << block->ToString();
2719     MS_LOG(DEBUG) << "Propagate flag of break or continue statement from branch block to block, branch block:"
2720                   << branch_block->ToString() << ", block: " << block->ToString();
2721     block->set_break_continue_statement_inside();
2722   } else if (branch_end->func_graph()->get_return() != nullptr) {
2723     // Currently, this can only happen with raise statement inside. As try/expect is not supported now,
2724     // and contional for raise will be evaluated in Compile time. If raise condition is met, it will
2725     // cause compile fail, so no need to propagate the flag back.
2726     MS_LOG(DEBUG) << "Ignore the block as branch_end will not call after_block, branch_block: "
2727                   << branch_block->ToString() << ", branch_end: " << branch_end->ToString()
2728                   << ", after_block: " << after_block->ToString();
2729   } else {
2730     branch_graphs_pair->second = branch_end;
2731   }
2732 }
2733 
2734 // Check constant bool constant attr, such as:
2735 //   if self.has_bias
CheckAttributeConstantCond(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond) const2736 bool Parser::CheckAttributeConstantCond(const FunctionBlockPtr &block, const py::object &test_node,
2737                                         bool *is_true_cond) const {
2738   bool is_constant;
2739   auto attr_cond = GetPyObjForAstAttr(block, test_node, &is_constant);
2740   if (!is_constant) {
2741     return false;
2742   }
2743   if (!py::isinstance<py::bool_>(attr_cond)) {
2744     return false;
2745   }
2746   *is_true_cond = py::cast<bool>(attr_cond);
2747   return true;
2748 }
2749 
2750 // Check constant local var, such as:
2751 //   if has_bias
CheckNameConstantCond(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond) const2752 bool Parser::CheckNameConstantCond(const FunctionBlockPtr &block, const py::object &test_node,
2753                                    bool *is_true_cond) const {
2754   auto id = python_adapter::GetPyObjAttr(test_node, "id");
2755   if (!py::isinstance<py::str>(id)) {
2756     return false;
2757   }
2758   auto anf_node = block->ReadVariable(id.cast<std::string>());
2759   if (anf_node == nullptr) {
2760     return false;
2761   }
2762   MS_LOG(DEBUG) << "CheckNameConstantCond anf_node: " << anf_node->DebugString();
2763   ValuePtr value = nullptr;
2764   if (anf_node->isa<ValueNode>()) {
2765     value = anf_node->cast<ValueNodePtr>()->value();
2766   } else if (anf_node->isa<Parameter>()) {
2767     value = GetParameterValue(anf_node);
2768     if (value == nullptr || value->ContainsValueAny()) {
2769       return false;
2770     }
2771     MS_LOG(DEBUG) << "Found constant value: " << value->ToString() << " for anf_node: " << anf_node;
2772   }
2773   if (value == nullptr || !value->isa<BoolImm>()) {
2774     return false;
2775   }
2776   MS_LOG(DEBUG) << "CheckNameConstantCond value: " << value->ToString();
2777   *is_true_cond = GetValue<bool>(value);
2778   return true;
2779 }
2780 
2781 // Check constant unary op result, such as:
2782 //   if not self.has_bias
CheckUnaryOpConstantCond(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond) const2783 bool Parser::CheckUnaryOpConstantCond(const FunctionBlockPtr &block, const py::object &test_node,
2784                                       bool *is_true_cond) const {
2785   auto op = python_adapter::GetPyObjAttr(test_node, "op");
2786   auto op_node_type = ast()->GetNodeType(op);
2787   const auto &op_node_type_name = op_node_type->node_name();
2788   MS_LOG(DEBUG) << "op_node_type_name: " << op_node_type_name;
2789   if (op_node_type_name != "Not") {
2790     return false;
2791   }
2792   auto operand = python_adapter::GetPyObjAttr(test_node, "operand");
2793   auto check_constant_cond = CheckConstantCondition(block, operand, is_true_cond);
2794   if (!check_constant_cond) {
2795     return false;
2796   }
2797   *is_true_cond = !(*is_true_cond);
2798   return true;
2799 }
2800 
2801 // Check constant compare result, such as:
2802 //   if self.has_bias is None
CheckCompareConstantCond(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond) const2803 bool Parser::CheckCompareConstantCond(const FunctionBlockPtr &block, const py::object &test_node,
2804                                       bool *is_true_cond) const {
2805   return GetBoolObjForAstCompare(block, test_node, is_true_cond);
2806 }
2807 
2808 // Check constant bool op result, such as:
2809 //   if self.has_bias is None and self.beta == 1
CheckBoolOpConstantCond(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond) const2810 bool Parser::CheckBoolOpConstantCond(const FunctionBlockPtr &block, const py::object &test_node,
2811                                      bool *is_true_cond) const {
2812   auto op = python_adapter::GetPyObjAttr(test_node, "op");
2813   auto op_node_type = ast()->GetNodeType(op);
2814   const auto &op_node_type_name = op_node_type->node_name();
2815   MS_LOG(DEBUG) << "op_node_type_name: " << op_node_type_name;
2816   py::list values = python_adapter::GetPyObjAttr(test_node, "values");
2817   bool determined = false;
2818   for (size_t i = 0; i < values.size(); ++i) {
2819     bool sub_is_true_cond;
2820     auto check_constant_cond = CheckConstantCondition(block, values[i], &sub_is_true_cond);
2821     if (!check_constant_cond) {
2822       return false;
2823     }
2824     if (op_node_type_name == "Or" && sub_is_true_cond) {
2825       determined = true;
2826       break;
2827     } else if (op_node_type_name == "And" && !sub_is_true_cond) {
2828       determined = true;
2829       break;
2830     }
2831   }
2832   if (op_node_type_name == "Or") {
2833     *is_true_cond = determined;
2834   } else if (op_node_type_name == "And") {
2835     *is_true_cond = !determined;
2836   }
2837   return true;
2838 }
2839 
GetConstantConditionFromComment(const FunctionBlockPtr & block,const py::object & if_node,bool * is_true_cond) const2840 bool Parser::GetConstantConditionFromComment(const FunctionBlockPtr &block, const py::object &if_node,
2841                                              bool *is_true_cond) const {
2842   auto location = GetLocation(if_node);
2843   MS_EXCEPTION_IF_NULL(location);
2844   const auto &comments = location->comments();
2845   if (comments.empty()) {
2846     return false;
2847   }
2848   const auto &comment = comments.back();
2849   MS_LOG(DEBUG) << "The comment of if statement: " << comment << ", block: " << block->ToString();
2850   std::regex regex("^#\\s*@jit.cond:\\s*([A-Za-z]+)$");
2851   std::smatch matched_results;
2852   if (!std::regex_match(comment, matched_results, regex)) {
2853     return false;
2854   }
2855   constexpr auto container_match_count = 2;
2856   if (matched_results.size() != container_match_count) {
2857     return false;
2858   }
2859   const auto &cond_str = matched_results[1].str();
2860   MS_LOG(DEBUG) << "The cond string of comment is " << cond_str;
2861   if (cond_str != "True" && cond_str != "False") {
2862     return false;
2863   }
2864   *is_true_cond = (cond_str == "True");
2865   return true;
2866 }
2867 
2868 // Return true if it's constant condition and the condition value returned by is_true_cond, otherwise return false.
CheckConstantCondition(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond,const py::object & if_node) const2869 bool Parser::CheckConstantCondition(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond,
2870                                     const py::object &if_node) const {
2871   static const auto boost_parse = common::GetCompileConfig("BOOST_PARSE");
2872   if (boost_parse == "0") {
2873     return false;
2874   }
2875   MS_EXCEPTION_IF_NULL(block);
2876   MS_EXCEPTION_IF_NULL(is_true_cond);
2877   // Try to get the constant condition from the comment "@jit.cond: True/False".
2878   if (if_node != py::none() && GetConstantConditionFromComment(block, if_node, is_true_cond)) {
2879     return true;
2880   }
2881   auto node_type = ast()->GetNodeType(test_node);
2882   const std::string &node_type_name = node_type->node_name();
2883   MS_LOG(DEBUG) << "node_type_name: " << node_type_name;
2884 
2885   auto func_iter = condition_method_map_.find(node_type_name);
2886   if (func_iter == condition_method_map_.end()) {
2887     return false;
2888   }
2889   auto check_constant = (this->*(func_iter->second))(block, test_node, is_true_cond);
2890   if (check_constant) {
2891     MS_LOG(DEBUG) << "Has constant condition, is_true_cond: " << *is_true_cond;
2892     return true;
2893   }
2894   MS_LOG(DEBUG) << "Has no constant condition, test_node: " << py::str(test_node);
2895   return false;
2896 }
2897 
2898 // Process a if statement
ParseIf(const FunctionBlockPtr & block,const py::object & node)2899 FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) {
2900   MS_LOG(DEBUG) << "Process ast If";
2901   MS_EXCEPTION_IF_NULL(block);
2902   py::object test_node = python_adapter::GetPyObjAttr(node, "test");
2903   bool is_true_cond = false;
2904   bool is_bool_const_cond = CheckConstantCondition(block, test_node, &is_true_cond, node);
2905 
2906   // Make condition node.
2907   AnfNodePtr bool_node = nullptr;
2908   if (!is_bool_const_cond) {
2909     AnfNodePtr condition_node = ParseExprNode(block, test_node);
2910     bool_node = block->ForceToCondNode(condition_node);
2911   }
2912 
2913   FunctionBlockPtr true_block = nullptr;
2914   FunctionBlockPtr false_block = nullptr;
2915   FunctionBlockPtr after_block = nullptr;
2916   auto block_fg = block->func_graph();
2917   MS_EXCEPTION_IF_NULL(block_fg);
2918   if (!is_bool_const_cond || is_true_cond) {
2919     TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block_fg->debug_info()));
2920     true_block = MakeFunctionBlock();
2921     MS_LOG(DEBUG) << "Make true branch, " << true_block->ToString();
2922   }
2923   if (!is_bool_const_cond || !is_true_cond) {
2924     TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block_fg->debug_info()));
2925     false_block = MakeFunctionBlock();
2926     MS_LOG(DEBUG) << "Make false branch, " << false_block->ToString();
2927   }
2928 
2929   if (!is_bool_const_cond) {
2930     MakeConditionBlocks(block, true_block, false_block);
2931   } else if (is_true_cond) {
2932     MS_LOG(DEBUG) << "Connect true branch, " << true_block->ToString();
2933     block->Jump(true_block, {});
2934     true_block->Mature();
2935     true_block->UpdateGlobalPyParam(block->global_py_params());
2936   } else {  // !is_true_cond
2937     MS_LOG(DEBUG) << "Connect false branch, " << false_block->ToString();
2938     block->Jump(false_block, {});
2939     false_block->Mature();
2940     false_block->UpdateGlobalPyParam(block->global_py_params());
2941   }
2942 
2943   {
2944     TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block_fg->debug_info()));
2945     after_block = MakeFunctionBlock();
2946   }
2947 
2948   if (!is_bool_const_cond && MsContext::GetInstance()->backend_policy() != "ge") {
2949     // For backends excludes 'ge', it can handle multi graph call, use this flag to
2950     // generate call not inline `after_block` graph to reduce if by if switch expansion.
2951     MS_EXCEPTION_IF_NULL(after_block->func_graph());
2952     after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
2953   }
2954 
2955   // Process the if-true branch
2956   std::pair<FunctionBlockPtr, FunctionBlockPtr> true_branch_graphs;
2957   if (!is_bool_const_cond || is_true_cond) {
2958     py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
2959     FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
2960     std::string true_branch_name = "true branch";
2961     true_end->set_block_name(true_branch_name);
2962     MS_EXCEPTION_IF_NULL(true_end->func_graph());
2963     CheckControlFlowAlterationInIf(&true_branch_graphs, true_block, true_end, after_block, block);
2964     // If the return_ is set, it has its own continuation block
2965     if (true_end->func_graph()->get_return() == nullptr) {
2966       TraceGuard trace_guard_true(GetLocation(bodyNode));
2967       true_end->Jump(after_block, {});
2968       MS_LOG(DEBUG) << "The true_end block jump to after, true_block: " << true_block->ToString()
2969                     << ", true_end: " << true_end->ToString() << ", after: " << after_block->ToString();
2970       after_block->UpdateGlobalPyParam(true_end->global_py_params());
2971     }
2972   }
2973 
2974   // Process the orelse branch
2975   std::pair<FunctionBlockPtr, FunctionBlockPtr> false_branch_graphs;
2976   if (!is_bool_const_cond || !is_true_cond) {
2977     py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
2978     FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode);
2979     std::string false_branch_name = "false branch";
2980     false_end->set_block_name(false_branch_name);
2981     MS_EXCEPTION_IF_NULL(false_end->func_graph());
2982     CheckControlFlowAlterationInIf(&false_branch_graphs, false_block, false_end, after_block, block);
2983     // If the return_ is set, it has its own continuation block
2984     if (false_end->func_graph()->get_return() == nullptr) {
2985       if (py::len_hint(orelseNode) != 0) {
2986         TraceGuard trace_guard_false(GetLocation(orelseNode));
2987         false_end->Jump(after_block, {});
2988       } else {
2989         false_end->Jump(after_block, {});
2990       }
2991       MS_LOG(DEBUG) << "The false_end block jump to after, false_block: " << false_block->ToString()
2992                     << ", false_end: " << false_end->ToString() << ", after: " << after_block->ToString();
2993       after_block->UpdateGlobalPyParam(false_end->global_py_params());
2994     }
2995   }
2996 
2997   if (!is_bool_const_cond) {
2998     MS_EXCEPTION_IF_NULL(bool_node);
2999     auto switch_app = block->ConditionalJump(bool_node, true_block, false_block);
3000 
3001     // Record the former, middle, latter graphs info.
3002     static const auto transform_tail_call_to_parallel_call = (common::GetCompileConfig("IF_PARALLEL_CALL") != "0");
3003     if (transform_tail_call_to_parallel_call && true_branch_graphs.second != nullptr &&
3004         false_branch_graphs.second != nullptr) {
3005       true_branch_graphs.first = block;
3006       false_branch_graphs.first = block;
3007       MS_LOG(DEBUG) << "Record tail call {former: " << block->func_graph()->ToString()
3008                     << ", true middle: " << true_branch_graphs.second->func_graph()->ToString()
3009                     << ", false middle: " << false_branch_graphs.second->func_graph()->ToString() << "}";
3010       std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> branch_graphs_vec{true_branch_graphs,
3011                                                                                    false_branch_graphs};
3012       (void)parallel_call_graphs_.emplace_back(branch_graphs_vec);
3013     }
3014 
3015     static const auto transform_for_half_unroll_call = (common::GetCompileConfig("FOR_HALF_UNROLL") == "1");
3016     if (transform_for_half_unroll_call) {
3017       // Lift the if branches in for statement.
3018       (void)if_branch_calls_.emplace_back(std::make_tuple(switch_app, true_block, false_block));
3019     }
3020   }
3021 
3022   if (after_block->prev_blocks().empty()) {
3023     MS_LOG(DEBUG) << "After block's previous block is null";
3024     after_block->SetAsDeadBlock();
3025   }
3026   after_block->Mature();
3027   return after_block;
3028 }
3029 
CheckReturnInLoop(const FunctionBlockPtr & block,const FunctionBlockPtr & body_block) const3030 void Parser::CheckReturnInLoop(const FunctionBlockPtr &block, const FunctionBlockPtr &body_block) const {
3031   MS_EXCEPTION_IF_NULL(block);
3032   MS_EXCEPTION_IF_NULL(body_block);
3033   // Propagate flag of return statement in body_block back.
3034   if (body_block->is_return_statement_inside()) {
3035     MS_LOG(DEBUG) << "Propagate flag of return statement in body_block back, body_block: " << body_block->ToString()
3036                   << ", block: " << block->ToString();
3037     block->set_is_return_statement_inside();
3038   }
3039 }
3040 
ParseWhile(const FunctionBlockPtr & block,const py::object & node)3041 FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) {
3042   MS_LOG(DEBUG) << "Process ast While";
3043   MS_EXCEPTION_IF_NULL(block);
3044   std::string while_block_name = "while";
3045   block->set_block_name(while_block_name);
3046   FunctionBlockPtr header_block = nullptr;
3047   FunctionBlockPtr body_block = nullptr;
3048   FunctionBlockPtr after_block = nullptr;
3049   MS_EXCEPTION_IF_NULL(block->func_graph());
3050   {
3051     TraceGuard guard(std::make_shared<TraceWhileHeader>(block->func_graph()->debug_info()));
3052     header_block = MakeFunctionBlock();
3053     auto func_graph = header_block->func_graph();
3054     MS_EXCEPTION_IF_NULL(func_graph);
3055     func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
3056   }
3057   {
3058     TraceGuard guard(std::make_shared<TraceWhileBody>(block->func_graph()->debug_info()));
3059     body_block = MakeFunctionBlock();
3060   }
3061   {
3062     TraceGuard guard(std::make_shared<TraceWhileAfter>(block->func_graph()->debug_info()));
3063     after_block = MakeFunctionBlock();
3064   }
3065 
3066   body_block->AddPrevBlock(header_block);
3067   after_block->AddPrevBlock(header_block);
3068   block->Jump(header_block, {});
3069 
3070   py::object test_node = python_adapter::GetPyObjAttr(node, "test");
3071   header_block->UpdateGlobalPyParam(block->global_py_params());
3072   body_block->UpdateGlobalPyParam(block->global_py_params());
3073   after_block->UpdateGlobalPyParam(block->global_py_params());
3074   AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
3075   AnfNodePtr while_condition_node = nullptr;
3076   {
3077     TraceGuard trace_guard(std::make_shared<TraceForceWhileCond>(condition_node->debug_info()));
3078     while_condition_node = header_block->ForceToCondNode(condition_node, true);
3079   }
3080   (void)header_block->ConditionalJump(while_condition_node, body_block, after_block);
3081 
3082   body_block->Mature();
3083   // Parse loop body statements with loop context.
3084   LoopContext loop_context{&loops_, header_block, nullptr};
3085   py::object body_node = python_adapter::GetPyObjAttr(node, "body");
3086   FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
3087   MS_EXCEPTION_IF_NULL(after_body->func_graph());
3088   if (after_body->func_graph()->get_return() == nullptr) {
3089     after_body->Jump(header_block, {});
3090   }
3091   header_block->Mature();
3092   after_block->Mature();
3093   py::object orelse_obj = python_adapter::GetPyObjAttr(node, "orelse");
3094   if (py::len_hint(orelse_obj) != 0) {
3095     TraceGuard trace_guard(GetLocation(orelse_obj));
3096     MS_LOG(EXCEPTION) << "The 'while...else...' statement is not supported now.";
3097   }
3098   auto &end_block = loop_context.EndBlock();
3099   // end_block exists if we encounter 'break' in loop body.
3100   if (end_block) {
3101     after_block->Jump(end_block, {});
3102     end_block->Mature();
3103     CheckReturnInLoop(block, body_block);
3104     return end_block;
3105   }
3106   // No 'break', no end_block.
3107   CheckReturnInLoop(block, body_block);
3108   return after_block;
3109 }
3110 
ParseFor(const FunctionBlockPtr & block,const py::object & node)3111 FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
3112   // Check for-else
3113   py::object orelse_obj = python_adapter::GetPyObjAttr(node, "orelse");
3114   if (py::len_hint(orelse_obj) != 0) {
3115     TraceGuard trace_guard(GetLocation(orelse_obj));
3116     MS_LOG(EXCEPTION) << "The 'for...else...' statement is not supported now.";
3117   }
3118   std::string for_block_name = "for";
3119   block->set_block_name(for_block_name);
3120   static const auto transform_for_half_unroll_call = (common::GetCompileConfig("FOR_HALF_UNROLL") == "1");
3121   if (transform_for_half_unroll_call) {
3122     return ParseForRepeat(block, node);
3123   }
3124   return ParseForUnroll(block, node);
3125 }
3126 
3127 // Implement unroll for statement with tuple/getitem.
ParseForUnroll(const FunctionBlockPtr & block,const py::object & node)3128 FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
3129   MS_LOG(DEBUG) << "Process ast For by loop variable";
3130   MS_EXCEPTION_IF_NULL(block);
3131   AnfNodePtr op_len = block->MakeResolveOperation(NAMED_PRIMITIVE_LEN);
3132   AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
3133   AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
3134 
3135   // Get variable name of 'x' in statement 'for x in xs'
3136   py::object target_node = python_adapter::GetPyObjAttr(node, "target");
3137 
3138   // Create statement 'len(xs)'
3139   py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
3140   MS_LOG(DEBUG) << "Parse Recursive Iter, iter_obj: " << py::str(iter_obj);
3141   AnfNodePtr origin_iter_node = ParseExprNode(block, iter_obj);
3142   MS_EXCEPTION_IF_NULL(origin_iter_node);
3143   auto iter_node = block->func_graph()->NewCNodeInOrder({op_iter, origin_iter_node});
3144   CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
3145   FunctionBlockPtr header_block =
3146     MakeFunctionBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
3147   MS_EXCEPTION_IF_NULL(header_block);
3148   // Create loop variable 'i'
3149   ParameterPtr loop_var = header_block->func_graph()->add_parameter();
3150 
3151   std::string less_module_name = "mindspore.ops.composite.multitype_ops.less_impl";
3152   ValuePtr less_op = prim::GetPythonOps("less", less_module_name);
3153   CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(less_op), loop_var, scalar_len});
3154 
3155   // Generate the body of the for statement
3156   FunctionBlockPtr body_block = MakeFunctionBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
3157   MS_EXCEPTION_IF_NULL(body_block);
3158   body_block->AddPrevBlock(header_block);
3159   // Create 'x = xs[i]'
3160   auto body_func_graph = body_block->func_graph();
3161   MS_EXCEPTION_IF_NULL(body_func_graph);
3162   auto target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
3163   header_block->UpdateGlobalPyParam(block->global_py_params());
3164   body_block->UpdateGlobalPyParam(block->global_py_params());
3165   WriteAssignVars(body_block, target_node, target_var);
3166 
3167   // Create 'i = i + 1'
3168   std::string add_module_name = "mindspore.ops.composite.multitype_ops.add_impl";
3169   ValuePtr add_op = prim::GetPythonOps("add", add_module_name);
3170   auto add_one = NewValueNode(static_cast<int64_t>(1));
3171   CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({NewValueNode(add_op), loop_var, add_one});
3172 
3173   body_block->WriteVariable(loop_var->name(), loop_var_inc);
3174 
3175   // Link the variable name with the target
3176   auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
3177   loop_var->debug_info()->set_trace_info(it_info);
3178 
3179   FunctionBlockPtr after_block = nullptr;
3180   {
3181     TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
3182     after_block = MakeFunctionBlock();
3183   }
3184   MS_EXCEPTION_IF_NULL(after_block);
3185   after_block->AddPrevBlock(header_block);
3186   block->Jump(header_block, {NewValueNode(static_cast<int64_t>(0))});
3187   body_block->Mature();
3188   after_block->UpdateGlobalPyParam(block->global_py_params());
3189 
3190   (void)header_block->ConditionalJump(cond_node, body_block, after_block);
3191 
3192   // Parse loop body statements with loop context.
3193   LoopContext loop_context{&loops_, header_block, loop_var_inc};
3194   py::object body_node = python_adapter::GetPyObjAttr(node, "body");
3195   FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
3196   after_body_block->UpdateGlobalPyParam(block->global_py_params());
3197   if (after_body_block->func_graph()->get_return() == nullptr) {
3198     after_body_block->Jump(header_block, {loop_var_inc});
3199   }
3200 
3201   header_block->Mature();
3202   after_block->Mature();
3203   auto &end_block = loop_context.EndBlock();
3204   if (end_block) {
3205     // end_block exists if we encounter 'break' in loop body.
3206     after_block->Jump(end_block, {});
3207     end_block->Mature();
3208     CheckReturnInLoop(block, body_block);
3209     return end_block;
3210   }
3211   // No 'break', no end_block.
3212   CheckReturnInLoop(block, body_block);
3213   return after_block;
3214 }
3215 
3216 // Implement for statement with repeat calling sub graph.
ParseForRepeat(const FunctionBlockPtr & block,const py::object & node)3217 FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py::object &node) {
3218   MS_LOG(DEBUG) << "Process ast For by loop variable";
3219   MS_EXCEPTION_IF_NULL(block);
3220   FunctionBlockPtr header_block =
3221     MakeFunctionBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
3222   MS_EXCEPTION_IF_NULL(header_block);
3223 
3224   // Create statement 'len(xs)'
3225   py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
3226   AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
3227   MS_EXCEPTION_IF_NULL(iter_node);
3228   // Generate node for loop count and convert it to tensor, to make the loop not unroll
3229   ParameterPtr header_iter_param = header_block->func_graph()->add_parameter();
3230   AnfNodePtr header_len = header_block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
3231   header_block->CheckUndefinedSymbol(NAMED_PRIMITIVE_LEN, header_len);
3232   CNodePtr scalar_len = header_block->func_graph()->NewCNodeInOrder({header_len, header_iter_param});
3233 
3234   // Create loop variable 'i'
3235   ParameterPtr loop_var = header_block->func_graph()->add_parameter();
3236   // Create loop condition 'i < len(xs)'
3237   std::string less_module_name = "mindspore.ops.composite.multitype_ops.less_impl";
3238   ValuePtr less_op = prim::GetPythonOps("less", less_module_name);
3239   CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(less_op), loop_var, scalar_len});
3240 
3241   // Generate the body of the for statement
3242   FunctionBlockPtr body_block = MakeFunctionBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
3243   MS_EXCEPTION_IF_NULL(body_block);
3244   body_block->AddPrevBlock(header_block);
3245   // Create 'x = xs[i]'
3246   auto body_func_graph = body_block->func_graph();
3247   AnfNodePtr body_getitem = body_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
3248   CNodePtr target_var = body_func_graph->NewCNodeInOrder({body_getitem, header_iter_param, loop_var});
3249 
3250   header_block->UpdateGlobalPyParam(block->global_py_params());
3251   body_block->UpdateGlobalPyParam(block->global_py_params());
3252 
3253   // Get variable name of 'x' in statement 'for x in xs'
3254   py::object target_node = python_adapter::GetPyObjAttr(node, "target");
3255   WriteAssignVars(body_block, target_node, target_var);
3256 
3257   // Create 'i = i + 1'
3258   std::string add_module_name = "mindspore.ops.composite.multitype_ops.add_impl";
3259   ValuePtr add_op = prim::GetPythonOps("add", add_module_name);
3260   CNodePtr loop_var_inc =
3261     body_func_graph->NewCNodeInOrder({NewValueNode(add_op), loop_var, NewValueNode(static_cast<int64_t>(1))});
3262   body_block->WriteVariable(loop_var->name(), loop_var_inc);
3263 
3264   // Link the variable name with the target
3265   auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
3266   loop_var->debug_info()->set_trace_info(it_info);
3267 
3268   FunctionBlockPtr after_block = nullptr;
3269   {
3270     TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
3271     after_block = MakeFunctionBlock();
3272   }
3273   MS_EXCEPTION_IF_NULL(after_block);
3274   after_block->AddPrevBlock(header_block);
3275   block->Jump(header_block, {iter_node, NewValueNode(static_cast<int64_t>(0))});
3276   body_block->Mature();
3277   after_block->UpdateGlobalPyParam(block->global_py_params());
3278   (void)header_block->ConditionalJump(cond_node, body_block, after_block);
3279 
3280   // Generate the body of the for statement
3281   FunctionBlockPtr rolled_body_block =
3282     MakeFunctionBlock(std::make_shared<TraceForRolledBody>(body_block->func_graph()->debug_info()));
3283   MS_EXCEPTION_IF_NULL(rolled_body_block);
3284 
3285   rolled_body_block->Mature();
3286   body_block->Jump(rolled_body_block, {});
3287   auto rolled_body_call = dyn_cast<CNode>(body_block->func_graph()->output());
3288   rolled_body_block->UpdateGlobalPyParam(block->global_py_params());
3289 
3290   // Parse loop body statements with loop context.
3291   LoopContext loop_context{&loops_, header_block, loop_var_inc};
3292   py::object body_node = python_adapter::GetPyObjAttr(node, "body");
3293   FunctionBlockPtr after_body_block = ParseStatements(rolled_body_block, body_node);
3294   after_body_block->UpdateGlobalPyParam(block->global_py_params());
3295   MS_LOG(DEBUG) << "Finish rolled block, after_body_block: " << after_body_block->ToString()
3296                 << ", rolled_body_block: " << rolled_body_block->ToString();
3297   if (after_body_block->func_graph()->get_return() == nullptr) {
3298     after_body_block->Jump(header_block, {header_iter_param, loop_var_inc});
3299   }
3300 
3301   // Record the former/middle/latter graphs for later transforming.
3302   static const auto transform_for_half_unroll_call = (common::GetCompileConfig("FOR_HALF_UNROLL") == "1");
3303   if (transform_for_half_unroll_call) {
3304     std::pair<FunctionBlockPtr, FunctionBlockPtr> loop_graphs;
3305     loop_graphs.first = body_block;
3306     loop_graphs.second = after_body_block;
3307     std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> loop_graphs_vec{loop_graphs};
3308     (void)parallel_call_graphs_.emplace_back(loop_graphs_vec);
3309     MS_LOG(DEBUG) << "Record tail call graphs, loop: {former: " << loop_graphs.first->func_graph()->ToString()
3310                   << ", middle: " << loop_graphs.second->func_graph()->ToString() << "}";
3311     // Record the rolled body function, for later lifting operation.
3312     if (rolled_body_call != nullptr) {
3313       (void)rolled_body_calls_.emplace_back(std::make_pair(rolled_body_call, rolled_body_block));
3314       constexpr int recursive_level = 2;
3315       MS_LOG(DEBUG) << "Record rolled body call: {CNode: " << rolled_body_call->DebugString(recursive_level)
3316                     << ", rolled_graph: " << rolled_body_block->ToString() << "}";
3317     }
3318     auto rolled_body_func_graph = rolled_body_block->func_graph();
3319     rolled_body_func_graph->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
3320     rolled_body_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE, true);
3321   }
3322 
3323   header_block->Mature();
3324   after_block->Mature();
3325   auto &end_block = loop_context.EndBlock();
3326   if (end_block) {
3327     // end_block exists if we encounter 'break' in loop body.
3328     after_block->Jump(end_block, {});
3329     end_block->Mature();
3330     return end_block;
3331   }
3332   // No 'break', no end_block.
3333   return after_block;
3334 }
3335 
ParseIfExp(const FunctionBlockPtr & block,const py::object & node)3336 AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
3337   MS_LOG(DEBUG) << "Process ast IfExp";
3338   MS_EXCEPTION_IF_NULL(block);
3339   py::object test_node = python_adapter::GetPyObjAttr(node, "test");
3340   AnfNodePtr condition_node = ParseExprNode(block, test_node);
3341 
3342   AnfNodePtr bool_node = block->ForceToCondNode(condition_node);
3343   FunctionBlockPtr true_block = nullptr;
3344   FunctionBlockPtr false_block = nullptr;
3345   MS_EXCEPTION_IF_NULL(block->func_graph());
3346   {
3347     TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
3348     true_block = MakeFunctionBlock();
3349   }
3350   {
3351     TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
3352     false_block = MakeFunctionBlock();
3353   }
3354 
3355   MakeConditionBlocks(block, true_block, false_block);
3356 
3357   // Process the if-true branch
3358   py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
3359   MS_EXCEPTION_IF_NULL(true_block->func_graph());
3360   MS_EXCEPTION_IF_NULL(true_block->func_graph()->debug_info());
3361   true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode));
3362   AnfNodePtr true_node = ParseExprNode(true_block, bodyNode);
3363 
3364   // Process the orelse branch
3365   py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
3366   MS_EXCEPTION_IF_NULL(false_block->func_graph());
3367   MS_EXCEPTION_IF_NULL(false_block->func_graph()->debug_info());
3368   false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode));
3369   AnfNodePtr false_node = ParseExprNode(false_block, orelseNode);
3370 
3371   true_block->func_graph()->set_output(true_node);
3372   false_block->func_graph()->set_output(false_node);
3373 
3374   // Use the Primitive replace the operation resolve node (switch),
3375   // because the switch will eventually be converted to Primitive node
3376   CNodePtr switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), bool_node,
3377                                                               NewValueNode(true_block->func_graph()),
3378                                                               NewValueNode(false_block->func_graph())});
3379   std::vector<AnfNodePtr> call_graph_nodes{switch_app};
3380   CNodePtr switch_app_call = block->func_graph()->NewCNodeInOrder(std::move(call_graph_nodes));
3381   return switch_app_call;
3382 }
3383 
ParseListCompIter(const FunctionBlockPtr & block,const py::object & node,const py::object & generator_node)3384 FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
3385                                            const py::object &generator_node) {
3386   // Create a header block.
3387   MS_EXCEPTION_IF_NULL(block->func_graph());
3388   FunctionBlockPtr top_block = MakeFunctionBlock(std::make_shared<TraceListComp>(block->func_graph()->debug_info()));
3389   top_block->AddPrevBlock(block);
3390   // Handle iter attribute.
3391   py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter");
3392   AnfNodePtr origin_iter_anf_node = ParseExprNode(block, iter_node);
3393   MS_EXCEPTION_IF_NULL(origin_iter_anf_node);
3394   AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
3395   MS_EXCEPTION_IF_NULL(op_iter);
3396   AnfNodePtr iter_anf_node = block->func_graph()->NewCNodeInOrder({op_iter, origin_iter_anf_node});
3397   MS_EXCEPTION_IF_NULL(iter_anf_node);
3398 
3399   // Create header graph.
3400   FunctionBlockPtr list_header_block =
3401     MakeFunctionBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
3402 
3403   // Create hasNext apply.
3404   AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
3405   MS_EXCEPTION_IF_NULL(list_header_block->func_graph());
3406   ParameterPtr iter_param = list_header_block->func_graph()->add_parameter();
3407   constexpr auto iter_param_name = "iter";
3408   iter_param->set_name(iter_param_name);
3409   MS_EXCEPTION_IF_NULL(iter_param->debug_info());
3410   iter_param->debug_info()->set_name(iter_param_name);
3411   CNodePtr cond_apply = list_header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
3412 
3413   // Call the header graph with iter.
3414   ParameterPtr list_param = list_header_block->func_graph()->add_parameter();
3415   constexpr auto list_param_name = "list";
3416   list_param->set_name(list_param_name);
3417   MS_EXCEPTION_IF_NULL(list_param->debug_info());
3418   list_param->debug_info()->set_name(list_param_name);
3419   auto empty_list = std::vector<ValuePtr>();
3420   AnfNodePtr empty_list_node = NewValueNode(std::make_shared<ValueList>(empty_list));
3421   top_block->Jump(list_header_block, {iter_anf_node, empty_list_node});
3422 
3423   // Create body graph.
3424   FunctionBlockPtr list_body_block =
3425     MakeFunctionBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
3426   list_body_block->AddPrevBlock(list_header_block);
3427   AnfNodePtr op_next = top_block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
3428   MS_EXCEPTION_IF_NULL(list_body_block->func_graph());
3429   CNodePtr next_apply = list_body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
3430   AnfNodePtr op_getitem = top_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
3431   CNodePtr item_apply =
3432     list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(0))});
3433   CNodePtr new_iter =
3434     list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(1))});
3435 
3436   // Save the `target` in a variable.
3437   py::object gen_target_node = python_adapter::GetPyObjAttr(generator_node, "target");
3438   WriteAssignVars(list_body_block, gen_target_node, item_apply);
3439 
3440   auto ifs_new_list = ParseListCompIfs(list_body_block, list_param, node, generator_node);
3441   list_body_block->Jump(list_header_block, {new_iter, ifs_new_list});
3442 
3443   // Create after graph.
3444   FunctionBlockPtr list_after_block =
3445     MakeFunctionBlock(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
3446   list_after_block->AddPrevBlock(list_header_block);
3447   // Return the list in after graph.
3448   MS_EXCEPTION_IF_NULL(list_after_block->func_graph());
3449   list_after_block->func_graph()->set_output(list_param);
3450 
3451   // Run the branches.
3452   (void)list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block);
3453 
3454   top_block->Mature();
3455   list_header_block->Mature();
3456   list_body_block->Mature();
3457   list_after_block->Mature();
3458   return top_block;
3459 }
3460 
ParseListCompIfs(const FunctionBlockPtr & list_body_block,const ParameterPtr & list_param,const py::object & node,const py::object & generator_node)3461 AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
3462                                     const py::object &node, const py::object &generator_node) {
3463   // Handle ifs attribute.
3464   py::list ifs_node = python_adapter::GetPyObjAttr(generator_node, "ifs");
3465   AnfNodePtr ifs_bool_node;
3466   if (ifs_node.empty()) {
3467     ifs_bool_node = NewValueNode(true);
3468   } else {
3469     ifs_bool_node = ProcessBoolOpValueList(list_body_block, ifs_node, AST_SUB_TYPE_AND);
3470   }
3471 
3472   // Create if-true graph.
3473   FunctionBlockPtr if_true_block =
3474     MakeFunctionBlock(std::make_shared<TraceIfStmtTrueBranch>(list_body_block->func_graph()->debug_info()));
3475   if_true_block->AddPrevBlock(list_body_block);
3476   // Handle elt attribute in body block.
3477   py::object elt_obj = python_adapter::GetPyObjAttr(node, "elt");
3478   AnfNodePtr elt_node = ParseExprNode(list_body_block, elt_obj);
3479   // Append the element.
3480   std::vector<AnfNodePtr> list_vec;
3481   AnfNodePtr make_list_op = list_body_block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
3482   (void)list_vec.emplace_back(make_list_op);
3483   (void)list_vec.emplace_back(elt_node);
3484   CNodePtr list_app = list_body_block->func_graph()->NewCNodeInOrder(std::move(list_vec));
3485   std::string add_module_name = "mindspore.ops.composite.multitype_ops.add_impl";
3486   ValuePtr add_op = prim::GetPythonOps("add", add_module_name);
3487   CNodePtr new_list = list_body_block->func_graph()->NewCNodeInOrder({NewValueNode(add_op), list_param, list_app});
3488   // Return new list in true branch graph.
3489   if_true_block->func_graph()->set_output(new_list);
3490 
3491   // Create if-false graph.
3492   FunctionBlockPtr if_false_block =
3493     MakeFunctionBlock(std::make_shared<TraceIfStmtFalseBranch>(list_body_block->func_graph()->debug_info()));
3494   if_false_block->AddPrevBlock(list_body_block);
3495   // Return original list in false branch graph.
3496   MS_EXCEPTION_IF_NULL(if_false_block->func_graph());
3497   if_false_block->func_graph()->set_output(list_param);
3498 
3499   // We don't want to create a header graph, where to get and wrap the result of Switch().
3500   // So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky.
3501   (void)list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
3502   // Output is Switch() result, i.e. updated list.
3503   auto switch_apply_node = list_body_block->func_graph()->output();
3504   auto ifs_new_list = switch_apply_node;
3505   // Since we call ConditionalJump() above, to reset the Return as null before call Jump().
3506   list_body_block->func_graph()->set_return(nullptr);
3507   if_true_block->Mature();
3508   if_false_block->Mature();
3509   return ifs_new_list;
3510 }
3511 
3512 // A ListComp contains: `elt` and `generators`.
3513 // `generators` contains: `target`, `iter` and `ifs`.
3514 // For example:
3515 // [x * x for x in range(0, 10) if x % 2 == 0]
3516 // It is compiled to be following statement:
3517 // list = []
3518 // for x in range(0, 10):
3519 //    if x % 2 == 0:
3520 //        list.append(x * x)
3521 // return list
ParseListComp(const FunctionBlockPtr & block,const py::object & node)3522 AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object &node) {
3523   MS_LOG(DEBUG) << "Process ast ListComp";
3524   MS_EXCEPTION_IF_NULL(block);
3525 
3526   // Handle generators attribute.
3527   py::list generators_node = python_adapter::GetPyObjAttr(node, "generators");
3528   if (generators_node.size() != 1) {
3529     MS_EXCEPTION(TypeError) << "The 'generators' supports 1 'comprehension' in ListComp/GeneratorExp, but got "
3530                             << generators_node.size() << " comprehensions.";
3531   }
3532   py::object generator_node = generators_node[0];
3533   auto generator_node_type = ast_->GetNodeType(generator_node);
3534   auto generator_node_name = generator_node_type->node_name();
3535   constexpr auto comprehension_name = "comprehension";
3536   if (generator_node_name != comprehension_name) {
3537     MS_LOG(INTERNAL_EXCEPTION) << "Generator node name should be " << comprehension_name << ", but got "
3538                                << generator_node_name;
3539   }
3540 
3541   // Parse ListComp's `iter` and add `elt` in it.
3542   auto top_block = ParseListCompIter(block, node, generator_node);
3543 
3544   // Call the top graph and return the list.
3545   auto call_function_node = NewValueNode(top_block->func_graph());
3546   std::vector<AnfNodePtr> func_call_nodes;
3547   func_call_nodes.push_back(call_function_node);
3548   MS_EXCEPTION_IF_NULL(block->func_graph());
3549   AnfNodePtr output = block->func_graph()->NewCNodeInOrder(std::move(func_call_nodes));
3550   return output;
3551 }
3552 
ParseDictCompIter(const FunctionBlockPtr & block,const py::object & node,const py::object & generator_node)3553 FunctionBlockPtr Parser::ParseDictCompIter(const FunctionBlockPtr &block, const py::object &node,
3554                                            const py::object &generator_node) {
3555   // Create a header block.
3556   MS_EXCEPTION_IF_NULL(block->func_graph());
3557   FunctionBlockPtr top_block = MakeFunctionBlock(std::make_shared<TraceDictComp>(block->func_graph()->debug_info()));
3558   top_block->AddPrevBlock(block);
3559   // Handle iter attribute.
3560   py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter");
3561   AnfNodePtr origin_iter_anf_node = ParseExprNode(block, iter_node);
3562   MS_EXCEPTION_IF_NULL(origin_iter_anf_node);
3563   AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
3564   MS_EXCEPTION_IF_NULL(op_iter);
3565   AnfNodePtr iter_anf_node = block->func_graph()->NewCNodeInOrder({op_iter, origin_iter_anf_node});
3566   MS_EXCEPTION_IF_NULL(iter_anf_node);
3567 
3568   // Create header graph.
3569   FunctionBlockPtr dict_header_block =
3570     MakeFunctionBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
3571   AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
3572   MS_EXCEPTION_IF_NULL(dict_header_block->func_graph());
3573   ParameterPtr iter_param = dict_header_block->func_graph()->add_parameter();
3574   constexpr auto iter_param_name = "iter";
3575   iter_param->set_name(iter_param_name);
3576   MS_EXCEPTION_IF_NULL(iter_param->debug_info());
3577   iter_param->debug_info()->set_name(iter_param_name);
3578   CNodePtr cond_apply = dict_header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
3579 
3580   // Call the header graph with iter.
3581   ParameterPtr dict_param = dict_header_block->func_graph()->add_parameter();
3582   constexpr auto dict_param_name = "dict";
3583   dict_param->set_name(dict_param_name);
3584   MS_EXCEPTION_IF_NULL(dict_param->debug_info());
3585   dict_param->debug_info()->set_name(dict_param_name);
3586   auto empty_key = std::vector<ValuePtr>();
3587   AnfNodePtr empty_key_node = NewValueNode(std::make_shared<ValueTuple>(empty_key));
3588   auto empty_value = std::vector<ValuePtr>();
3589   AnfNodePtr empty_value_node = NewValueNode(std::make_shared<ValueTuple>(empty_value));
3590   auto make_dict_op = top_block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
3591   auto empty_dict_node = top_block->func_graph()->NewCNodeInOrder({make_dict_op, empty_key_node, empty_value_node});
3592   top_block->Jump(dict_header_block, {iter_anf_node, empty_dict_node});
3593 
3594   // Create body graph.
3595   FunctionBlockPtr dict_body_block =
3596     MakeFunctionBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
3597   dict_body_block->AddPrevBlock(dict_header_block);
3598   AnfNodePtr op_next = top_block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
3599   MS_EXCEPTION_IF_NULL(dict_body_block->func_graph());
3600   CNodePtr next_apply = dict_body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
3601   AnfNodePtr op_getitem = top_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
3602   CNodePtr item_apply =
3603     dict_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(0))});
3604   CNodePtr new_iter =
3605     dict_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(1))});
3606 
3607   // Save the `target` in a variable.
3608   py::object gen_target_node = python_adapter::GetPyObjAttr(generator_node, "target");
3609   WriteAssignVars(dict_body_block, gen_target_node, item_apply);
3610 
3611   auto ifs_new_dic = ParseDictCompIfs(dict_body_block, dict_param, node, generator_node);
3612   dict_body_block->Jump(dict_header_block, {new_iter, ifs_new_dic});
3613 
3614   // Create after graph.
3615   FunctionBlockPtr dict_after_block =
3616     MakeFunctionBlock(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
3617   dict_after_block->AddPrevBlock(dict_header_block);
3618   // Return the dict in after graph.
3619   MS_EXCEPTION_IF_NULL(dict_after_block->func_graph());
3620   dict_after_block->func_graph()->set_output(dict_param);
3621 
3622   // Run the branches.
3623   (void)dict_header_block->ConditionalJump(cond_apply, dict_body_block, dict_after_block);
3624 
3625   top_block->Mature();
3626   dict_header_block->Mature();
3627   dict_body_block->Mature();
3628   dict_after_block->Mature();
3629   return top_block;
3630 }
3631 
ParseDictCompIfs(const FunctionBlockPtr & dict_body_block,const ParameterPtr & dict_param,const py::object & node,const py::object & generator_node)3632 AnfNodePtr Parser::ParseDictCompIfs(const FunctionBlockPtr &dict_body_block, const ParameterPtr &dict_param,
3633                                     const py::object &node, const py::object &generator_node) {
3634   // Handle ifs attribute.
3635   py::list ifs_node = python_adapter::GetPyObjAttr(generator_node, "ifs");
3636   AnfNodePtr ifs_bool_node;
3637   if (ifs_node.empty()) {
3638     ifs_bool_node = NewValueNode(true);
3639   } else {
3640     ifs_bool_node = ProcessBoolOpValueList(dict_body_block, ifs_node, AST_SUB_TYPE_AND);
3641   }
3642 
3643   // Create if-true graph.
3644   MS_EXCEPTION_IF_NULL(dict_body_block->func_graph());
3645   FunctionBlockPtr if_true_block =
3646     MakeFunctionBlock(std::make_shared<TraceIfStmtTrueBranch>(dict_body_block->func_graph()->debug_info()));
3647   if_true_block->AddPrevBlock(dict_body_block);
3648   // Handle key, value attribute in body block.
3649   py::object key_obj = python_adapter::GetPyObjAttr(node, "key");
3650   AnfNodePtr key_node = ParseExprNode(dict_body_block, key_obj);
3651   py::object value_obj = python_adapter::GetPyObjAttr(node, "value");
3652   AnfNodePtr value_node = ParseExprNode(dict_body_block, value_obj);
3653   // update dict.
3654   std::vector<AnfNodePtr> key_vec;
3655   std::vector<AnfNodePtr> value_vec;
3656   std::vector<AnfNodePtr> dict_vec;
3657   AnfNodePtr make_dict_op = dict_body_block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
3658   AnfNodePtr make_tuple_op = dict_body_block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
3659   (void)key_vec.emplace_back(make_tuple_op);
3660   (void)key_vec.emplace_back(key_node);
3661   CNodePtr key_app = dict_body_block->func_graph()->NewCNodeInOrder(std::move(key_vec));
3662   (void)value_vec.emplace_back(make_tuple_op);
3663   (void)value_vec.emplace_back(value_node);
3664   CNodePtr value_app = dict_body_block->func_graph()->NewCNodeInOrder(std::move(value_vec));
3665   (void)dict_vec.emplace_back(make_dict_op);
3666   (void)dict_vec.emplace_back(key_app);
3667   (void)dict_vec.emplace_back(value_app);
3668   CNodePtr dict_app = dict_body_block->func_graph()->NewCNodeInOrder(std::move(dict_vec));
3669   std::string add_module_name = "mindspore.ops.composite.multitype_ops.add_impl";
3670   ValuePtr add_op = prim::GetPythonOps("add", add_module_name);
3671   CNodePtr new_dict = dict_body_block->func_graph()->NewCNodeInOrder({NewValueNode(add_op), dict_param, dict_app});
3672   // Return new dict in true branch graph.
3673   if_true_block->func_graph()->set_output(new_dict);
3674 
3675   // Create if-false graph.
3676   FunctionBlockPtr if_false_block =
3677     MakeFunctionBlock(std::make_shared<TraceIfStmtFalseBranch>(dict_body_block->func_graph()->debug_info()));
3678   if_false_block->AddPrevBlock(dict_body_block);
3679   // Return original dict in false branch graph.
3680   MS_EXCEPTION_IF_NULL(if_false_block->func_graph());
3681   if_false_block->func_graph()->set_output(dict_param);
3682 
3683   // We don't want to create a header graph, where to get and wrap the result of Switch().
3684   // So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky.
3685   (void)dict_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
3686   // Output is Switch() result, i.e. updated dict.
3687   auto switch_apply_node = dict_body_block->func_graph()->output();
3688   auto ifs_new_dict = switch_apply_node;
3689   // Since we call ConditionalJump() above, to reset the Return as null before call Jump().
3690   dict_body_block->func_graph()->set_return(nullptr);
3691   if_true_block->Mature();
3692   if_false_block->Mature();
3693   return ifs_new_dict;
3694 }
3695 
3696 // A ListComp contains: `elt` and `generators`.
3697 // `generators` contains: `target`, `iter` and `ifs`.
3698 // For example:
3699 // {x: y for y, x in some_dict.items() if x > 1}
3700 // It is compiled to be following statement:
3701 // dict = {}
3702 // for y, x in some_dict.items():
3703 //    if x > 1:
3704 //        dict[x] = y
3705 // return dict
ParseDictComp(const FunctionBlockPtr & block,const py::object & node)3706 AnfNodePtr Parser::ParseDictComp(const FunctionBlockPtr &block, const py::object &node) {
3707   MS_LOG(DEBUG) << "Process ast DictComp";
3708   MS_EXCEPTION_IF_NULL(block);
3709 
3710   // Handle generators attribute.
3711   py::list generators_node = python_adapter::GetPyObjAttr(node, "generators");
3712   if (generators_node.size() != 1) {
3713     MS_EXCEPTION(TypeError) << "The 'generators' supports 1 'comprehension' in DictComp/GeneratorExp, but got "
3714                             << generators_node.size() << " comprehensions.";
3715   }
3716   py::object generator_node = generators_node[0];
3717   auto generator_node_type = ast_->GetNodeType(generator_node);
3718   auto generator_node_name = generator_node_type->node_name();
3719   constexpr auto comprehension_name = "comprehension";
3720   if (generator_node_name != comprehension_name) {
3721     MS_LOG(INTERNAL_EXCEPTION) << "Generator node name should be " << comprehension_name << ", but got "
3722                                << generator_node_name;
3723   }
3724 
3725   // Parse DictComp's `iter` and add `key`, `value` in it.
3726   auto top_block = ParseDictCompIter(block, node, generator_node);
3727 
3728   // Call the top graph and return the dict.
3729   auto call_function_node = NewValueNode(top_block->func_graph());
3730   std::vector<AnfNodePtr> func_call_nodes;
3731   func_call_nodes.push_back(call_function_node);
3732   MS_EXCEPTION_IF_NULL(block->func_graph());
3733   AnfNodePtr output = block->func_graph()->NewCNodeInOrder(std::move(func_call_nodes));
3734   return output;
3735 }
3736 
ParseJoinedStr(const FunctionBlockPtr & block,const py::object & node)3737 AnfNodePtr Parser::ParseJoinedStr(const FunctionBlockPtr &block, const py::object &node) {
3738   MS_LOG(DEBUG) << "Process ast JoinedStr.";
3739   TraceGuard trace_guard(GetLocation(node));
3740   MS_EXCEPTION_IF_NULL(block);
3741   const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(node));
3742   py::list py_values = python_adapter::GetPyObjAttr(node, "values");
3743   std::vector<AnfNodePtr> value_nodes{NewValueNode(prim::kPrimJoinedStr)};
3744   bool has_interpret_node = false;
3745   for (size_t i = 0; i < py_values.size(); ++i) {
3746     AnfNodePtr str_value = ParseExprNode(block, py_values[i]);
3747     // If exist interpret node in JoinedStr, all object in py_values will convert to interpret node.
3748     // Need to parse all elements in py_values in order to put them in local param dict.
3749     if (!has_interpret_node && fallback::CheckInterpretInput(str_value)) {
3750       has_interpret_node = true;
3751     }
3752     (void)value_nodes.emplace_back(str_value);
3753   }
3754   // JoinedStr can not get expr_src for their separate element. So, we convert the whole str directly.
3755   if (has_interpret_node) {
3756     return MakeInterpretNode(block, value_nodes[1], script_text);
3757   }
3758   auto func_graph = block->func_graph();
3759   MS_EXCEPTION_IF_NULL(func_graph);
3760   AnfNodePtr output = func_graph->NewCNodeInOrder(std::move(value_nodes));
3761   return output;
3762 }
3763 
ParseFormattedValue(const FunctionBlockPtr & block,const py::object & node)3764 AnfNodePtr Parser::ParseFormattedValue(const FunctionBlockPtr &block, const py::object &node) {
3765   MS_LOG(DEBUG) << "Process ast FormattedValue.";
3766   TraceGuard trace_guard(GetLocation(node));
3767   MS_EXCEPTION_IF_NULL(block);
3768   py::object value_object = python_adapter::GetPyObjAttr(node, "value");
3769   AnfNodePtr value_node = ParseExprNode(block, value_object);
3770   return value_node;
3771 }
3772 
ParseStarred(const FunctionBlockPtr & block,const py::object & node)3773 AnfNodePtr Parser::ParseStarred(const FunctionBlockPtr &block, const py::object &node) {
3774   MS_LOG(DEBUG) << "Process ast Starred.";
3775   TraceGuard trace_guard(GetLocation(node));
3776   MS_EXCEPTION_IF_NULL(block);
3777   py::object value_object = python_adapter::GetPyObjAttr(node, "value");
3778   AnfNodePtr value_node = ParseExprNode(block, value_object);
3779   AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
3780   auto func = block->func_graph();
3781   MS_EXCEPTION_IF_NULL(func);
3782   CNodePtr iterated_node = func->NewCNodeInOrder({op_iter, value_node});
3783   auto prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
3784   CNodePtr unpack_node = func->NewCNodeInOrder({NewValueNode(prim), iterated_node});
3785   return unpack_node;
3786 }
3787 
HandleAssignStarred(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & assigned_node)3788 void Parser::HandleAssignStarred(const FunctionBlockPtr &block, const py::object &target,
3789                                  const AnfNodePtr &assigned_node) {
3790   MS_EXCEPTION_IF_NULL(block);
3791   MS_EXCEPTION_IF_NULL(assigned_node);
3792   py::object value_object = python_adapter::GetPyObjAttr(target, "value");
3793   py::str name = python_adapter::GetPyObjAttr(value_object, "id");
3794   std::string name_id = name;
3795   MS_EXCEPTION_IF_NULL(assigned_node->debug_info());
3796   assigned_node->debug_info()->set_name(name_id);
3797   MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
3798   block->WriteVariable(name_id, assigned_node);
3799 }
3800 
HandleAssignName(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & assigned_node) const3801 void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target,
3802                               const AnfNodePtr &assigned_node) const {
3803   MS_EXCEPTION_IF_NULL(block);
3804   MS_EXCEPTION_IF_NULL(assigned_node);
3805   py::str name = python_adapter::GetPyObjAttr(target, "id");
3806   std::string name_id = name;
3807 
3808   MS_EXCEPTION_IF_NULL(assigned_node->debug_info());
3809   assigned_node->debug_info()->set_name(name_id);
3810   // Set the debug name of the constant graph
3811   if (IsValueNode<FuncGraph>(assigned_node)) {
3812     // The value should be graph
3813     auto fg = GetValueNode<FuncGraphPtr>(assigned_node);
3814     MS_EXCEPTION_IF_NULL(fg->debug_info());
3815     if (fg->debug_info()->name().empty()) {
3816       fg->debug_info()->set_name(name_id);
3817     }
3818   }
3819   MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
3820   block->WriteVariable(name_id, assigned_node);
3821 }
3822 
HandleAssignTupleWithStarredExpression(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & assigned_node,const std::vector<int64_t> & positions)3823 void Parser::HandleAssignTupleWithStarredExpression(const FunctionBlockPtr &block, const py::object &target,
3824                                                     const AnfNodePtr &assigned_node,
3825                                                     const std::vector<int64_t> &positions) {
3826   // Process assigned_node
3827   auto func = block->func_graph();
3828   MS_EXCEPTION_IF_NULL(func);
3829   AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
3830   CNodePtr iterated_node = func->NewCNodeInOrder({op_iter, assigned_node});
3831   auto starred_unpack_prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
3832   CNodePtr unpack_node = func->NewCNodeInOrder({NewValueNode(starred_unpack_prim), iterated_node});
3833 
3834   py::list items = python_adapter::GetPyObjAttr(target, "elts");
3835   for (size_t i = 0; i < items.size(); i++) {
3836     py::object elt = items[i];
3837     auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elt)));
3838     if (elt_type != AST_SUB_TYPE_STARRED) {
3839       std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
3840       ValuePtr op = prim::GetPythonOps("getitem", module_name);
3841       std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(op), unpack_node, NewValueNode(positions[i])};
3842       AnfNodePtr tuple_get_item = func->NewCNodeInOrder(tuple_get_item_inputs);
3843       MS_LOG(DEBUG) << "Assign name: `" << py::str(elt) << "` to node: " << tuple_get_item->DebugString();
3844       WriteAssignVars(block, elt, tuple_get_item);
3845     } else {
3846       auto starred_get_item_prim = std::make_shared<prim::StarredGetItem>(NAMED_METAGRAPH_STARRED_GET_ITEM);
3847       std::vector<AnfNodePtr> starred_get_item_inputs{NewValueNode(starred_get_item_prim), unpack_node,
3848                                                       NewValueNode(positions[i]),
3849                                                       NewValueNode(SizeToLong(items.size()))};
3850       AnfNodePtr starred_get_item = func->NewCNodeInOrder(starred_get_item_inputs);
3851       MS_LOG(DEBUG) << "Assign name: `" << py::str(elt) << "` to node: " << starred_get_item->DebugString();
3852       WriteAssignVars(block, elt, starred_get_item);
3853     }
3854   }
3855 }
3856 
HandleAssignTupleOrList(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & assigned_node)3857 void Parser::HandleAssignTupleOrList(const FunctionBlockPtr &block, const py::object &target,
3858                                      const AnfNodePtr &assigned_node) {
3859   MS_EXCEPTION_IF_NULL(block);
3860   AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
3861   py::list items = python_adapter::GetPyObjAttr(target, "elts");
3862 
3863   // Record the position with targets.
3864   size_t target_starred_num = 0;
3865   size_t starred_pos = items.size();
3866   std::vector<int64_t> positions(items.size(), 0);
3867   for (size_t i = 0; i < items.size(); i++) {
3868     py::object elt = items[i];
3869     auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elt)));
3870     if (elt_type == AST_SUB_TYPE_STARRED) {
3871       target_starred_num++;
3872       if (target_starred_num > 1) {
3873         MS_LOG(EXCEPTION) << "SyntaxError: " << target_starred_num << " starred expressions in assignment.";
3874       }
3875       starred_pos = i;
3876       positions[i] = i;
3877     } else {
3878       if (i > starred_pos) {
3879         positions[i] = i - items.size();
3880       } else {
3881         positions[i] = i;
3882       }
3883     }
3884   }
3885   auto func = block->func_graph();
3886   MS_EXCEPTION_IF_NULL(func);
3887 
3888   if (target_starred_num == 0) {
3889     for (size_t i = 0; i < items.size(); i++) {
3890       // Use the Primitive replace the operation resolve node (getitem),
3891       // because the getitem will eventually be converted to Primitive node
3892       CNodePtr item_apply = func->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast<int64_t>(i))});
3893       py::object elt = items[i];
3894       WriteAssignVars(block, elt, item_apply);
3895     }
3896     return;
3897   }
3898 
3899   // Process AssignTuple with starred expression.
3900   // a, *b, c = x    // targets(a, *b, c) = assign(x)
3901   HandleAssignTupleWithStarredExpression(block, target, assigned_node, positions);
3902 }
3903 
IsClassParameterMember(const py::object & target_obj,const AnfNodePtr & target_node) const3904 bool Parser::IsClassParameterMember(const py::object &target_obj, const AnfNodePtr &target_node) const {
3905   auto attr_name = target_obj.attr("attr").cast<std::string>();
3906   if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
3907     return false;
3908   }
3909 
3910   auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
3911   return (py::hasattr(obj, "__parameter__"));
3912 }
3913 
HandleAssignClassParameterMember(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & value_node)3914 bool Parser::HandleAssignClassParameterMember(const FunctionBlockPtr &block, const py::object &target,
3915                                               const AnfNodePtr &value_node) {
3916   MS_EXCEPTION_IF_NULL(block);
3917   // Now only support the self.xx = xxxxx, can't support x.y = xxxx
3918   AnfNodePtr target_node = ParseExprNode(block, target);
3919   if (target_node == nullptr) {
3920     return false;
3921   }
3922 
3923   if (!IsClassParameterMember(target, target_node)) {
3924     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
3925     if (!allow_fallback_runtime) {
3926       auto attr_name = target.attr("attr").cast<std::string>();
3927       std::string var_name = "self." + attr_name;
3928       auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
3929       auto obj_type = obj.attr("__class__").attr("__name__");
3930       MS_EXCEPTION(TypeError) << "In JIT strict mode, if need to modify a member attribute of a class with " << var_name
3931                               << ", the member attribute must be of the Parameter type. But got '"
3932                               << py::str(obj).cast<std::string>() << "' with type '"
3933                               << py::str(obj_type).cast<std::string>()
3934                               << "'. You can use os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' "
3935                               << "to enable the JIT lax mode to support the current syntax.\n\n"
3936                               << trace::GetDebugInfoStr(target_node->debug_info());
3937     }
3938     MS_LOG(DEBUG) << "Erase unused node: " << target_node->DebugString();
3939     block->func_graph()->EraseUnusedNodeInOrder(target_node);
3940     return false;
3941   }
3942   block->SetStateAssign(target_node, value_node);
3943   return true;
3944 }
3945 
MakeSetAttrNode(const FunctionBlockPtr & block,const AnfNodePtr & target_node,const AnfNodePtr & value_node,const std::string & target_id_str,const std::string & attr_str)3946 void Parser::MakeSetAttrNode(const FunctionBlockPtr &block, const AnfNodePtr &target_node, const AnfNodePtr &value_node,
3947                              const std::string &target_id_str, const std::string &attr_str) {
3948   std::vector<AnfNodePtr> setattr_node_inputs{NewValueNode(prim::kPrimSetAttr)};
3949   (void)setattr_node_inputs.emplace_back(target_node);
3950   (void)setattr_node_inputs.emplace_back(NewValueNode(attr_str));
3951   (void)setattr_node_inputs.emplace_back(value_node);
3952   auto fg = block->func_graph();
3953   MS_EXCEPTION_IF_NULL(fg);
3954   auto setattr_node = fg->NewCNodeInOrder(setattr_node_inputs);
3955 
3956   // Update setattr_nodes_map.
3957   auto iter = setattr_nodes_map_.find(target_id_str);
3958   if (iter == setattr_nodes_map_.end()) {
3959     auto attr_map = std::map<std::string, AnfNodePtr>();
3960     (void)attr_map.emplace(std::make_pair(attr_str, setattr_node));
3961     (void)setattr_nodes_map_.emplace(std::make_pair(target_id_str, attr_map));
3962   } else {
3963     // If found setattr node before, set it as new setattr node's input.
3964     auto iter_attr = iter->second.find(attr_str);
3965     if (iter_attr != iter->second.end()) {
3966       auto prev_node = iter_attr->second;
3967       MS_EXCEPTION_IF_NULL(prev_node);
3968       auto prev_node_fg = prev_node->func_graph();
3969       MS_EXCEPTION_IF_NULL(prev_node_fg);
3970       if (prev_node_fg == fg) {
3971         // Only add to new setattr node input if two nodes is in the same graph.
3972         setattr_node->add_input(iter_attr->second);
3973       }
3974     }
3975     // Force update the setattr node to keep the newest one.
3976     (void)iter->second.insert_or_assign(attr_str, setattr_node);
3977   }
3978   block->AddIsolatedNode(setattr_node);
3979 }
3980 
HandleAssignClassMember(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & value_node)3981 void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target,
3982                                      const AnfNodePtr &value_node) {
3983   MS_EXCEPTION_IF_NULL(block);
3984   const py::object target_obj = python_adapter::GetPyObjAttr(target, "value");
3985   TraceGuard trace_guard(GetLocation(target_obj));
3986   std::string target_id_str;
3987   AnfNodePtr target_node = nullptr;
3988   auto node_type = ast()->GetNodeType(target_obj);
3989   const std::string &node_type_name = node_type->node_name();
3990   MS_LOG(DEBUG) << "node_type_name: " << node_type_name << ", target: " << py::str(target);
3991   if (node_type_name == "Attribute") {
3992     // Prepare for setattr with nested getattr target, parse getattr firstly.
3993     target_node = ParseExprNode(block, target_obj);
3994     target_id_str = GetLocation(target_obj)->expr_src();
3995   } else if (node_type_name == "Call") {
3996     // Prepare for setattr with nested 'getattr' call target, parse 'getattr' call firstly.
3997     target_node = ParseExprNode(block, target_obj);
3998     target_id_str = GetLocation(target_obj)->expr_src();
3999   } else if (node_type_name == "Name") {
4000     if (!py::hasattr(target_obj, "id")) {
4001       MS_LOG(INTERNAL_EXCEPTION) << "Wrong ast, target: " << target;
4002     }
4003     const py::object id_obj = python_adapter::GetPyObjAttr(target_obj, "id");
4004     target_id_str = id_obj.cast<std::string>();
4005     if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE && target_id_str == "self") {
4006       const auto &interpreted_obj = std::make_shared<InterpretedObject>(ast()->obj());
4007       target_node = NewValueNode(interpreted_obj);
4008     } else {
4009       py::object setattr_obj;
4010       try {
4011         py::tuple namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, target_id_str);
4012         constexpr size_t value_index = 2;
4013         setattr_obj = namespace_info[value_index];
4014       } catch (const std::exception &e) {
4015         MS_LOG(DEBUG) << target_id_str << " is not supported in JIT Fallback. Original steps are processing instead.";
4016         setattr_obj = py::none();
4017       }
4018       // convert target node in setattr to convert getattr after it later.
4019       if (!py::isinstance<py::none>(setattr_obj)) {
4020         const auto &interpreted_obj = std::make_shared<InterpretedObject>(setattr_obj);
4021         target_node = NewValueNode(interpreted_obj);
4022       } else {
4023         target_node = ParseExprNode(block, target_obj);
4024       }
4025     }
4026   }
4027   if (target_node == nullptr) {
4028     MS_LOG(EXCEPTION) << "In graph mode, only attribute and name of class members can be assigned. But got "
4029                       << node_type_name << ".";
4030   }
4031   const auto &attr_str = python_adapter::GetPyObjAttr(target, "attr").cast<std::string>();
4032   MS_LOG(DEBUG) << "target node: " << target_node->DebugString() << ", target name: " << target_id_str
4033                 << ", attr: " << attr_str;
4034 
4035   MakeSetAttrNode(block, target_node, value_node, target_id_str, attr_str);
4036 }
4037 
MakeSetitemNode(const FunctionBlockPtr & block,const py::object & value_obj,const py::object & slice_obj,const AnfNodePtr & assigned_node,const AnfNodePtr & value_node)4038 CNodePtr Parser::MakeSetitemNode(const FunctionBlockPtr &block, const py::object &value_obj,
4039                                  const py::object &slice_obj, const AnfNodePtr &assigned_node,
4040                                  const AnfNodePtr &value_node) {
4041   AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
4042   auto value_id = python_adapter::GetPyObjAttr(value_obj, "id");
4043   AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
4044   auto str_setitem = std::make_shared<StringImm>("__setitem__");
4045   if (!py::isinstance<py::none>(value_id)) {
4046     py::object value_obj = GetValuePythonObject(value_id);
4047     if (!py::isinstance<py::none>(value_obj)) {
4048       AnfNodePtr setitem_node =
4049         block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), value_node, NewValueNode(str_setitem)});
4050       setitem_node->set_user_data<py::object>("__setitem__", std::make_shared<py::object>(value_obj));
4051       return block->func_graph()->NewCNodeInOrder({setitem_node, slice_node, assigned_node});
4052     }
4053   }
4054   return block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
4055 }
4056 
HandleAssignSubscript(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & assigned_node)4057 void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target,
4058                                    const AnfNodePtr &assigned_node) {
4059   MS_EXCEPTION_IF_NULL(block);
4060   py::object value_obj = python_adapter::GetPyObjAttr(target, "value");
4061   py::object slice_obj = python_adapter::GetPyObjAttr(target, "slice");
4062   AnfNodePtr value_node = ParseExprNode(block, value_obj);
4063   MS_EXCEPTION_IF_NULL(block->func_graph());
4064   auto setitem_app = MakeSetitemNode(block, value_obj, slice_obj, assigned_node, value_node);
4065   // Getitem apply should return the sequence data structure itself
4066   std::string var_name;
4067   if (ast_->IsClassMemberOfSelf(value_obj)) {
4068     auto attr_name = value_obj.attr("attr").cast<std::string>();
4069     var_name = "self." + attr_name;
4070     if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
4071       MS_EXCEPTION(TypeError) << "'" << var_name
4072                               << "' should be initialized in the '__init__' function before subscript.\n\n"
4073                               << trace::GetDebugInfoStr(value_node->debug_info());
4074     }
4075     auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
4076     if (py::hasattr(obj, PYTHON_CELL_AS_LIST) || py::hasattr(obj, PYTHON_CELL_AS_DICT)) {
4077       MS_EXCEPTION(TypeError) << "CellList or CellDict object " << py::str(obj).cast<std::string>()
4078                               << " is not support to do assign subscript.";
4079     }
4080     const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
4081     if (!allow_fallback_runtime) {
4082       if (!py::hasattr(obj, "__parameter__")) {
4083         auto obj_type = obj.attr("__class__").attr("__name__");
4084         MS_EXCEPTION(TypeError) << "When JIT_SYNTAX_LEVEL is not set to LAX" << var_name
4085                                 << " should be initialized as a 'Parameter' in the '__init__' function"
4086                                 << " to perform assign subscript, but got: " << py::str(obj).cast<std::string>()
4087                                 << "' with type '" << py::str(obj_type).cast<std::string>() << "'.\n\n"
4088                                 << trace::GetDebugInfoStr(value_node->debug_info());
4089       }
4090     }
4091     block->WriteVariable(var_name, setitem_app);
4092     return;
4093   }
4094   if (AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, value_obj))) ==
4095       AST_SUB_TYPE_SUBSCRIPT) {
4096     if (IsSubscriptReferenceType(value_obj)) {
4097       HandleAssignSubscript(block, value_obj, setitem_app);
4098       return;
4099     }
4100   }
4101   if (py::hasattr(value_obj, "id")) {
4102     var_name = value_obj.attr("id").cast<std::string>();
4103   }
4104   block->WriteVariable(var_name, setitem_app);
4105 }
4106 
WriteAssignVars(const FunctionBlockPtr & block,const py::object & target_object,const AnfNodePtr & value_node)4107 void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object,
4108                              const AnfNodePtr &value_node) {
4109   MS_EXCEPTION_IF_NULL(value_node);
4110   auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object)));
4111   MS_LOG(DEBUG) << "target_object: " << target_object << ", value_node: " << value_node->DebugString()
4112                 << ", ast_type: " << ast_type;
4113   if (ast_type == AST_SUB_TYPE_NAME) {
4114     HandleAssignName(block, target_object, value_node);
4115   } else if (ast_type == AST_SUB_TYPE_TUPLE || ast_type == AST_SUB_TYPE_LIST) {
4116     HandleAssignTupleOrList(block, target_object, value_node);
4117   } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
4118     HandleAssignSubscript(block, target_object, value_node);
4119   } else if (ast_->IsClassMemberOfSelf(target_object)) {
4120     if (HandleAssignClassParameterMember(block, target_object, value_node)) {
4121       return;
4122     }
4123     HandleAssignClassMember(block, target_object, value_node);
4124   } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
4125     HandleAssignClassMember(block, target_object, value_node);
4126   } else if (ast_type == AST_SUB_TYPE_STARRED) {
4127     HandleAssignStarred(block, target_object, value_node);
4128   } else {
4129     TraceGuard trace_guard(GetLocation(target_object));
4130     MS_EXCEPTION(TypeError) << "Only supported augassign to attribute of self, variable and index value, but got "
4131                             << target_object.get_type()
4132                             << ".\nMore details please refer to syntax support at https://www.mindspore.cn";
4133   }
4134 }
4135 
IsScriptInParams(const std::string & script_text,const py::dict & global_dict,const std::map<std::string,AnfNodePtr> & local_keys,const FuncGraphPtr & func_graph) const4136 bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
4137                               const std::map<std::string, AnfNodePtr> &local_keys,
4138                               const FuncGraphPtr &func_graph) const {
4139   MS_EXCEPTION_IF_NULL(func_graph);
4140   // Check global parameters.
4141   if (global_dict.contains(script_text)) {
4142     MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in global params.";
4143     return true;
4144   }
4145 
4146   // Check local parameters.
4147   if (local_keys.find(script_text) != local_keys.end()) {
4148     MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in local params.";
4149     return true;
4150   }
4151   return false;
4152 }
4153 
HandleInterpret(const FunctionBlockPtr & block,const AnfNodePtr & value_node,const py::object & value_object)4154 AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
4155                                    const py::object &value_object) {
4156   MS_EXCEPTION_IF_NULL(value_node);
4157   if (!value_node->interpret()) {
4158     return value_node;
4159   }
4160   const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
4161   return MakeInterpretNode(block, value_node, script_text);
4162 }
4163 
CheckNeedConvertInterpret(const FunctionBlockPtr & block,const AnfNodePtr & node,const string & script_text) const4164 bool Parser::CheckNeedConvertInterpret(const FunctionBlockPtr &block, const AnfNodePtr &node,
4165                                        const string &script_text) const {
4166   MS_EXCEPTION_IF_NULL(block);
4167   MS_EXCEPTION_IF_NULL(node);
4168   // Check if script_text is in global/local params.
4169   const py::dict &global_dict = block->global_py_params();
4170   auto keys = std::get<0>(block->local_py_params());
4171   if (IsScriptInParams(script_text, global_dict, keys, block->func_graph())) {
4172     return false;
4173   }
4174   return true;
4175 }
4176 
GetSubStrNum(const string & script_text,const string & sub)4177 size_t GetSubStrNum(const string &script_text, const string &sub) {
4178   size_t count = 0;
4179   size_t pos = script_text.find(sub);
4180   while (pos != string::npos) {
4181     count++;
4182     pos = script_text.find(sub, pos + 1);
4183   }
4184   return count;
4185 }
4186 
UpdateString(const string & str)4187 std::string UpdateString(const string &str) {
4188   string temp = "";
4189   std::string new_string = "";
4190   for (size_t i = 0; i < str.length(); i++) {
4191     if (str[i] != '\n') {
4192       temp += str[i];
4193     } else {
4194       if (temp[temp.length() - 1] != '\\') {
4195         temp += "+\\";
4196       }
4197       auto pos = temp.find_first_not_of(" ");
4198       temp = temp.substr(pos) + '\n';
4199       new_string += temp;
4200       temp = "";
4201     }
4202   }
4203   auto pos = temp.find_first_not_of(" ");
4204   new_string += temp.substr(pos);
4205   return new_string;
4206 }
4207 
ProcessIndentationInScript(const string & script_text)4208 string ProcessIndentationInScript(const string &script_text) {
4209   const size_t f_string_num = 2;
4210   size_t num1 = GetSubStrNum(script_text, "f'");
4211   size_t num2 = GetSubStrNum(script_text, "f\"");
4212   if (script_text.find("\n") == string::npos || num1 + num2 < f_string_num) {
4213     return script_text;
4214   }
4215   return UpdateString(script_text);
4216 }
4217 
MakeInterpretNode(const FunctionBlockPtr & block,const AnfNodePtr & value_node,const string & script_text)4218 AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
4219                                      const string &script_text) {
4220   MS_EXCEPTION_IF_NULL(block);
4221   MS_EXCEPTION_IF_NULL(value_node);
4222   if (!CheckNeedConvertInterpret(block, value_node, script_text)) {
4223     return value_node;
4224   }
4225   string new_script_text = ProcessIndentationInScript(script_text);
4226 
4227   // Prepare global parameters.
4228   PyObjectWrapperPtr interpreted_global_dict = std::make_shared<InterpretedObject>(block->global_py_params());
4229   auto global_dict_node = NewValueNode(interpreted_global_dict);
4230   // Prepare local parameters. Select the needed local parameters for script.
4231   auto [keys, values] = block->local_py_params();
4232   std::vector<AnfNodePtr> filter_keys;
4233   std::vector<AnfNodePtr> filter_values;
4234   try {
4235     const py::set &ids = data_converter::GetPythonScriptIdAttrs(py::str(new_script_text));
4236     for (const auto &id : ids) {
4237       const auto &id_str = py::str(id);
4238       const auto &iter = values.find(id_str);
4239       if (iter != values.end()) {
4240         (void)filter_keys.emplace_back(keys[iter->first]);
4241         auto &val_node = iter->second;
4242         // '__py_interpret_local_value_flag__' is used by 'ConvertInterpretedObjForResolve' not to convert PyExecute.
4243         val_node->set_user_data("__py_interpret_local_value_flag__", std::make_shared<bool>(true));
4244         (void)filter_values.emplace_back(val_node);
4245       }
4246     }
4247   } catch (const std::exception &e) {
4248     MS_LOG(INTERNAL_EXCEPTION) << "GetPythonScriptIdAttrs failed, script: " << py::str(new_script_text) << ".\n"
4249                                << e.what();
4250   }
4251   constexpr auto self_text = "self";
4252   if (keys.find(self_text) == keys.end() && new_script_text.find(self_text) != std::string::npos) {
4253     py::object self_namespace = ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast()->obj());
4254     auto self_value = std::make_shared<InterpretedObject>(self_namespace);
4255     (void)filter_keys.emplace_back(NewValueNode(MakeValue(self_text)));
4256     (void)filter_values.emplace_back(NewValueNode(self_value));
4257   }
4258 
4259   auto local_dict_node = ParseDictByKeysAndValues(block, filter_keys, filter_values);
4260   // Update the valued node if it need interpreting.
4261   constexpr int recursive_level = 2;
4262   MS_EXCEPTION_IF_NULL(block->func_graph());
4263   AnfNodePtr interpreted_node = block->MakeInterpret(new_script_text, global_dict_node, local_dict_node, value_node);
4264   MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: `" << new_script_text
4265                << "`,\nvalue_node: " << value_node->DebugString(recursive_level)
4266                << ",\nglobal_dict_node: " << global_dict_node->ToString()
4267                << ",\nlocal_dict_node: " << local_dict_node->DebugString(recursive_level)
4268                << ",\ninterpreted_node: " << interpreted_node->DebugString(recursive_level);
4269 
4270   // Print a hint for user.
4271   auto line_info = trace::GetDebugInfoStr(value_node->debug_info());
4272   MS_LOG(INFO) << "Found unsupported syntax in graph mode, those codes would be fallen back to Python interpreter:"
4273                << "\n\n"
4274                << line_info;
4275   InterpretNodeRecorder::GetInstance().PushPyInterpretNode(value_node);
4276   return interpreted_node;
4277 }
4278 
IsPopOperation(const AnfNodePtr & node) const4279 bool Parser::IsPopOperation(const AnfNodePtr &node) const {
4280   auto cnode = node->cast<CNodePtr>();
4281   if (cnode == nullptr) {
4282     return false;
4283   }
4284   auto attr_node = cnode->input(0);
4285   if (IsPrimitiveCNode(attr_node, prim::kPrimGetAttr)) {
4286     auto attr_cnode = attr_node->cast<CNodePtr>();
4287     MS_EXCEPTION_IF_NULL(attr_cnode);
4288     constexpr size_t attr_cnode_size = 3;
4289     constexpr size_t member_index = 2;
4290     if (attr_cnode->size() < attr_cnode_size) {
4291       MS_LOG(EXCEPTION) << "The attr operate has wrong input.";
4292     }
4293     auto member_node = attr_cnode->input(member_index);
4294     if (IsValueNode<StringImm>(member_node)) {
4295       auto attr_name = GetValue<std::string>(GetValueNode(member_node));
4296       if (attr_name == "pop") {
4297         return true;
4298       }
4299     }
4300   }
4301   return false;
4302 }
4303 
ProcessPopOperation(const FunctionBlockPtr & block,const AnfNodePtr & value_node,const py::object & target_object)4304 void Parser::ProcessPopOperation(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
4305                                  const py::object &target_object) {
4306   auto func_graph = block->func_graph();
4307   MS_EXCEPTION_IF_NULL(func_graph);
4308   auto new_list =
4309     func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), value_node, NewValueNode(SizeToLong(0))});
4310   auto pop_node =
4311     func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), value_node, NewValueNode(SizeToLong(1))});
4312   if (!(ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE && ast_->IsClassMemberOfSelf(list_pop_target_obj_))) {
4313     WriteAssignVars(block, list_pop_target_obj_, new_list);
4314   }
4315   WriteAssignVars(block, target_object, pop_node);
4316 }
4317 
ProcessPopOperationInAugAssign(const FunctionBlockPtr & block,const AnfNodePtr & value_node,const AnfNodePtr & target_node,const AnfNodePtr & op_node,const py::object & target_object)4318 void Parser::ProcessPopOperationInAugAssign(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
4319                                             const AnfNodePtr &target_node, const AnfNodePtr &op_node,
4320                                             const py::object &target_object) {
4321   auto func_graph = block->func_graph();
4322   MS_EXCEPTION_IF_NULL(func_graph);
4323   auto new_list =
4324     func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), value_node, NewValueNode(SizeToLong(0))});
4325   auto pop_node =
4326     func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), value_node, NewValueNode(SizeToLong(1))});
4327   if (!(ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE && ast_->IsClassMemberOfSelf(list_pop_target_obj_))) {
4328     WriteAssignVars(block, list_pop_target_obj_, new_list);
4329   }
4330   AnfNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, pop_node});
4331   WriteAssignVars(block, target_object, augassign_app);
4332 }
4333 
4334 // Process a assign statement, such as a = b,  a, b = tuple(xx, xx)
ParseAssign(const FunctionBlockPtr & block,const py::object & node)4335 FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
4336   MS_LOG(DEBUG) << "Process ast assign";
4337   py::object value_object = python_adapter::GetPyObjAttr(node, "value");
4338   py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
4339 
4340   AnfNodePtr value_node = ParseExprNode(block, value_object);
4341 
4342   py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
4343   size_t count = LongToSize(pcount);
4344   MS_LOG(DEBUG) << "The nodes count is " << count;
4345 
4346   // b = list_x.pop(a)
4347   // -->  list_x, b = list_x.pop(a) need renew the list_x.
4348   if (IsPopOperation(value_node)) {
4349     auto pop_obj = py::cast<py::list>(targets_object)[0];
4350     ProcessPopOperation(block, value_node, pop_obj);
4351     return block;
4352   }
4353   for (size_t i = 0; i < count; i++) {
4354     auto target_node = py::cast<py::list>(targets_object)[i];
4355     WriteAssignVars(block, target_node, value_node);
4356   }
4357 
4358   return block;
4359 }
4360 
4361 // Process a annassign statement, such as a:int = 1
4362 // target may be one of Name, Attribute, Subscript.
ParseAnnAssign(const FunctionBlockPtr & block,const py::object & node)4363 FunctionBlockPtr Parser::ParseAnnAssign(const FunctionBlockPtr &block, const py::object &node) {
4364   MS_LOG(DEBUG) << "Process ast annassign";
4365   py::object value_object = python_adapter::GetPyObjAttr(node, "value");
4366   py::object target_object = python_adapter::GetPyObjAttr(node, "target");
4367   AnfNodePtr value_node = ParseExprNode(block, value_object);
4368   // b: int = list_x.pop(a)
4369   // -->  list_x, b = list_x.pop(a) need renew the list_x.
4370   if (IsPopOperation(value_node)) {
4371     ProcessPopOperation(block, value_node, target_object);
4372     return block;
4373   }
4374   WriteAssignVars(block, target_object, value_node);
4375   return block;
4376 }
4377 
ParseBreak(const FunctionBlockPtr & block,const py::object & node)4378 FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) {
4379   if (loops_.empty()) {
4380     // Report error if loop context not set for the 'break' statement.
4381     MS_LOG(INTERNAL_EXCEPTION) << "Unexpected 'break'.";
4382   }
4383   // Get current loop.
4384   Loop &loop = loops_.top();
4385   if (loop.end == nullptr) {
4386     // Create end_block if it is not existed.
4387     MS_EXCEPTION_IF_NULL(block->func_graph());
4388     TraceGuard trace_guard(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
4389     loop.end = MakeFunctionBlock();
4390   }
4391   block->set_break_continue_statement_inside();
4392   MS_LOG(DEBUG) << "Inside the block has break statement, block: " << block->ToString();
4393 
4394   // Jump to the end_block.
4395   block->Jump(loop.end, {});
4396   return block;
4397 }
4398 
ParseContinue(const FunctionBlockPtr & block,const py::object & node)4399 FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) {
4400   if (loops_.empty()) {
4401     // Report error if loop context not set for the 'continue' statement.
4402     MS_LOG(INTERNAL_EXCEPTION) << "Unexpected 'continue'.";
4403   }
4404   // Jump to the header of the loop with iterator called.
4405   Loop &loop = loops_.top();
4406   std::vector<AnfNodePtr> args;
4407   if (loop.iterator != nullptr) {
4408     (void)args.emplace_back(loop.iterator);
4409   }
4410   block->set_break_continue_statement_inside();
4411   MS_LOG(DEBUG) << "Inside the block has continue statement, block: " << block->ToString();
4412 
4413   block->Jump(loop.header, args);
4414   return block;
4415 }
4416 
ParsePass(const FunctionBlockPtr & block,const py::object & node)4417 FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) {
4418   // We just bypass 'pass' statement.
4419   return block;
4420 }
4421 
ParseRaise(const FunctionBlockPtr & block,const py::object & node)4422 FunctionBlockPtr Parser::ParseRaise(const FunctionBlockPtr &block, const py::object &node) {
4423   MS_LOG(DEBUG) << "Process raise statement";
4424   TraceGuard trace_guard(GetLocation(node));
4425   MS_EXCEPTION_IF_NULL(block);
4426   auto func_graph = block->func_graph();
4427   MS_EXCEPTION_IF_NULL(func_graph);
4428   py::object exc_ast_node = python_adapter::GetPyObjAttr(node, "exc");
4429   // raise
4430   if (py::isinstance<py::none>(exc_ast_node)) {
4431     CNodePtr raise_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimRaise)});
4432     func_graph->set_return(raise_node);
4433     return block;
4434   }
4435   auto exc_node_inputs = ParseRaiseCall(block, exc_ast_node);
4436   // raise ExceptionType or raise ExceptionType(ExceptionString)
4437   std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimRaise)};
4438   (void)inputs.insert(inputs.end(), exc_node_inputs.begin(), exc_node_inputs.end());
4439   CNodePtr raise_node = func_graph->NewCNodeInOrder(inputs);
4440   CNodePtr return_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), raise_node});
4441   func_graph->set_return(return_node);
4442   return block;
4443 }
4444 
MakeAssertErrorBlock(const FunctionBlockPtr & block,const py::object & node)4445 FunctionBlockPtr Parser::MakeAssertErrorBlock(const FunctionBlockPtr &block, const py::object &node) {
4446   MS_LOG(DEBUG) << "Process make AssertError block";
4447   MS_EXCEPTION_IF_NULL(block);
4448   const std::string kAssertionError = "AssertionError";
4449   std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimRaise), NewValueNode(kAssertionError)};
4450 
4451   py::object msg_node = python_adapter::GetPyObjAttr(node, "msg");
4452   if (!py::isinstance<py::none>(msg_node)) {
4453     AnfNodePtr msg = ParseExprNode(block, msg_node);
4454     (void)inputs.emplace_back(msg);
4455   }
4456   auto str_none = std::make_shared<StringImm>("None");
4457   (void)inputs.emplace_back(NewValueNode(str_none));
4458 
4459   auto func_graph = block->func_graph();
4460   CNodePtr raise_node = func_graph->NewCNodeInOrder(inputs);
4461   CNodePtr return_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), raise_node});
4462   func_graph->set_return(return_node);
4463   return block;
4464 }
4465 
4466 // assert expression [, arguments]
4467 // =>
4468 // if not expression:
4469 //     raise AssertionError(arguments)
ParseAssert(const FunctionBlockPtr & block,const py::object & node)4470 FunctionBlockPtr Parser::ParseAssert(const FunctionBlockPtr &block, const py::object &node) {
4471   MS_LOG(DEBUG) << "Process ast Assert";
4472   MS_EXCEPTION_IF_NULL(block);
4473   py::object test_node = python_adapter::GetPyObjAttr(node, "test");
4474   AnfNodePtr condition_node = ParseExprNode(block, test_node);
4475 
4476   AnfNodePtr bool_node = block->ForceToCondNode(condition_node);
4477   TraceGuard guard(std::make_shared<TraceAssert>(block->func_graph()->debug_info()));
4478   TraceGuard location_guard(GetLocation(node));
4479   FunctionBlockPtr true_block = MakeFunctionBlock();
4480   FunctionBlockPtr false_block = MakeFunctionBlock();
4481   FunctionBlockPtr after_block = MakeFunctionBlock();
4482   MakeConditionBlocks(block, true_block, false_block);
4483 
4484   true_block->Jump(after_block, {});
4485   false_block = MakeAssertErrorBlock(false_block, node);
4486   (void)block->ConditionalJump(bool_node, true_block, false_block);
4487 
4488   after_block->Mature();
4489   return after_block;
4490 }
4491 
ParseWithitem(const FunctionBlockPtr & block,const py::object & node,const AnfNodePtr & context_expr_node)4492 AnfNodePtr Parser::ParseWithitem(const FunctionBlockPtr &block, const py::object &node,
4493                                  const AnfNodePtr &context_expr_node) {
4494   MS_LOG(DEBUG) << "Process ast Withitem";
4495   MS_EXCEPTION_IF_NULL(block);
4496   // Handle __enter__(self)
4497   std::vector<AnfNodePtr> enter_inputs{NewValueNode(prim::kPrimWithEnter), context_expr_node};
4498   auto func_graph = block->func_graph();
4499   MS_EXCEPTION_IF_NULL(func_graph);
4500   AnfNodePtr enter_node = func_graph->NewCNodeInOrder(enter_inputs);
4501   py::object optional_vars_obj = python_adapter::GetPyObjAttr(node, "optional_vars");
4502   if (!py::isinstance<py::none>(optional_vars_obj)) {
4503     // with Sample() as sample: mean that sample = Sample()
4504     WriteAssignVars(block, optional_vars_obj, enter_node);
4505   }
4506   return enter_node;
4507 }
4508 
4509 // with expression [as variable]:
4510 //      with-block
ParseWith(const FunctionBlockPtr & block,const py::object & node)4511 FunctionBlockPtr Parser::ParseWith(const FunctionBlockPtr &block, const py::object &node) {
4512   MS_LOG(DEBUG) << "Process ast With";
4513   py::list items_objs = python_adapter::GetPyObjAttr(node, "items");
4514   if (items_objs.empty()) {
4515     MS_LOG(INTERNAL_EXCEPTION) << "Unexpected 'with'.";
4516   }
4517   std::stack<AnfNodePtr> context_expr_nodes;
4518   std::stack<AnfNodePtr> entered_nodes;
4519   for (size_t i = 0; i < items_objs.size(); ++i) {
4520     auto items_obj = items_objs[i];
4521     // with Sample() as sample:
4522     // mean context_expr is Sample(), sample is optional_vars
4523     py::object context_expr_obj = python_adapter::GetPyObjAttr(items_obj, "context_expr");
4524     AnfNodePtr context_expr_node = ParseExprNode(block, context_expr_obj);
4525     context_expr_nodes.push(context_expr_node);
4526     auto enter_node = ParseWithitem(block, items_obj, context_expr_node);
4527     entered_nodes.push(enter_node);
4528   }
4529   MS_EXCEPTION_IF_NULL(block);
4530   auto func_graph = block->func_graph();
4531   MS_EXCEPTION_IF_NULL(func_graph);
4532   py::object body_node = python_adapter::GetPyObjAttr(node, "body");
4533   FunctionBlockPtr body_block = ParseStatements(block, body_node);
4534   auto body_func = body_block->func_graph();
4535   MS_EXCEPTION_IF_NULL(body_func);
4536 
4537   while (!context_expr_nodes.empty()) {
4538     auto context_expr_node = context_expr_nodes.top();
4539     auto entered_node = entered_nodes.top();
4540     context_expr_nodes.pop();
4541     entered_nodes.pop();
4542     // Use the depend node to ensure the execution order of enter and exit node.
4543     std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), context_expr_node, entered_node};
4544     context_expr_node = func_graph->NewCNodeInOrder(depend_inputs);
4545     // Handle __exit__(self, type, value, trace)
4546     std::vector<AnfNodePtr> exit_inputs{NewValueNode(prim::kPrimWithExit), context_expr_node};
4547     AnfNodePtr exit_node = func_graph->NewCNodeInOrder(exit_inputs);
4548     block->AddIsolatedNode(exit_node);
4549   }
4550   FunctionBlockPtr after_block = MakeFunctionBlock();
4551   if (body_func->get_return() == nullptr) {
4552     body_block->Jump(after_block, {});
4553   }
4554   after_block->Mature();
4555   return after_block;
4556 }
4557 
PrintPhiArgMaps(const std::map<ParameterPtr,std::set<AnfNodePtr>> & phi_to_args,const std::map<AnfNodePtr,std::set<ParameterPtr>> & arg_to_phis)4558 void Parser::PrintPhiArgMaps(const std::map<ParameterPtr, std::set<AnfNodePtr>> &phi_to_args,
4559                              const std::map<AnfNodePtr, std::set<ParameterPtr>> &arg_to_phis) {
4560   if (!IS_OUTPUT_ON(mindspore::kDebug)) {
4561     return;
4562   }
4563   std::ostringstream oss;
4564   oss << "==============================Start=============================="
4565       << "\n";
4566   size_t m = 0;
4567   for (const auto &[phi, args] : phi_to_args) {
4568     MS_EXCEPTION_IF_NULL(phi);
4569     oss << "phi[" << m++ << "]: " << phi->DebugString() << "\n";
4570     size_t n = 0;
4571     for (auto &arg : args) {
4572       MS_EXCEPTION_IF_NULL(arg);
4573       oss << "    args[" << n++ << "]: " << arg->DebugString() << "\n";
4574     }
4575   }
4576 
4577   m = 0;
4578   for (const auto &[arg, phis] : arg_to_phis) {
4579     MS_EXCEPTION_IF_NULL(arg);
4580     oss << "arg[" << m++ << "]: " << arg->DebugString() << "\n";
4581     size_t n = 0;
4582     for (auto &phi : phis) {
4583       MS_EXCEPTION_IF_NULL(phi);
4584       oss << "    phis[" << n++ << "]: " << phi->DebugString() << "\n";
4585     }
4586   }
4587   oss << "===============================End==============================="
4588       << "\n";
4589   MS_LOG(DEBUG) << "\n" << oss.str();
4590 }
4591 
4592 namespace {
UpdatePhiArgMaps(std::map<ParameterPtr,std::set<AnfNodePtr>> * phi_to_args,std::map<AnfNodePtr,std::set<ParameterPtr>> * arg_to_phis)4593 bool UpdatePhiArgMaps(std::map<ParameterPtr, std::set<AnfNodePtr>> *phi_to_args,
4594                       std::map<AnfNodePtr, std::set<ParameterPtr>> *arg_to_phis) {
4595   MS_EXCEPTION_IF_NULL(phi_to_args);
4596   MS_EXCEPTION_IF_NULL(arg_to_phis);
4597   bool phi_arg_updated = false;
4598   auto copy_phi_to_args = *phi_to_args;
4599   for (const auto &[phi, args] : copy_phi_to_args) {
4600     // The phi node has only one arg can be replaced as arg.
4601     if (args.size() != 1) {
4602       continue;
4603     }
4604     auto phi_arg = *args.begin();
4605     MS_EXCEPTION_IF_NULL(phi_arg);
4606     MS_LOG(DEBUG) << "phi: " << phi->DebugString() << ", get one arg: " << phi_arg->DebugString();
4607     // If this phi is a arg of other phi.
4608     auto arg_to_phi_it = arg_to_phis->find(phi);
4609     if (arg_to_phi_it == arg_to_phis->end()) {
4610       continue;
4611     }
4612     // Use the new phi arg as the arg of other phi's arg. Usually other phi is a deeper subgraph's phi node.
4613     auto other_phis = arg_to_phi_it->second;
4614     MS_LOG(DEBUG) << "Find phi as arg of other phi, other phis num: " << other_phis.size();
4615     // Update all other phis' arg from phi to phi_arg.
4616     for (auto &other_phi : other_phis) {
4617       MS_EXCEPTION_IF_NULL(other_phi);
4618       MS_LOG(DEBUG) << "other phi: " << other_phi->DebugString();
4619       phi_arg_updated = true;
4620       // The phi will not be arg of any other phis.Erase map1.
4621       (void)(*phi_to_args)[other_phi].erase(phi);
4622       // If arg is same to the parameter phi, ignore the arg, keep maps don't have self arg.
4623       if (phi_arg == other_phi) {
4624         MS_LOG(DEBUG) << "Get phi arg of phi self.";
4625         continue;
4626       }
4627       MS_LOG(DEBUG) << "phi arg: " << phi_arg->DebugString()
4628                     << " as new arg of other phi: " << other_phi->DebugString();
4629       // Replace other phi's arg as this phi's arg, instead of phi. (other_phi , phi) -> (other_phi, phi_arg)
4630       (void)(*phi_to_args)[other_phi].insert(phi_arg);
4631       // Add other phi to the phi_arg's phis set. (phi_arg, {phi_x, }) -> (phi_arg, {phi_x, other_phi})
4632       (void)(*arg_to_phis)[phi_arg].insert(other_phi);
4633     }
4634     MS_LOG(DEBUG) << "Remove phi type arg: " << phi;
4635     // The phi will not be arg of any other phis.Erase map2.
4636     (void)(*arg_to_phis).erase(phi);
4637   }
4638   return phi_arg_updated;
4639 }
4640 }  // namespace
4641 
UpdatePhiArgMapsRepeatedly(std::map<ParameterPtr,std::set<AnfNodePtr>> * phi_to_args,std::map<AnfNodePtr,std::set<ParameterPtr>> * arg_to_phis)4642 void Parser::UpdatePhiArgMapsRepeatedly(std::map<ParameterPtr, std::set<AnfNodePtr>> *phi_to_args,
4643                                         std::map<AnfNodePtr, std::set<ParameterPtr>> *arg_to_phis) {
4644   bool phi_arg_updated = true;
4645   size_t loop_count = 0;
4646   while (phi_arg_updated) {
4647     MS_LOG(DEBUG) << "update loop count: " << loop_count++;
4648     PrintPhiArgMaps(*phi_to_args, *arg_to_phis);
4649     phi_arg_updated = UpdatePhiArgMaps(phi_to_args, arg_to_phis);
4650   }
4651 }
4652 
CreatePhiArgMaps(std::map<ParameterPtr,std::set<AnfNodePtr>> * phi_to_args,std::map<AnfNodePtr,std::set<ParameterPtr>> * arg_to_phis)4653 void Parser::CreatePhiArgMaps(std::map<ParameterPtr, std::set<AnfNodePtr>> *phi_to_args,
4654                               std::map<AnfNodePtr, std::set<ParameterPtr>> *arg_to_phis) {
4655   MS_EXCEPTION_IF_NULL(phi_to_args);
4656   MS_EXCEPTION_IF_NULL(arg_to_phis);
4657   for (FunctionBlockPtr &block : func_block_list_) {
4658     MS_EXCEPTION_IF_NULL(block);
4659     for (const auto &[phi, args] : block->phi_args()) {
4660       // Filtered args exclude the arg pointer equals to phi pointer.
4661       for (const auto &arg : args) {
4662         if (phi == arg) {
4663           continue;
4664         }
4665         (void)(*phi_to_args)[phi].insert(arg);
4666         (void)(*arg_to_phis)[arg].insert(phi);
4667       }
4668     }
4669   }
4670 }
4671 
CollectRemovablePhiArgs(const std::map<ParameterPtr,std::set<AnfNodePtr>> & phi_to_args)4672 std::shared_ptr<std::map<ParameterPtr, AnfNodePtr>> Parser::CollectRemovablePhiArgs(
4673   const std::map<ParameterPtr, std::set<AnfNodePtr>> &phi_to_args) {
4674   auto need_remove_phi_args = std::make_shared<std::map<ParameterPtr, AnfNodePtr>>();
4675   for (const auto &[phi, args] : phi_to_args) {
4676     if (args.empty()) {
4677       // phi's arg is phi self.
4678       (*need_remove_phi_args)[phi] = nullptr;
4679       continue;
4680     }
4681     if (args.size() == 1) {
4682       (*need_remove_phi_args)[phi] = *(args.begin());
4683     }
4684   }
4685   if (IS_OUTPUT_ON(mindspore::kDebug)) {
4686     size_t m = 0;
4687     std::ostringstream oss;
4688     oss << "=====================Need removed phis and args====================="
4689         << "\n";
4690     for (const auto &[phi, arg] : *need_remove_phi_args) {
4691       MS_EXCEPTION_IF_NULL(phi);
4692       oss << "phi[" << m << "]: " << phi->DebugString() << "\n";
4693       oss << "arg[" << m++ << "]: " << arg->DebugString() << "\n";
4694     }
4695     MS_LOG(DEBUG) << "\n" << oss.str();
4696   }
4697   return need_remove_phi_args;
4698 }
4699 
CalRemovablePhis()4700 std::shared_ptr<std::map<ParameterPtr, AnfNodePtr>> Parser::CalRemovablePhis() {
4701   std::map<ParameterPtr, std::set<AnfNodePtr>> phi_to_args;
4702   std::map<AnfNodePtr, std::set<ParameterPtr>> arg_to_phis;
4703   CreatePhiArgMaps(&phi_to_args, &arg_to_phis);
4704   // Update phi arg maps by phi arg map relations, some phi can be replaced as arg.
4705   UpdatePhiArgMapsRepeatedly(&phi_to_args, &arg_to_phis);
4706   // Collect all one arg phis.
4707   return CollectRemovablePhiArgs(phi_to_args);
4708 }
4709 
ReplacePhiAsArg(const std::map<ParameterPtr,AnfNodePtr> & removable_phis,const FuncGraphManagerPtr & manager)4710 void ReplacePhiAsArg(const std::map<ParameterPtr, AnfNodePtr> &removable_phis, const FuncGraphManagerPtr &manager) {
4711   MS_LOG(DEBUG) << "Removable phi size: " << removable_phis.size();
4712   for (const auto &[phi, arg] : removable_phis) {
4713     MS_LOG(DEBUG) << "Removable phi: " << phi->DebugString()
4714                   << ", arg: " << (arg == nullptr ? "null" : arg->DebugString());
4715     if (arg != nullptr) {
4716       (void)manager->Replace(phi, arg);
4717     }
4718   }
4719 }
4720 
4721 // Remove the removable phi parameter and get the corresponding index.
RemovePhiParametersAndGetRemoveIndex(const FunctionBlockPtr & block,const std::map<ParameterPtr,AnfNodePtr> & removable_phis)4722 HashSet<size_t> RemovePhiParametersAndGetRemoveIndex(const FunctionBlockPtr &block,
4723                                                      const std::map<ParameterPtr, AnfNodePtr> &removable_phis) {
4724   MS_EXCEPTION_IF_NULL(block);
4725   auto func_graph = block->func_graph();
4726   MS_EXCEPTION_IF_NULL(func_graph);
4727   MS_LOG(DEBUG) << "Check removable parameters of block: " << block->ToString();
4728   const auto &parameters = func_graph->parameters();
4729   std::vector<AnfNodePtr> new_parameters;
4730   // Remove the unnecessary phi parameters.
4731   HashSet<size_t> need_removed_indexes;
4732   for (size_t i = 0; i < parameters.size(); ++i) {
4733     auto parameter_i = parameters[i];
4734     MS_EXCEPTION_IF_NULL(parameter_i);
4735     if (removable_phis.find(parameter_i->cast<ParameterPtr>()) == removable_phis.end()) {
4736       new_parameters.push_back(parameter_i);
4737       continue;
4738     }
4739     // Record all removed indexes.
4740     (void)need_removed_indexes.insert(i);
4741   }
4742   MS_LOG(DEBUG) << "parameters.size(): " << parameters.size()
4743                 << ", need_removed_indexes.size(): " << need_removed_indexes.size();
4744   // Only if need_removed_indexes not empty, parameters need be updated.
4745   if (!need_removed_indexes.empty()) {
4746     func_graph->set_parameters(new_parameters);
4747   }
4748   return need_removed_indexes;
4749 }
4750 
4751 // If phi parameter is removable, then the corresponding arg should be removed.
RemoveJumpNodeArgs(const FunctionBlockPtr & block,const HashSet<size_t> & need_removed_indexes,const FuncGraphManagerPtr & manager)4752 void RemoveJumpNodeArgs(const FunctionBlockPtr &block, const HashSet<size_t> &need_removed_indexes,
4753                         const FuncGraphManagerPtr &manager) {
4754   MS_EXCEPTION_IF_NULL(block);
4755   if (need_removed_indexes.empty()) {
4756     return;
4757   }
4758   for (const auto &prev_block : block->prev_blocks()) {
4759     MS_EXCEPTION_IF_NULL(prev_block);
4760     const auto &jump_node = prev_block->GetJumpNode(block.get());
4761     // Switch call has no jump node.
4762     if (jump_node == nullptr) {
4763       continue;
4764     }
4765     std::vector<AnfNodePtr> new_inputs = {jump_node->input(0)};
4766     for (size_t arg_index = 0; arg_index < jump_node->size() - 1; ++arg_index) {
4767       if (need_removed_indexes.find(arg_index) == need_removed_indexes.end()) {
4768         new_inputs.push_back(jump_node->input(arg_index + 1));
4769       }
4770     }
4771     MS_EXCEPTION_IF_NULL(prev_block->func_graph());
4772     const auto &new_jump_node = prev_block->func_graph()->NewCNodeInOrder(new_inputs);
4773     MS_LOG(DEBUG) << "Replace old jump node: " << jump_node->DebugString()
4774                   << " as new jump node: " << new_jump_node->DebugString()
4775                   << ", jump node block: " << prev_block->ToString();
4776     (void)manager->Replace(jump_node, new_jump_node);
4777   }
4778 }
4779 
RemoveUnnecessaryPhis(const FuncGraphManagerPtr & manager)4780 void Parser::RemoveUnnecessaryPhis(const FuncGraphManagerPtr &manager) {
4781   // Merge all removable phis to one map;
4782   const auto &removable_phis = CalRemovablePhis();
4783   if (removable_phis->empty()) {
4784     return;
4785   }
4786   MS_EXCEPTION_IF_NULL(manager);
4787   // Replace all phi node as arg.
4788   ReplacePhiAsArg(*removable_phis, manager);
4789   // Remove the unnecessary phi parameters.
4790   for (const auto &block : func_block_list_) {
4791     MS_EXCEPTION_IF_NULL(block);
4792     MS_LOG(DEBUG) << "Start remove phi of block: " << block->ToString();
4793     // Remove the unnecessary phi parameters.
4794     const auto &need_removed_indexes = RemovePhiParametersAndGetRemoveIndex(block, *removable_phis);
4795     // Remove all block->prev_blocks()'s jump node corresponding args.
4796     RemoveJumpNodeArgs(block, need_removed_indexes, manager);
4797   }
4798 }
4799 
4800 // ParseFunctionAst class code
InitParseAstInfo(const std::string & python_mod_get_parse_method)4801 bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
4802   // Init the type
4803   target_type_ = PARSE_TARGET_UNKNOW;
4804 
4805   // Call python parse, get the parser fn
4806   module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
4807   py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_PARSE_METHOD);
4808 
4809   // Get the obj type
4810   auto type = data_converter::GetObjType(obj_);
4811   if (type == RESOLVE_TYPE_FUNCTION) {
4812     target_type_ = PARSE_TARGET_FUNCTION;
4813     function_ = obj_;
4814   } else if (type == RESOLVE_TYPE_METHOD) {
4815     // Process the method ,need get the method's self obj
4816     target_type_ = PARSE_TARGET_METHOD;
4817     py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS);
4818     if (py::isinstance<py::none>(method_object)) {
4819       MS_LOG(ERROR) << "Get method's self object instance failed.";
4820       return false;
4821     }
4822     target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
4823     function_ = obj_;
4824     obj_ = method_object;
4825   } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) {
4826     // 'obj' is class instance, get the method to parse.
4827     function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method);
4828     if (py::isinstance<py::none>(function_)) {
4829       MS_LOG(ERROR) << "Get obj method function failed.";
4830       return false;
4831     }
4832     target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
4833     // Check the fn is method
4834     auto obj_type = data_converter::GetObjType(function_);
4835     if (obj_type != RESOLVE_TYPE_METHOD) {
4836       MS_LOG(WARNING) << "Parse method function is invalid.";
4837       return false;
4838     }
4839   } else {
4840     MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type: " << type;
4841     return false;
4842   }
4843 
4844   // Call python parse get ast tree
4845   parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method);
4846   py::tuple ast_info = python_adapter::CallPyObjMethod(parser_, "parse");
4847   const size_t ast_info_size = 2;
4848   if (ast_info.size() != ast_info_size) {
4849     MS_INTERNAL_EXCEPTION(NameError) << "ast info size is not equal to 2.";
4850   }
4851   ast_tokens_ = ast_info[0];
4852   ast_tree_ = ast_info[1];
4853 
4854   // Get fn name and module
4855   function_module_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_module"));
4856   function_name_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_name"));
4857   function_filename_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "filename"));
4858   function_line_offset_ = py::cast<int64_t>(python_adapter::GetPyObjAttr(parser_, "line_offset"));
4859 
4860   return true;
4861 }
4862 
4863 // Get ast tree node : is the tree bode list[0]
GetAstNode()4864 py::object ParseFunctionAst::GetAstNode() {
4865   py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body");
4866   py::object ast_node = tree_body[0];
4867   return ast_node;
4868 }
4869 
4870 // Get ast tokens node text.
GetAstNodeText(const py::object & node_obj)4871 py::str ParseFunctionAst::GetAstNodeText(const py::object &node_obj) {
4872   return python_adapter::CallPyObjMethod(ast_tokens_, "get_text", node_obj);
4873 }
4874 
GetArgs(const py::object & func_node)4875 py::list ParseFunctionAst::GetArgs(const py::object &func_node) {
4876   py::list res = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS, func_node);
4877   return res;
4878 }
4879 
GetArgsDefaultValues(const py::object & func_node)4880 py::list ParseFunctionAst::GetArgsDefaultValues(const py::object &func_node) {
4881   py::list res = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node);
4882   return res;
4883 }
4884 
GetNodeType(const py::object & node)4885 AstNodeTypePtr ParseFunctionAst::GetNodeType(const py::object &node) {
4886   py::list list_value = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_NODE_TYPE, node);
4887   const size_t list_value_size = 2;
4888   if (list_value.size() < list_value_size) {
4889     MS_LOG(INTERNAL_EXCEPTION) << "The node of python method must has 2 values.";
4890   }
4891   auto node_name = py::cast<std::string>(list_value[0]);
4892   auto type = AstMainType(py::cast<int32_t>(list_value[1]));
4893   return std::make_shared<AstNodeType>(node, node_name, type);
4894 }
4895 
GetOpType(const py::object & node)4896 AstSubType ParseFunctionAst::GetOpType(const py::object &node) {
4897   auto op_type = AstSubType(python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_AST_TYPE, node).cast<int32_t>());
4898   return op_type;
4899 }
4900 
IsClassMemberOfSelf(const py::object & node)4901 bool ParseFunctionAst::IsClassMemberOfSelf(const py::object &node) {
4902   py::object res = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER_OF_SELF, node);
4903   if (!py::isinstance<py::bool_>(res)) {
4904     MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
4905     return false;
4906   }
4907   return res.cast<bool>();
4908 }
4909 
IsClassMemberRecursive(const py::object & node)4910 bool ParseFunctionAst::IsClassMemberRecursive(const py::object &node) {
4911   py::object res = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER_RECURSIVE, node);
4912   if (!py::isinstance<py::bool_>(res)) {
4913     MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
4914     return false;
4915   }
4916   return res.cast<bool>();
4917 }
4918 
SetMixedPrecisionFlag(const py::object & obj,const FuncGraphPtr & func_graph)4919 void SetMixedPrecisionFlag(const py::object &obj, const FuncGraphPtr &func_graph) {
4920   MS_EXCEPTION_IF_NULL(func_graph);
4921   if (!py::isinstance<Cell>(obj)) {
4922     return;
4923   }
4924   auto cell = py::cast<CellPtr>(obj);
4925   MS_EXCEPTION_IF_NULL(cell);
4926   auto mixed_type = cell->GetMixedPrecisionType();
4927   if (mixed_type != MixedPrecisionType::kNotSet) {
4928     func_graph->set_flag(GRAPH_FLAG_MIX_PRECISION_FP16, mixed_type == MixedPrecisionType::kFP16);
4929     func_graph->set_flag(GRAPH_FLAG_MIX_PRECISION_FP32, mixed_type == MixedPrecisionType::kFP32);
4930     func_graph->set_flag(GRAPH_FLAG_MIX_PRECISION_BF16, mixed_type == MixedPrecisionType::kBF16);
4931   }
4932 }
4933 
UpdateFuncGraphFlags(const py::object & obj,const FuncGraphPtr & func_graph,bool is_construct_function)4934 bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph, bool is_construct_function) {
4935   if (func_graph == nullptr) {
4936     MS_LOG(ERROR) << "FuncGraph is null";
4937     return false;
4938   }
4939 
4940   SetMixedPrecisionFlag(obj, func_graph);
4941 
4942   if (!py::hasattr(obj, PYTHON_FUNC_GRAPH_FLAGS)) {
4943     MS_LOG(DEBUG) << "No flags";
4944     return true;
4945   }
4946   py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_FUNC_GRAPH_FLAGS);
4947   for (auto &item : flags) {
4948     if (!py::isinstance<py::str>(item.first)) {
4949       MS_LOG(ERROR) << "Type error in flags dict convert";
4950       return false;
4951     }
4952     auto name = py::cast<std::string>(item.first);
4953     if (py::isinstance<py::bool_>(item.second)) {
4954       auto value = py::cast<bool>(item.second);
4955       MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
4956       if (!is_construct_function && name == FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) {
4957         continue;
4958       }
4959       func_graph->set_flag(name, value);
4960     } else if (py::isinstance<py::str>(item.second)) {
4961       auto value = py::cast<std::string>(item.second);
4962       MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
4963       func_graph->set_attr(name, MakeValue(value));
4964     } else {
4965       MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
4966       return false;
4967     }
4968   }
4969   return true;
4970 }
4971 
UpdateRecomputeScope(const FuncGraphPtr & func_graph)4972 void UpdateRecomputeScope(const FuncGraphPtr &func_graph) {
4973   MS_EXCEPTION_IF_NULL(func_graph);
4974   auto nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple);
4975 
4976   for (const auto &node : nodes) {
4977     MS_EXCEPTION_IF_NULL(node);
4978     const auto &origin_scope_name = node->scope()->name();
4979     if (node->isa<CNode>() && origin_scope_name.compare(0, strlen(kAttrRecompute), kAttrRecompute) != 0) {
4980       std::stringstream scope_name_buffer;
4981       scope_name_buffer << kAttrRecompute << "_" << origin_scope_name;
4982       node->set_scope(std::make_shared<Scope>(scope_name_buffer.str()));
4983     }
4984   }
4985 }
4986 
IsSubscriptReferenceType(const py::object & obj)4987 bool Parser::IsSubscriptReferenceType(const py::object &obj) {
4988   py::object slice_node = python_adapter::GetPyObjAttr(obj, "slice");
4989   auto node_type = ast_->GetNodeType(slice_node);
4990   auto node_name = node_type->node_name();
4991   return node_name != "Slice";
4992 }
4993 
4994 struct CompileConfigCollectRegister {
CompileConfigCollectRegistermindspore::parse::CompileConfigCollectRegister4995   CompileConfigCollectRegister() noexcept {
4996     CompileConfigManager::set_collect_func([]() {
4997       std::map<std::string, std::string> compile_config;
4998       const auto module_name = "mindspore._extends.parse.compile_config";
4999       py::list config_list = py::cast<py::list>(python_adapter::GetPyFn(module_name, "__all__"));
5000       for (size_t i = 0; i < config_list.size(); ++i) {
5001         auto config_name = config_list[i].cast<std::string>();
5002         auto config = python_adapter::GetPyFn(module_name, config_name);
5003         if (py::isinstance<py::none>(config)) {
5004           MS_LOG(INTERNAL_EXCEPTION) << config_name << " not found in " << module_name << ".";
5005         }
5006         if (py::isinstance<py::int_>(config)) {
5007           compile_config[config_name] = std::to_string(py::cast<int64_t>(config));
5008         } else {
5009           compile_config[config_name] = config.cast<std::string>();
5010         }
5011       }
5012       return compile_config;
5013     });
5014   }
5015   ~CompileConfigCollectRegister() = default;
5016 } compile_config_collect_register;
5017 }  // namespace parse
5018 }  // namespace mindspore
5019