1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2024 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/ps/parse/parse.h"
20
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <sstream>
25 #include <algorithm>
26 #include <stack>
27 #include <regex>
28 #include "mindspore/core/ops/structure_ops.h"
29 #include "mindspore/core/ops/sequence_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "utils/hash_map.h"
32 #include "pipeline/jit/ps/fallback.h"
33 #include "pipeline/jit/ps/parse/resolve.h"
34 #include "pipeline/jit/ps/parse/data_converter.h"
35 #include "frontend/operator/ops.h"
36 #include "frontend/operator/composite/composite.h"
37 #include "utils/ms_context.h"
38 #include "utils/log_adapter.h"
39 #include "utils/compile_config.h"
40 #include "utils/interpret_node_recorder.h"
41 #include "pipeline/jit/ps/debug/trace.h"
42 #include "mindspore/core/ir/cell.h"
43 #include "include/common/fallback.h"
44 #include "include/common/utils/utils.h"
45 #include "include/common/utils/python_adapter.h"
46 #include "include/common/utils/convert_utils_py.h"
47
48 namespace mindspore {
49 namespace parse {
ParsePythonCode(const py::object & obj,const std::string & python_mod_get_parse_method,const ValuePtrList & args_value_list)50 FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method,
51 const ValuePtrList &args_value_list) {
52 (void)python_adapter::set_python_scoped();
53 py::gil_scoped_acquire gil;
54
55 if (!obj || py::isinstance<py::none>(obj)) {
56 MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none";
57 return nullptr;
58 }
59 MS_LOG(DEBUG) << "Parse ast obj: " << py::str(obj)
60 << ", python_mod_get_parse_method: " << python_mod_get_parse_method;
61
62 auto ast = std::make_shared<ParseFunctionAst>(obj);
63 bool success = ast->InitParseAstInfo(python_mod_get_parse_method);
64 if (!success) {
65 MS_LOG(ERROR) << "Parse code to ast tree failed. obj: " << py::str(obj)
66 << ", python_mod_get_parse_method: " << python_mod_get_parse_method;
67 return nullptr;
68 }
69
70 auto parser = std::make_shared<Parser>(ast, args_value_list);
71
72 FuncGraphPtr func_graph = parser->ParseFuncGraph();
73 if (func_graph == nullptr) {
74 MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode();
75 py::object node = ast->GetAstNode();
76 const auto &location = parser->GetLocation(node);
77 py::str desc = python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(),
78 location->file_name(), location->line());
79 MS_LOG(ERROR) << "\nlocation:" << desc.cast<std::string>();
80 return nullptr;
81 }
82
83 // Handle no_inline function
84 auto no_inline_value = py::getattr(obj, FUNC_GRAPH_FLAG_NO_INLINE, py::none());
85 if (no_inline_value != py::none()) {
86 func_graph->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, py::cast<bool>(no_inline_value));
87 }
88 // Handle cell_reusing function
89 auto cell_reuse_value = py::getattr(obj, FUNC_GRAPH_FLAG_CELL_REUSE, py::none());
90 if (cell_reuse_value != py::none()) {
91 func_graph->set_flag(FUNC_GRAPH_FLAG_CELL_REUSE, py::cast<bool>(cell_reuse_value));
92 }
93
94 MS_LOG(DEBUG) << "Finish Parsing " << py::str(obj);
95 return func_graph;
96 }
97
98 FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
99
Parser(const std::shared_ptr<ParseFunctionAst> & ast,ValuePtrList args_value_list)100 Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast, ValuePtrList args_value_list)
101 : ast_(ast), errcode_(PARSE_SUCCESS), args_value_list_(std::move(args_value_list)) {
102 BuildMethodMap();
103 }
104
BuildMethodMap()105 void Parser::BuildMethodMap() {
106 stmt_method_map_["Return"] = &Parser::ParseReturn;
107 stmt_method_map_["Expr"] = &Parser::ParseExpr;
108 stmt_method_map_["If"] = &Parser::ParseIf;
109 stmt_method_map_["Assign"] = &Parser::ParseAssign;
110 stmt_method_map_["AnnAssign"] = &Parser::ParseAnnAssign;
111 stmt_method_map_["While"] = &Parser::ParseWhile;
112 stmt_method_map_["For"] = &Parser::ParseFor;
113 stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
114 stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
115 stmt_method_map_["Global"] = &Parser::ParseGlobal;
116 stmt_method_map_["Break"] = &Parser::ParseBreak;
117 stmt_method_map_["Continue"] = &Parser::ParseContinue;
118 stmt_method_map_["Pass"] = &Parser::ParsePass;
119 stmt_method_map_["Raise"] = &Parser::ParseRaise;
120 stmt_method_map_["Assert"] = &Parser::ParseAssert;
121 stmt_method_map_["With"] = &Parser::ParseWith;
122 expr_method_map_["NoneType"] = &Parser::ParseNone;
123 expr_method_map_["BinOp"] = &Parser::ParseBinOp;
124 expr_method_map_["Name"] = &Parser::ParseName;
125 expr_method_map_["Num"] = &Parser::ParseNum;
126 expr_method_map_["Str"] = &Parser::ParseStr;
127 expr_method_map_["Constant"] = &Parser::ParseConstant;
128 expr_method_map_["NameConstant"] = &Parser::ParseNameConstant;
129 expr_method_map_["Call"] = &Parser::ParseCall;
130 expr_method_map_["IfExp"] = &Parser::ParseIfExp;
131 expr_method_map_["Attribute"] = &Parser::ParseAttribute;
132 expr_method_map_["Compare"] = &Parser::ParseCompare;
133 expr_method_map_["BoolOp"] = &Parser::ParseBoolOp;
134 expr_method_map_["Lambda"] = &Parser::ParseLambda;
135 expr_method_map_["Tuple"] = &Parser::ParseTuple;
136 expr_method_map_["List"] = &Parser::ParseList;
137 expr_method_map_["Subscript"] = &Parser::ParseSubscript;
138 expr_method_map_["Slice"] = &Parser::ParseSlice;
139 expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice;
140 expr_method_map_["Index"] = &Parser::ParseIndex;
141 expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
142 expr_method_map_["Dict"] = &Parser::ParseDict;
143 expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
144 expr_method_map_["DictComp"] = &Parser::ParseDictComp;
145 expr_method_map_["ListComp"] = &Parser::ParseListComp;
146 expr_method_map_["GeneratorExp"] = &Parser::ParseListComp; // We treat 'GeneratorExp' the same as 'ListComp'.
147 expr_method_map_["JoinedStr"] = &Parser::ParseJoinedStr;
148 expr_method_map_["FormattedValue"] = &Parser::ParseFormattedValue;
149 expr_method_map_["Starred"] = &Parser::ParseStarred;
150 condition_method_map_["Attribute"] = &Parser::CheckAttributeConstantCond;
151 condition_method_map_["Name"] = &Parser::CheckNameConstantCond;
152 condition_method_map_["UnaryOp"] = &Parser::CheckUnaryOpConstantCond;
153 condition_method_map_["Compare"] = &Parser::CheckCompareConstantCond;
154 condition_method_map_["BoolOp"] = &Parser::CheckBoolOpConstantCond;
155 compare_method_map_["is"] = &Parser::CompareIs;
156 compare_method_map_["is not"] = &Parser::CompareIsNot;
157 compare_method_map_["=="] = &Parser::CompareEqual;
158 compare_method_map_["!="] = &Parser::CompareNotEqual;
159 compare_method_map_[">"] = &Parser::CompareGreater;
160 compare_method_map_[">="] = &Parser::CompareGreaterEqual;
161 compare_method_map_["<"] = &Parser::CompareLess;
162 compare_method_map_["<="] = &Parser::CompareLessEqual;
163 }
164
UpdateTopFuncGraph(const FuncGraphPtr & func_graph)165 void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
166
InitParserEnvironment(const py::object & obj)167 void Parser::InitParserEnvironment(const py::object &obj) {
168 Parser::top_func_graph_ = FuncGraphWeakPtr();
169 ScopeManager::GetInstance().ClearScope();
170 (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj);
171 // CellList need convert to FuncGraph in Parse, add flag for input from top graph.
172 if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
173 py::setattr(obj, PYTHON_CELL_LIST_FROM_TOP, py::bool_(true));
174 }
175 }
176
CleanParserResource()177 void Parser::CleanParserResource() {
178 Parser::top_func_graph_ = FuncGraphWeakPtr();
179 ScopeManager::GetInstance().ClearScope();
180 parse::CleanParameterNameCache();
181 }
182
CheckFuncReturn(const FuncGraphManagerPtr & manager,const FuncGraphPtr & fn)183 void Parser::CheckFuncReturn(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fn) {
184 // Check whether the functions referred by this function and itself are missing 'return' statement
185 MS_EXCEPTION_IF_NULL(manager);
186 MS_EXCEPTION_IF_NULL(ast_);
187 for (const auto &func_graph : manager->func_graphs()) {
188 MS_EXCEPTION_IF_NULL(func_graph);
189 if (func_graph->get_return() != nullptr) {
190 continue;
191 }
192 py::object node = ast_->GetAstNode();
193 const auto &location = GetLocation(node);
194 MS_EXCEPTION_IF_NULL(location);
195 py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
196 location->file_name(), location->line());
197 MS_LOG(INFO) << "Function must has 'return' statement, but missing in " << desc.cast<std::string>()
198 << ". FuncGraph: " << func_graph->ToString() << ", location: " << location->ToString()
199 << "\nWe will add a 'return None' statement automatically.";
200 // If the def function has no return statement, mean that return none.
201 TraceGuard trace_guard_none(location);
202 auto none_node = NewValueNode(kNone);
203 auto return_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), none_node});
204 func_graph->set_return(return_node);
205 }
206 }
207
GetFreeVariable(const FuncGraphPtr & func_graph)208 std::vector<std::pair<CNodePtr, size_t>> GetFreeVariable(const FuncGraphPtr &func_graph) {
209 // Considering the performance, we didn't use Manager here.
210 std::vector<std::pair<CNodePtr, size_t>> free_variables;
211 MS_EXCEPTION_IF_NULL(func_graph);
212 std::vector<AnfNodePtr> nodes =
213 TopoSort(func_graph->get_return(), SuccIncoming, [&func_graph](const AnfNodePtr &node) -> IncludeType {
214 MS_EXCEPTION_IF_NULL(node);
215 // Not follow FV's inputs.
216 if (node->func_graph() != nullptr && node->func_graph() != func_graph) {
217 return NOFOLLOW;
218 }
219 return FOLLOW;
220 });
221 for (auto &node : nodes) {
222 // Only check Non-FV CNode.
223 auto cnode = dyn_cast<CNode>(node);
224 if (cnode == nullptr || cnode->func_graph() != func_graph) {
225 continue;
226 }
227
228 for (size_t i = 0; i < cnode->size(); ++i) {
229 auto &input = cnode->input(i);
230 if (input->func_graph() != nullptr && input->func_graph() != func_graph) {
231 (void)free_variables.emplace_back(std::make_pair(cnode, i));
232 constexpr auto recur_2 = 2;
233 MS_LOG(DEBUG) << "Found FV: input[" << i << "] of " << cnode->DebugString(recur_2);
234 }
235 }
236 }
237 return free_variables;
238 }
239
LiftRolledBodyGraphFV()240 void Parser::LiftRolledBodyGraphFV() {
241 for (auto &rolled_call_pair : rolled_body_calls_) {
242 auto rolled_call_cnode = rolled_call_pair.first;
243 auto rolled_graph = rolled_call_pair.second->func_graph();
244 MS_EXCEPTION_IF_NULL(rolled_graph);
245 const auto &free_variables = GetFreeVariable(rolled_graph);
246 for (auto &free_node_pair : free_variables) {
247 auto &cnode = free_node_pair.first;
248 auto index = free_node_pair.second;
249 // Move the free variable to parent.
250 auto &free_node = cnode->input(index);
251 rolled_call_cnode->add_input(free_node);
252 // Change the free variable to the parameter.
253 auto parameter = rolled_graph->add_parameter();
254 cnode->set_input(index, parameter);
255 constexpr auto recur_2 = 2;
256 MS_LOG(DEBUG) << "Change FV: " << cnode->DebugString(recur_2);
257 }
258 }
259 }
260
LiftIfBranchGraphFV()261 void Parser::LiftIfBranchGraphFV() {
262 for (auto &branch_call_tuple : if_branch_calls_) {
263 auto call_cnode = std::get<0>(branch_call_tuple);
264 auto true_branch_graph = std::get<1>(branch_call_tuple)->func_graph();
265 MS_EXCEPTION_IF_NULL(true_branch_graph);
266 auto false_branch_graph = std::get<2>(branch_call_tuple)->func_graph();
267 MS_EXCEPTION_IF_NULL(false_branch_graph);
268 const auto &true_free_variables = GetFreeVariable(true_branch_graph);
269 const auto &false_free_variables = GetFreeVariable(false_branch_graph);
270 // Handle true branch.
271 for (auto &free_node_pair : true_free_variables) {
272 auto &cnode = free_node_pair.first;
273 MS_EXCEPTION_IF_NULL(cnode);
274 auto index = free_node_pair.second;
275 // Move the free variable to parent.
276 auto &free_node = cnode->input(index);
277 call_cnode->add_input(free_node);
278 // Change the free variable to the parameter.
279 auto parameter = true_branch_graph->add_parameter();
280 cnode->set_input(index, parameter);
281 // Add a unused parameter in other branch.
282 (void)false_branch_graph->add_parameter();
283 constexpr auto recur_2 = 2;
284 MS_LOG(DEBUG) << "True branch, change FV: " << cnode->DebugString(recur_2);
285 }
286 // Handle false branch.
287 for (auto &free_node_pair : false_free_variables) {
288 auto &cnode = free_node_pair.first;
289 MS_EXCEPTION_IF_NULL(cnode);
290 auto index = free_node_pair.second;
291 // Move the free variable to parent.
292 auto &free_node = cnode->input(index);
293 call_cnode->add_input(free_node);
294 // Change the free variable to the parameter.
295 auto parameter = false_branch_graph->add_parameter();
296 cnode->set_input(index, parameter);
297 // Add a unused parameter in other branch.
298 (void)true_branch_graph->add_parameter();
299 constexpr auto recur_2 = 2;
300 MS_LOG(DEBUG) << "False branch, change FV: " << cnode->DebugString(recur_2);
301 }
302 }
303 }
304
305 namespace {
IsDependOfIsolatedNodes(const AnfNodePtr & node)306 bool IsDependOfIsolatedNodes(const AnfNodePtr &node) {
307 if (!IsPrimitiveCNode(node, prim::kPrimDepend)) {
308 return false;
309 }
310 auto cnode = dyn_cast<CNode>(node);
311 MS_EXCEPTION_IF_NULL(cnode);
312 auto attr_sort_rhs_first = cnode->GetAttr(kAttrTopoSortRhsFirst);
313 auto sort_rhs_first =
314 attr_sort_rhs_first != nullptr && attr_sort_rhs_first->isa<BoolImm>() && GetValue<bool>(attr_sort_rhs_first);
315 return sort_rhs_first;
316 }
317
GetRealOutputNodes(const FuncGraphPtr & call_graph)318 std::pair<CNodePtr, AnfNodePtr> GetRealOutputNodes(const FuncGraphPtr &call_graph) {
319 MS_EXCEPTION_IF_NULL(call_graph);
320 auto graph_output = call_graph->output();
321 if (graph_output == nullptr) {
322 MS_LOG(INTERNAL_EXCEPTION) << "graph_output is null, call_graph: " << call_graph->ToString();
323 }
324 auto graph_output_cnode = dyn_cast<CNode>(graph_output);
325 MS_EXCEPTION_IF_NULL(graph_output_cnode);
326 // If output cnode is not the tail call but a Depend CNode, keep the dependency node for later use.
327 AnfNodePtr graph_dependency_node = nullptr;
328 if (IsDependOfIsolatedNodes(graph_output_cnode)) {
329 auto graph_real_output_cnode = dyn_cast<CNode>(graph_output_cnode->input(1));
330 // Get the dependency node;
331 constexpr auto dependency_node_index = 2;
332 graph_dependency_node = graph_output_cnode->input(dependency_node_index);
333 MS_EXCEPTION_IF_NULL(graph_real_output_cnode);
334 graph_output_cnode = graph_real_output_cnode;
335 }
336 return {graph_output_cnode, graph_dependency_node};
337 }
338
TransformParallelCallFormerToMiddle(const FuncGraphPtr & former_call_graph,const FuncGraphPtr & latter_call_graph,size_t middle_graph_output_cnode_size,bool use_arguments_pack)339 void TransformParallelCallFormerToMiddle(const FuncGraphPtr &former_call_graph, const FuncGraphPtr &latter_call_graph,
340 size_t middle_graph_output_cnode_size, bool use_arguments_pack) {
341 // The 'former_graph_output' is middle graph call or depend.
342 const auto &[former_graph_output_cnode, former_graph_dependency_node] = GetRealOutputNodes(former_call_graph);
343 MS_EXCEPTION_IF_NULL(former_graph_output_cnode);
344 MS_EXCEPTION_IF_NULL(former_call_graph);
345 std::vector<AnfNodePtr> inputs({NewValueNode(latter_call_graph)});
346 if (use_arguments_pack) {
347 for (size_t i = 0; i < middle_graph_output_cnode_size - 1; ++i) {
348 auto getitem_input = former_call_graph->NewCNodeInOrder(
349 {NewValueNode(prim::kPrimTupleGetItem), former_graph_output_cnode, NewValueNode(SizeToLong(i))});
350 (void)inputs.emplace_back(getitem_input);
351 }
352 } else {
353 (void)inputs.emplace_back(former_graph_output_cnode);
354 }
355 auto new_output = former_call_graph->NewCNodeBefore(former_call_graph->return_node(), std::move(inputs));
356 if (former_graph_dependency_node != nullptr) {
357 // Adjust the former funcgraph output with Depend.
358 new_output = former_call_graph->NewCNodeAfter(
359 new_output, {NewValueNode(prim::kPrimDepend), new_output, former_graph_dependency_node});
360 // Origin dependency_node has this attribute(refer to function IsDependOfIsolatedNodes), so we keep it.
361 new_output->AddAttr(kAttrTopoSortRhsFirst, MakeValue(true));
362 }
363 former_call_graph->set_output(new_output);
364 }
365
TransformParallelCallMiddleToLatter(const FuncGraphPtr & middle_call_graph,const CNodePtr & middle_graph_output_cnode,const AnfNodePtr & middle_graph_dependency_node,size_t middle_graph_output_cnode_size)366 bool TransformParallelCallMiddleToLatter(const FuncGraphPtr &middle_call_graph,
367 const CNodePtr &middle_graph_output_cnode,
368 const AnfNodePtr &middle_graph_dependency_node,
369 size_t middle_graph_output_cnode_size) {
370 MS_EXCEPTION_IF_NULL(middle_graph_output_cnode);
371 MS_EXCEPTION_IF_NULL(middle_call_graph);
372 bool use_arguments_pack = false;
373 constexpr auto output_inputs_num = 2;
374 AnfNodePtr new_middle_graph_output = nullptr;
375 if (middle_graph_output_cnode_size == output_inputs_num) { // Only one argument.
376 new_middle_graph_output = middle_graph_output_cnode->input(1);
377 } else { // More than one argument, pack them with tuple.
378 use_arguments_pack = true;
379 middle_graph_output_cnode->set_input(0, NewValueNode(prim::kPrimMakeTuple));
380 new_middle_graph_output = middle_graph_output_cnode;
381 }
382 // Adjust the middle funcgraph output with Depend.
383 if (middle_graph_dependency_node != nullptr) {
384 new_middle_graph_output = middle_graph_output_cnode->func_graph()->NewCNode(
385 {NewValueNode(prim::kPrimDepend), new_middle_graph_output, middle_graph_dependency_node});
386 }
387 middle_call_graph->set_output(new_middle_graph_output);
388 return use_arguments_pack;
389 }
390
IsValueContainScalar(const ValuePtr & value)391 bool IsValueContainScalar(const ValuePtr &value) {
392 if (value->isa<Scalar>()) {
393 return true;
394 }
395 return false;
396 }
397
IsOutputContainScalar(const CNodePtr & output_cnode)398 bool IsOutputContainScalar(const CNodePtr &output_cnode) {
399 return std::any_of(output_cnode->weak_inputs().cbegin() + 1, output_cnode->weak_inputs().end(),
400 [](const AnfNodeWeakPtr &weak_node) {
401 auto node = weak_node.lock();
402 MS_EXCEPTION_IF_NULL(node);
403 if (node->isa<ValueNode>()) {
404 auto value_node = node->cast<ValueNodePtr>();
405 return IsValueContainScalar(value_node->value());
406 }
407 return false;
408 });
409 }
410
CheckMiddleGraphOutputContainScalar(const std::vector<std::pair<FunctionBlockPtr,FunctionBlockPtr>> & parallel_call_vec)411 bool CheckMiddleGraphOutputContainScalar(
412 const std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> ¶llel_call_vec) {
413 std::vector<bool> contains_scalar;
414 for (auto &call_graphs_pair : parallel_call_vec) {
415 MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
416 auto middle_call_graph = call_graphs_pair.second->func_graph();
417 MS_EXCEPTION_IF_NULL(middle_call_graph);
418 if (middle_call_graph->get_return() == nullptr) {
419 continue;
420 }
421 constexpr auto recur_2 = 2;
422 const auto &middle_graph_output_pair = GetRealOutputNodes(middle_call_graph);
423 const auto middle_graph_output_cnode = middle_graph_output_pair.first;
424 MS_EXCEPTION_IF_NULL(middle_graph_output_cnode);
425 auto middle_graph_output_cnode_size = middle_graph_output_cnode->size();
426 if (middle_graph_output_cnode_size <= 1) {
427 MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
428 return false;
429 }
430
431 static const auto transform_if_const_scalar = (common::GetCompileConfig("IF_PARALLEL_CALL") == "2");
432 if (!transform_if_const_scalar && IsOutputContainScalar(middle_graph_output_cnode)) {
433 MS_LOG(DEBUG) << "CNode's inputs contain const scalar, " << middle_graph_output_cnode->DebugString(recur_2);
434 contains_scalar.push_back(true);
435 } else {
436 contains_scalar.push_back(false);
437 }
438 }
439
440 return std::all_of(contains_scalar.cbegin(), contains_scalar.cend(), [](bool is_scalar) { return is_scalar; });
441 }
442
CheckMiddleGraphOutputPyInterpret(const std::vector<std::pair<FunctionBlockPtr,FunctionBlockPtr>> & parallel_call_vec)443 bool CheckMiddleGraphOutputPyInterpret(
444 const std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> ¶llel_call_vec) {
445 bool contain_py_interpret = false;
446 for (auto &call_graphs_pair : parallel_call_vec) {
447 MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
448 auto middle_call_graph = call_graphs_pair.second->func_graph();
449 MS_EXCEPTION_IF_NULL(middle_call_graph);
450 if (middle_call_graph->get_return() == nullptr) {
451 continue;
452 }
453 constexpr auto recur_2 = 2;
454 const auto &middle_graph_output_pair = GetRealOutputNodes(middle_call_graph);
455 const auto middle_graph_output_cnode = middle_graph_output_pair.first;
456 MS_EXCEPTION_IF_NULL(middle_graph_output_cnode);
457 auto middle_graph_output_cnode_size = middle_graph_output_cnode->size();
458 if (middle_graph_output_cnode_size <= 1) {
459 MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
460 return false;
461 }
462 bool exist_interpret = std::any_of(
463 middle_graph_output_cnode->weak_inputs().cbegin() + 1, middle_graph_output_cnode->weak_inputs().cend(),
464 [](const AnfNodeWeakPtr &weak_node) { return IsPrimitiveCNode(weak_node.lock(), prim::kPrimPyInterpret); });
465 contain_py_interpret |= exist_interpret;
466 if (contain_py_interpret) {
467 return true;
468 }
469 }
470
471 return false;
472 }
473 } // namespace
474
475 // Transform tail call to parallel call.
TransformParallelCall()476 void Parser::TransformParallelCall() {
477 mindspore::HashSet<FuncGraphPtr> latter_call_graphs_set;
478 for (auto ¶llel_call_vec : parallel_call_graphs_) {
479 bool all_middle_graphs_output_scalar = CheckMiddleGraphOutputContainScalar(parallel_call_vec);
480 if (all_middle_graphs_output_scalar) {
481 MS_LOG(DEBUG) << "All middle func graph's output contain const scalar, cannot transform to Parallel_If.";
482 continue;
483 }
484 // After Join, Value in Abstract of PyInterpret CNode will be kValueAny, it cannot be PyInterpreted again, so
485 // ignore the transformation.
486 bool is_middle_graphs_output_py_interpret = CheckMiddleGraphOutputPyInterpret(parallel_call_vec);
487 if (is_middle_graphs_output_py_interpret) {
488 MS_LOG(DEBUG) << "Middle func graph's output contain PyInterpret CNode, cannot transform to Parallel_If.";
489 continue;
490 }
491 for (auto &call_graphs_pair : parallel_call_vec) {
492 MS_EXCEPTION_IF_NULL(call_graphs_pair.first);
493 auto former_call_graph = call_graphs_pair.first->func_graph();
494 MS_EXCEPTION_IF_NULL(call_graphs_pair.second);
495 auto middle_call_graph = call_graphs_pair.second->func_graph();
496 // Transform the call of {middle_graph -> latter_graph}.
497 auto middle_graph_return = middle_call_graph->get_return();
498 if (middle_graph_return == nullptr) {
499 MS_LOG(INFO) << "middle_graph_return is null, middle_call_graph: " << middle_call_graph->ToString();
500 continue;
501 }
502 constexpr auto recur_3 = 3;
503 constexpr auto recur_2 = 2;
504 MS_LOG(DEBUG) << "Tail call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
505 << ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
506 const auto &[middle_graph_output_cnode, middle_graph_dependency_node] = GetRealOutputNodes(middle_call_graph);
507 auto middle_graph_output_cnode_size = middle_graph_output_cnode->size();
508 if (middle_graph_output_cnode_size <= 1) {
509 MS_LOG(DEBUG) << "CNode's inputs size should exceed 1, " << middle_graph_output_cnode->DebugString(recur_2);
510 continue;
511 }
512
513 auto latter_graph_node = middle_graph_output_cnode->input(0);
514 bool use_arguments_pack = TransformParallelCallMiddleToLatter(
515 middle_call_graph, middle_graph_output_cnode, middle_graph_dependency_node, middle_graph_output_cnode_size);
516
517 // Transform the call of {former_graph -> middle_graph}.
518 auto latter_call_graph = GetValueNode<FuncGraphPtr>(latter_graph_node);
519 if (latter_call_graph == nullptr) {
520 MS_LOG(ERROR) << "The latter graph node is not FuncGraph, " << latter_graph_node->DebugString(recur_2);
521 continue;
522 }
523 if (latter_call_graphs_set.find(latter_call_graph) != latter_call_graphs_set.end()) {
524 MS_LOG(DEBUG) << "The latter graph is handled before, " << latter_call_graph->ToString();
525 continue;
526 }
527 (void)latter_call_graphs_set.emplace(latter_call_graph);
528 TransformParallelCallFormerToMiddle(former_call_graph, latter_call_graph, middle_graph_output_cnode_size,
529 use_arguments_pack);
530
531 MS_LOG(DEBUG) << "Parallel call graphs return: {former: " << former_call_graph->get_return()->DebugString(recur_3)
532 << ", middle: " << middle_call_graph->get_return()->DebugString(recur_3) << "}";
533 }
534 }
535
536 // Lift inner, then lift outer.
537 LiftIfBranchGraphFV();
538 LiftRolledBodyGraphFV();
539 }
540
ParseFuncGraph()541 FuncGraphPtr Parser::ParseFuncGraph() {
542 // Get ast FunctionDef node
543 py::object node = ast_->GetAstNode();
544 constexpr char function_def_name[] = "FunctionDef";
545 constexpr char lambda_name[] = "Lambda";
546 FunctionBlockPtr fn_block = nullptr;
547 MS_EXCEPTION_IF_NULL(ast_->GetNodeType(node));
548 if (ast_->GetNodeType(node)->node_name() == function_def_name) {
549 fn_block = ParseDefFunction(node);
550 } else {
551 auto lambda_node = python_adapter::GetPyObjAttr(node, "value");
552 if (py::isinstance<py::none>(lambda_node) || ast_->GetNodeType(lambda_node)->node_name() != lambda_name) {
553 MS_INTERNAL_EXCEPTION(TypeError) << "Parse Lambda Function Fail. Node type must be Lambda, but got "
554 << ast_->GetNodeType(lambda_node)->node_name() << ". Please check lambda"
555 << " expression to make sure it is defined on a separate line.\n For example, "
556 << "the code 'func = nn.ReLU() if y < 1 else lambda x: x + 1' rewritten as\n"
557 << "'if y < 1:\n func = nn.ReLU()\nelse:\n func = lambda x: x + 1\n'"
558 << "will solve the problem.";
559 }
560 fn_block = ParseLambdaFunction(lambda_node);
561 }
562 if (errcode() != PARSE_SUCCESS) {
563 MS_LOG(ERROR) << "Parse function error, code is " << errcode();
564 return nullptr;
565 }
566 for (auto &func_block_item : func_block_list_) {
567 MS_EXCEPTION_IF_NULL(func_block_item);
568 MS_EXCEPTION_IF_NULL(func_block_item->func_graph());
569 if (!func_block_item->isolated_nodes().empty()) {
570 // Find unused variables.
571 func_block_item->FindIsolatedNodes();
572 // Attach all isolated nodes.
573 func_block_item->AttachIsolatedNodesBeforeReturn();
574 }
575 }
576 MS_EXCEPTION_IF_NULL(fn_block);
577 auto manager = Manage(fn_block->func_graph(), false);
578 RemoveUnnecessaryPhis(manager);
579 CheckFuncReturn(manager, fn_block->func_graph());
580 TransformParallelCall();
581 return fn_block->func_graph();
582 }
583
584 // If any mixed precision flag add a cast node after the parameter node.
GetMixedPrecisionCastHelp(const FuncGraphPtr & func_graph,const AnfNodePtr & param)585 AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) {
586 MS_EXCEPTION_IF_NULL(func_graph);
587 TypePtr dst_type;
588 if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
589 dst_type = kFloat32;
590 } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
591 dst_type = kFloat16;
592 } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_BF16)) {
593 dst_type = kBFloat16;
594 } else {
595 return param;
596 }
597 auto cast_helper = prim::kPrimMixedPrecisionCast;
598 auto cast = func_graph->NewCNodeAfter(param, {NewValueNode(cast_helper), NewValueNode(dst_type), param});
599 return cast;
600 }
601
GenerateArgsNodeForFunction(const FunctionBlockPtr & block,const py::object & fn_node)602 void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
603 py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args");
604 py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg");
605 MS_EXCEPTION_IF_NULL(block);
606 auto block_fg = block->func_graph();
607 block_fg->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
608
609 py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg");
610 block_fg->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
611
612 py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs");
613 block_fg->set_kwonlyargs_count(SizeToInt(kwonly_args.size()));
614
615 MS_EXCEPTION_IF_NULL(ast_);
616 py::list args = ast_->GetArgs(fn_node);
617 for (std::size_t i = 0; i < args.size(); i++) {
618 std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
619 if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
620 if (arg_name == "self") {
621 continue;
622 }
623 }
624 TraceGuard guard(GetLocation(args[i]));
625 auto para_node = std::make_shared<Parameter>(block_fg);
626 MS_EXCEPTION_IF_NULL(para_node);
627 para_node->set_name(arg_name);
628 MS_EXCEPTION_IF_NULL(para_node->debug_info());
629 para_node->debug_info()->set_name(arg_name);
630 block_fg->add_parameter(para_node);
631 AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block_fg, para_node);
632 MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
633 block->WriteVariable(arg_name, para_after_cast);
634 }
635 }
636
GenerateArgsDefaultValueForFunction(const FunctionBlockPtr & block,const py::object & fn_node)637 void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
638 MS_EXCEPTION_IF_NULL(block);
639 py::list defaults = ast_->GetArgsDefaultValues(fn_node);
640 py::list args = ast_->GetArgs(fn_node);
641 std::vector<std::string> namelist_for_default_value;
642 std::vector<AnfNodePtr> default_values;
643 for (std::size_t i = 0; i < args.size(); i++) {
644 std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
645 if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
646 if (arg_name == "self") {
647 continue;
648 }
649 }
650
651 namelist_for_default_value.push_back(arg_name);
652 if (i >= defaults.size()) {
653 MS_LOG(INTERNAL_EXCEPTION) << "Index: " << i << " out of range: " << defaults.size();
654 }
655 if (py::isinstance<py::none>(defaults[i])) {
656 default_values.push_back(NewValueNode(kNull));
657 } else {
658 AnfNodePtr arg_node = ParseExprNode(block, defaults[i]);
659 default_values.push_back(arg_node);
660 }
661 }
662 MS_EXCEPTION_IF_NULL(block->func_graph());
663 block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values);
664 }
665
GetScopeForParseFunction()666 ScopePtr Parser::GetScopeForParseFunction() {
667 ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope();
668 if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
669 py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj());
670 if (!py::isinstance<py::none>(scope_str)) {
671 auto scope_name = py::cast<std::string>(scope_str);
672 scope = std::make_shared<Scope>(scope_name);
673 }
674 }
675 return scope;
676 }
677
ConvertGetattrNodes()678 void Parser::ConvertGetattrNodes() {
679 // If obj.attr has been set a new value in graph, convert all getattr node to PyExecute.
680 AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
681 for (const auto &setattr_node_pair : setattr_nodes_map_) {
682 const auto &obj_str = setattr_node_pair.first;
683 const auto &attr_map = setattr_node_pair.second;
684 auto getattr_nodes_map_iter = getattr_nodes_map_.find(obj_str);
685 // If the same object is not in both setattr map and getattr map, no need to convert getattr node.
686 if (getattr_nodes_map_iter == getattr_nodes_map_.end()) {
687 continue;
688 }
689 const auto &getattr_map = getattr_nodes_map_iter->second;
690 for (const auto &attr_pair : attr_map) {
691 const auto &attr_str = attr_pair.first;
692 auto getattr_map_iter = getattr_map.find(attr_str);
693 // If the same attr for the same obj is not in both setattr map and getattr map, no need to convert getattr node.
694 if (getattr_map_iter == getattr_map.end()) {
695 continue;
696 }
697 const auto &setattr_node = attr_pair.second;
698 auto setattr_cnode = setattr_node->cast<CNodePtr>();
699 MS_EXCEPTION_IF_NULL(setattr_cnode);
700 const auto getattr_nodes = getattr_map_iter->second;
701 constexpr size_t obj_index = 1;
702 const auto &setattr_cnode_obj_node = setattr_cnode->input(obj_index);
703 AnfNodePtr cur_getattr_node = nullptr;
704 for (const auto &getattr_node : getattr_nodes) {
705 auto getattr_node_fg = getattr_node->func_graph();
706 if (getattr_node_fg == nullptr) {
707 MS_LOG(DEBUG) << "Has no func graph, getattr_node: " << getattr_node->DebugString();
708 continue;
709 }
710 std::vector<AnfNodePtr> new_getattr_node_inputs = {op_node, setattr_cnode_obj_node, NewValueNode(attr_str)};
711 if (cur_getattr_node != nullptr && cur_getattr_node->func_graph() == getattr_node_fg) {
712 (void)new_getattr_node_inputs.emplace_back(cur_getattr_node);
713 }
714 auto new_getattr_node = getattr_node_fg->NewCNode(new_getattr_node_inputs);
715 new_getattr_node->set_user_data<bool>(fallback::kObjectAttrChange, std::make_shared<bool>(true));
716 new_getattr_node->set_debug_info(getattr_node->debug_info());
717 MS_LOG(DEBUG) << "Generate new getattr node: " << new_getattr_node->DebugString();
718 const auto &manager = Manage(getattr_node_fg, false);
719 MS_EXCEPTION_IF_NULL(manager);
720 (void)manager->Replace(getattr_node, new_getattr_node);
721 cur_getattr_node = new_getattr_node;
722 }
723 }
724 }
725 }
726
ParseDefFunction(const py::object & node,const FunctionBlockPtr & block)727 FunctionBlockPtr Parser::ParseDefFunction(const py::object &node, const FunctionBlockPtr &block) {
728 ScopePtr scope = GetScopeForParseFunction();
729 // The node created in the parsefunction context, will inherit the scope created using scope_guard
730 ScopeGuard scope_guard(scope);
731 const auto debug_info = std::make_shared<DebugInfo>(GetLocation(node));
732 TraceGuard trace_guard(std::make_shared<TraceParse>(debug_info));
733 FunctionBlockPtr func_block = MakeFunctionBlock();
734 if (block != nullptr) {
735 func_block->AddPrevBlock(block);
736 } else {
737 func_graph_ = func_block->func_graph();
738 }
739 func_block->Mature();
740 auto current_fg = func_block->func_graph();
741 auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
742 MS_LOG(DEBUG) << "The function name is " << function_name << ", loc: " << GetLocation(node)->ToString();
743 // Replace the construct function name with the cell name
744 constexpr auto cell_construct = "construct";
745 bool is_construct_function = false;
746 if (function_name == cell_construct) {
747 is_construct_function = true;
748 // 'py_class_name' format is like: <class 'x.x.xxx'>
749 std::string py_class_name = py::cast<std::string>(py::str(ast()->obj().get_type()));
750 constexpr auto py_class_prefix_len = 8; // <class '
751 constexpr auto py_class_suffix_len = 2; // '>
752 auto py_class_len = py_class_name.length();
753 // Exclude class prefix and suffix.
754 auto class_name =
755 py_class_name.substr(py_class_prefix_len, py_class_len - py_class_prefix_len - py_class_suffix_len);
756 function_name = class_name + '_' + cell_construct;
757 MS_LOG(DEBUG) << "The generated function full name: " << function_name;
758 }
759 // Normalize the name.
760 std::replace(function_name.begin(), function_name.end(), '.', '_');
761 std::replace(function_name.begin(), function_name.end(), '<', '_');
762 std::replace(function_name.begin(), function_name.end(), '>', '_');
763
764 // Save the function node to block
765 func_block->WriteVariable(function_name, NewValueNode(current_fg));
766 MS_EXCEPTION_IF_NULL(current_fg->debug_info());
767 current_fg->debug_info()->set_name(function_name);
768 py::list deco_list = node.attr("decorator_list");
769 if (!deco_list.empty()) {
770 current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
771 }
772 MS_EXCEPTION_IF_NULL(ast_);
773 bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
774 if (!ast_->obj().is(ast_->function())) {
775 set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg, is_construct_function);
776 }
777
778 if (!set_flag) {
779 MS_LOG(ERROR) << "Set flags failed";
780 return nullptr;
781 }
782 GenerateArgsNodeForFunction(func_block, node);
783
784 // When parsing the top graph of construct, save the top graph
785 if (GetTopFuncGraph() == nullptr) {
786 UpdateTopFuncGraph(func_block->func_graph());
787 }
788
789 py::object func_obj = python_adapter::GetPyObjAttr(node, "body");
790 (void)ParseStatements(func_block, func_obj);
791 if (current_fg->get_return() == nullptr) {
792 // If the def function has no return statement, mean that return none.
793 py::object location_node = ast_->GetAstNode();
794 const auto &location = GetLocation(location_node);
795 MS_EXCEPTION_IF_NULL(location);
796 py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(),
797 location->file_name(), location->line());
798 MS_LOG(INFO) << "Function must has 'return' statement, but missing in " << desc.cast<std::string>()
799 << ". FuncGraph: " << current_fg->ToString() << ", location: " << location->ToString()
800 << ". We will add a 'return None' statement automatically.";
801 TraceGuard trace_guard_none(location);
802 auto none_node = NewValueNode(kNone);
803 auto return_node = current_fg->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), none_node});
804 current_fg->set_return(return_node);
805 }
806
807 // Add unused variables as isolate nodes.
808 for (auto &func_block_item : func_block_list_) {
809 MS_EXCEPTION_IF_NULL(func_block_item);
810 MS_EXCEPTION_IF_NULL(func_block_item->func_graph());
811 if (func_block_item->func_graph()->get_return() != nullptr) {
812 // Find unused variables.
813 func_block_item->FindIsolatedNodes();
814 // Attach all isolated nodes.
815 func_block_item->AttachIsolatedNodesBeforeReturn();
816 }
817 }
818
819 ConvertGetattrNodes();
820 GenerateArgsDefaultValueForFunction(func_block, node);
821 return func_block;
822 }
823
ParseLambdaFunction(const py::object & node,const FunctionBlockPtr & block)824 FunctionBlockPtr Parser::ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block) {
825 MS_EXCEPTION_IF_NULL(ast_);
826 ScopePtr scope = GetScopeForParseFunction();
827 ScopeGuard scope_guard(scope);
828 const auto debug_info = std::make_shared<DebugInfo>(GetLocation(node));
829 TraceGuard trace_guard(std::make_shared<TraceParse>(debug_info));
830 FunctionBlockPtr func_block = MakeFunctionBlock();
831 MS_EXCEPTION_IF_NULL(func_block);
832 if (block != nullptr) {
833 func_block->AddPrevBlock(block);
834 } else {
835 func_graph_ = func_block->func_graph();
836 }
837 func_block->Mature();
838 auto current_fg = func_block->func_graph();
839
840 MS_EXCEPTION_IF_NULL(current_fg);
841 auto lambda_function_name = ast_->function_name();
842 // Normalize the name.
843 std::replace(lambda_function_name.begin(), lambda_function_name.end(), '.', '_');
844 std::replace(lambda_function_name.begin(), lambda_function_name.end(), '<', '_');
845 std::replace(lambda_function_name.begin(), lambda_function_name.end(), '>', '_');
846 constexpr auto lambda_suffix = "_lambda_"; // Represent <lambda>.
847 auto function_name = lambda_function_name + "_" + lambda_suffix;
848 MS_LOG(DEBUG) << "The function name is " << function_name;
849 MS_EXCEPTION_IF_NULL(current_fg->debug_info());
850 current_fg->debug_info()->set_name(function_name);
851 GenerateArgsNodeForFunction(func_block, node);
852
853 // When parsing the top graph of construct, save the top graph
854 if (GetTopFuncGraph() == nullptr) {
855 UpdateTopFuncGraph(func_block->func_graph());
856 }
857
858 py::object body_node = python_adapter::GetPyObjAttr(node, "body");
859 AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
860 current_fg->set_output(lambda_body_node);
861
862 // Add unused variables as isolate nodes.
863 for (auto &func_block_item : func_block_list_) {
864 MS_EXCEPTION_IF_NULL(func_block_item);
865 MS_EXCEPTION_IF_NULL(func_block_item->func_graph());
866 if (!func_block_item->isolated_nodes().empty()) {
867 // Find unused variables.
868 func_block_item->FindIsolatedNodes();
869 // Attach all isolated nodes.
870 func_block_item->AttachIsolatedNodesBeforeReturn();
871 }
872 }
873
874 GenerateArgsDefaultValueForFunction(func_block, node);
875 return func_block;
876 }
877
ParseStatements(const FunctionBlockPtr & block,const py::object & nodes)878 FunctionBlockPtr Parser::ParseStatements(const FunctionBlockPtr &block, const py::object &nodes) {
879 auto node_list = py::cast<py::list>(nodes);
880 size_t count = py::len(node_list);
881 MS_LOG(DEBUG) << "The nodes count is " << count;
882 auto sub_block = block;
883 for (size_t i = 0; i < count; ++i) {
884 MS_LOG(DEBUG) << "Start parse statement[" << i << "]: " << py::str(node_list[i])
885 << ", block: " << sub_block->ToString();
886 auto node = node_list[i];
887 // Flag of return statement is set on sub_block inside ParseStatement, so use next_block
888 // to store the returned block temporarily.
889 auto next_block = ParseStatement(sub_block, node);
890 MS_EXCEPTION_IF_NULL(next_block);
891 MS_EXCEPTION_IF_NULL(next_block->func_graph());
892 // Propagate flag of return statement back;
893 if (sub_block != block && sub_block->is_return_statement_inside()) {
894 MS_LOG(DEBUG) << "Sub block: " << sub_block->ToString()
895 << " has return statement inside, propagate flag back to block: " << block->ToString();
896 block->set_is_return_statement_inside();
897 }
898 // Propagate flag of break or continue statement back;
899 if (sub_block != block && sub_block->is_break_continue_statement_inside()) {
900 MS_LOG(DEBUG) << "Sub block: " << sub_block->ToString()
901 << " has break or continue statement inside, propagate flag back to block: " << block->ToString();
902 block->set_break_continue_statement_inside();
903 }
904 sub_block = next_block;
905
906 static const auto boost_parse = common::GetCompileConfig("BOOST_PARSE");
907 if (boost_parse != "0" && sub_block->is_dead_block()) {
908 break;
909 }
910 if (boost_parse == "0") {
911 // Insert appropriate depended items for the function block if it has a return node
912 if (sub_block->func_graph()->get_return() != nullptr || sub_block->is_dead_block()) {
913 // If break is not the last expr.
914 if (i != count - 1) {
915 TraceGuard trace_guard(GetLocation(node_list[i + 1]));
916 MS_LOG(EXCEPTION) << "Dead code exist, please remove it. [" << (i + 1) << "/" << count
917 << "], node: " << py::str(node_list[i]) << ", block: " << sub_block->ToString()
918 << ", has_return: " << (sub_block->func_graph()->get_return() != nullptr)
919 << ", is_dead_block: " << sub_block->is_dead_block();
920 }
921 // Skip statements after 'return' (or 'break', 'continue').
922 break;
923 }
924 }
925 // If the current block has multi return statements,
926 // only parse the statements before first return statement.
927 // Statements after the continue and break statements are also not parsed.
928 if (ast_->GetNodeType(node)->node_name() == "Break" || ast_->GetNodeType(node)->node_name() == "Continue" ||
929 ast_->GetNodeType(node)->node_name() == "Return") {
930 break;
931 }
932 }
933 return sub_block;
934 }
935
ParseStatement(const FunctionBlockPtr & block,const py::object & node)936 FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) {
937 TraceGuard trace_guard(GetLocation(node));
938 auto node_type = ast_->GetNodeType(node);
939
940 // Check the node type
941 AstMainType nodeType = node_type->main_type();
942 if (nodeType != AST_MAIN_TYPE_STMT) {
943 MS_LOG(INFO) << "Node type is error : " << nodeType;
944 return block;
945 }
946 // Call the process function
947 std::string node_name = node_type->node_name();
948 MS_LOG(DEBUG) << "Ast node is " << node_name << ", location:" << GetLocation(node)->ToString();
949 if (stmt_method_map_.count(node_name) != 0) {
950 auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
951 return stmt_block;
952 } else {
953 errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
954 MS_LOG(EXCEPTION) << "Unsupported statement '" << node_name
955 << "'.\nMore details please refer to syntax support at https://www.mindspore.cn";
956 }
957 }
958
ParseExprNode(const FunctionBlockPtr & block,const py::object & node)959 AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) {
960 MS_LOG(DEBUG) << "Process ast expr.";
961 TraceGuard trace_guard(GetLocation(node));
962 auto node_type = ast_->GetNodeType(node);
963 // Check the node type
964 AstMainType node_main_type = node_type->main_type();
965 if (node_main_type != AST_MAIN_TYPE_EXPR) {
966 errcode_ = PARSE_NODE_TYPE_NO_MATCH;
967 MS_LOG(INTERNAL_EXCEPTION) << "Node type is error : " << node_main_type;
968 }
969 // Call the process function
970 const std::string &node_type_name = node_type->node_name();
971 MS_LOG(DEBUG) << "Ast node is " << node_type_name << ", location:" << GetLocation(node)->ToString();
972 if (expr_method_map_.count(node_type_name) != 0) {
973 auto expr_node = (this->*expr_method_map_[node_type_name])(block, node);
974 MS_LOG(DEBUG) << "Get parsed anf node:" << expr_node->DebugString();
975 return expr_node;
976 } else {
977 errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
978 MS_LOG(EXCEPTION) << "Unsupported expression '" << node_type_name
979 << "'.\nMore details please refer to syntax support at https://www.mindspore.cn";
980 }
981 }
982
983 // If self.attr.func is inplace operation, then
984 // self.attr.func(inputs)
985 // -->
986 // self.attr = self.attr.func(inputs)
987 // setattr(self, "attr", self.attr.func(inputs))
HandleSetAttrClassMemberForInplace(const FunctionBlockPtr & block,const AnfNodePtr & node)988 bool Parser::HandleSetAttrClassMemberForInplace(const FunctionBlockPtr &block, const AnfNodePtr &node) {
989 if (!node->isa<CNode>()) {
990 return false;
991 }
992 auto cnode = node->cast<CNodePtr>();
993 auto call_node = cnode->input(0);
994 if (!IsPrimitiveCNode(call_node, prim::kPrimGetAttr)) {
995 return false;
996 }
997 constexpr int recursive_level = 2;
998 // call_cnode: self.attr.func
999 auto call_cnode = call_node->cast<CNodePtr>();
1000 MS_LOG(DEBUG) << "call cnode: " << call_cnode->DebugString(recursive_level);
1001 const auto &call_cnode_inputs = call_cnode->inputs();
1002 constexpr size_t attr_node_index = 1;
1003 constexpr size_t func_str_index = 2;
1004 auto func_str_node = call_cnode_inputs[func_str_index];
1005 if (!IsValueNode<StringImm>(func_str_node)) {
1006 return false;
1007 }
1008 const auto &func_str = GetValue<std::string>(GetValueNode(func_str_node));
1009 std::vector<std::string> inplace_ops{"extend", "pop", "insert", "reverse"};
1010 MS_LOG(DEBUG) << "func str: " << func_str;
1011 if (std::all_of(inplace_ops.begin(), inplace_ops.end(),
1012 [&func_str](const std::string &ops) { return func_str != ops; })) {
1013 return false;
1014 }
1015 auto attr_node = call_cnode_inputs[attr_node_index];
1016 if (!attr_node->isa<CNode>()) {
1017 return false;
1018 }
1019 // attr_cnode: self.attr
1020 auto attr_cnode = attr_node->cast<CNodePtr>();
1021 MS_LOG(DEBUG) << "attr cnode: " << attr_cnode->DebugString(recursive_level);
1022 const auto &attr_cnode_inputs = attr_cnode->inputs();
1023 constexpr size_t target_index = 1;
1024 constexpr size_t attr_index = 2;
1025 auto target_node = attr_cnode_inputs[target_index];
1026 MS_LOG(DEBUG) << "target node: " << target_node->DebugString();
1027 auto target_attr_node = attr_cnode_inputs[attr_index];
1028 auto symbol_val = GetValueNode<SymbolPtr>(target_attr_node);
1029 if (symbol_val == nullptr) {
1030 return false;
1031 }
1032 const auto &target_symbol_str = symbol_val->symbol();
1033 MakeSetAttrNode(block, target_node, cnode, "self", target_symbol_str);
1034 return true;
1035 }
1036
1037 // Process the expr statement and expand it
ParseExpr(const FunctionBlockPtr & block,const py::object & node)1038 FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) {
1039 MS_LOG(DEBUG) << "Process ast Expr";
1040 // Expr only have value, no target
1041 py::tuple expand_info = ast_->CallParseModFunction(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node);
1042
1043 // Refer python function expand_expr_statement, expand_info is one of the following:
1044 // True, expr.value, x
1045 // True, expr.value
1046 // False, None, None
1047 //
1048 // Check the expand info result
1049 if (expand_info.empty()) {
1050 MS_LOG(INTERNAL_EXCEPTION) << "Empty expand_info.";
1051 }
1052 auto is_expand = py::cast<bool>(expand_info[0]);
1053 if (is_expand) {
1054 // Process the expr statement
1055 constexpr size_t expect_size = 2;
1056 if (expand_info.size() < expect_size) {
1057 MS_LOG(INTERNAL_EXCEPTION) << "expand_info size:" << expand_info.size() << " less than " << expect_size << ".";
1058 }
1059 py::object value_object = expand_info[1];
1060 // Make a Expr CNode.
1061 AnfNodePtr call_node = ParseExprNode(block, value_object);
1062 if (py::len(expand_info) == expect_size) {
1063 // list_x.pop(a) does not write the return value of pop.
1064 // --> list_x = list_x.pop(a) need renew the list_x.
1065 if (IsPopOperation(call_node)) {
1066 if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE && ast_->IsClassMemberOfSelf(list_pop_target_obj_)) {
1067 // self.list_x = [xx, xx]
1068 // self.list_x.pop()
1069 MS_LOG(DEBUG) << "The variables whose type is not parameter do not support pop operation.";
1070 } else {
1071 auto func_graph = block->func_graph();
1072 MS_EXCEPTION_IF_NULL(func_graph);
1073 auto new_list = func_graph->NewCNodeInOrder(
1074 {NewValueNode(prim::kPrimTupleGetItem), call_node, NewValueNode(SizeToLong(0))});
1075 WriteAssignVars(block, list_pop_target_obj_, new_list);
1076 block->AddIsolatedNode(call_node);
1077 return block;
1078 }
1079 }
1080 // Expression that not assigned to any variable.
1081 // This is usually a call with side effects.
1082 // e.g.: print(x)
1083 // We save it as an isolated node.
1084 auto &no_return_node = call_node;
1085 MS_LOG(INFO) << "Isolated node found(NoReturn), no_return_node: " << no_return_node->DebugString()
1086 << ", block: " << block << "/"
1087 << (block->func_graph() ? block->func_graph()->ToString() : "FG(Null)")
1088 << ", Line: " << trace::GetDebugInfoStr(no_return_node->debug_info(), "", kSourceLineTipDiscard);
1089 block->AddIsolatedNode(no_return_node);
1090 } else {
1091 // Expand the assign statement,
1092 // e.g.: x.append(y) -> x = x.append(y)
1093 py::object target_node = expand_info[2];
1094 // Check whether the target_node is class member recursively.
1095 // e.g.: self.a1.a1.update()
1096 if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE && ast_->IsClassMemberRecursive(target_node)) {
1097 // self.x = [xx, xx]
1098 // self.x.append()
1099 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
1100 if (!allow_fallback_runtime || !HandleSetAttrClassMemberForInplace(block, call_node)) {
1101 block->AddIsolatedNode(call_node);
1102 }
1103 } else {
1104 WriteAssignVars(block, target_node, call_node);
1105 }
1106 }
1107 }
1108 return block;
1109 }
1110
GetLocation(const py::object & node) const1111 LocationPtr Parser::GetLocation(const py::object &node) const {
1112 MS_EXCEPTION_IF_NULL(ast_);
1113 py::list res = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
1114 constexpr size_t list_size = 6;
1115 if (res.size() < list_size) {
1116 MS_LOG(INTERNAL_EXCEPTION) << "List size should not be less than 5.";
1117 }
1118 constexpr size_t file_name_index = 0;
1119 constexpr size_t line_index = 1;
1120 constexpr size_t column_index = 2;
1121 constexpr size_t line_end_index = 3;
1122 constexpr size_t column_end_index = 4;
1123 constexpr size_t expr_src_index = 5;
1124 constexpr size_t comments_index = 6;
1125 // Deal with the comments.
1126 std::vector<std::string> comments_str_list;
1127 const auto comments_list = res[comments_index].cast<py::list>();
1128 for (size_t i = 0; i < comments_list.size(); ++i) {
1129 (void)comments_str_list.emplace_back(comments_list[i].cast<std::string>());
1130 }
1131 if (!comments_str_list.empty()) {
1132 MS_LOG(DEBUG) << "@jit comments: " << comments_str_list;
1133 }
1134 // Refer to Location::Location() for each member of res: line, column, line_end, column_end, expr_src.
1135 auto location = std::make_shared<Location>(res[file_name_index].cast<std::string>(), res[line_index].cast<int64_t>(),
1136 res[column_index].cast<int64_t>(), res[line_end_index].cast<int64_t>(),
1137 res[column_end_index].cast<int64_t>(),
1138 res[expr_src_index].cast<std::string>(), std::move(comments_str_list));
1139 MS_LOG(DEBUG) << "node: " << py::str(node) << ",\n" << location->DebugString();
1140 return location;
1141 }
1142
1143 // NOTICE: Must add a TraceGuard before call it.
MakeFunctionBlock()1144 FunctionBlockPtr Parser::MakeFunctionBlock() {
1145 FunctionBlockPtr block = std::make_shared<FunctionBlock>(*this);
1146 // In order to keep effect order in the sub-graphs which generated by control flow.
1147 // We copy the flags from the top graph to the sub-graphs.
1148 if (func_graph_ && !func_graph_->attrs().empty()) {
1149 for (const auto &attr : func_graph_->attrs()) {
1150 // The flag FUNC_GRAPH_OUTPUT_NO_RECOMPUTE should be only set in the top graph.
1151 if (attr.first != FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) {
1152 block->func_graph()->set_attr(attr.first, attr.second);
1153 }
1154 }
1155 }
1156 func_block_list_.push_back(block);
1157 return block;
1158 }
1159
MakeFunctionBlock(const TraceInfoPtr & trace_info)1160 FunctionBlockPtr Parser::MakeFunctionBlock(const TraceInfoPtr &trace_info) {
1161 TraceGuard trace_guard(trace_info);
1162 FunctionBlockPtr block = MakeFunctionBlock();
1163 return block;
1164 }
1165
MakeConditionBlocks(const FunctionBlockPtr & pre_block,const FunctionBlockPtr & true_block,const FunctionBlockPtr & false_block) const1166 void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
1167 const FunctionBlockPtr &false_block) const {
1168 MS_EXCEPTION_IF_NULL(true_block);
1169 MS_EXCEPTION_IF_NULL(false_block);
1170 true_block->AddPrevBlock(pre_block);
1171 true_block->Mature();
1172
1173 false_block->AddPrevBlock(pre_block);
1174 false_block->Mature();
1175
1176 true_block->UpdateGlobalPyParam(pre_block->global_py_params());
1177 false_block->UpdateGlobalPyParam(pre_block->global_py_params());
1178 }
1179
ParseReturn(const FunctionBlockPtr & block,const py::object & node)1180 FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
1181 MS_LOG(DEBUG) << "Process ast return";
1182 MS_EXCEPTION_IF_NULL(block);
1183 // Parse the return Statements value.
1184 py::object value_object = python_adapter::GetPyObjAttr(node, "value");
1185 AnfNodePtr return_expr_node = ParseExprNode(block, value_object);
1186 // Create the `return` CNode.
1187 auto func_graph = block->func_graph();
1188 CNodePtr return_cnode = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), return_expr_node});
1189 func_graph->set_return(return_cnode);
1190 MS_LOG(DEBUG) << "Inside the block has return statement, block: " << block->ToString();
1191 block->set_is_return_statement_inside();
1192 return block;
1193 }
1194
1195 // Process binary operators,eg: `a + b`, `a | b`, etc.
ParseBinOp(const FunctionBlockPtr & block,const py::object & node)1196 AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) {
1197 MS_LOG(DEBUG) << "Process ast BinOP";
1198
1199 MS_EXCEPTION_IF_NULL(block);
1200 py::object left = python_adapter::GetPyObjAttr(node, "left");
1201 py::object right = python_adapter::GetPyObjAttr(node, "right");
1202 py::object op = python_adapter::GetPyObjAttr(node, "op");
1203 // Create left and right ANF node
1204 AnfNodePtr left_node = ParseExprNode(block, left);
1205 if (left_node == nullptr) {
1206 MS_LOG(INTERNAL_EXCEPTION) << "DoBinOp process left node failed: " << errcode();
1207 }
1208 AnfNodePtr right_node = ParseExprNode(block, right);
1209 if (right_node == nullptr) {
1210 MS_LOG(INTERNAL_EXCEPTION) << "DoBinOp process right node failed:" << errcode();
1211 }
1212 // Resolve the op
1213 const auto &ns = block->GetAstOpNameSpace(op);
1214 auto op_node = block->MakeResolveAstOpNameSpace(ns);
1215
1216 // Create apply node
1217 MS_EXCEPTION_IF_NULL(block->func_graph());
1218 AnfNodePtr new_node = block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
1219 // Handling % symbol in formatted string values by JIT Fallback.
1220 // The string AnfNode may be created by ParseJoinedStr or ParseStr.
1221 // For example, string % var, f"The string is: %s." % str or "The number is: %d." % num
1222 constexpr size_t symbol_index = 1;
1223 SymbolPtr symbol = std::make_shared<Symbol>(ns[symbol_index].cast<std::string>());
1224 // Only support the pattern (string % xxx) by fallback.
1225 if (symbol != nullptr && symbol->symbol() == "mod") {
1226 if (IsPrimitiveCNode(left_node, prim::kPrimJoinedStr)) {
1227 // left_node created by ParseJoinedStr
1228 auto inputs = left_node->cast<CNodePtr>()->inputs();
1229 if (inputs.size() <= 1) {
1230 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected maketuple node:" << left_node->DebugString();
1231 }
1232 auto str_node = inputs[1];
1233 if (IsValueNode<StringImm>(str_node)) {
1234 new_node->set_interpret(true);
1235 auto new_interpret_node = HandleInterpret(block, new_node, node);
1236 return new_interpret_node;
1237 }
1238 } else if (IsValueNode<StringImm>(left_node)) {
1239 // left_node created by ParseStr
1240 new_node->set_interpret(true);
1241 auto new_interpret_node = HandleInterpret(block, new_node, node);
1242 return new_interpret_node;
1243 }
1244 }
1245 return new_node;
1246 }
1247
ParseName(const FunctionBlockPtr & block,const py::object & node)1248 AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) {
1249 MS_LOG(DEBUG) << "Process ast Name";
1250 auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
1251 MS_LOG(DEBUG) << "The Name id is " << name_id;
1252 MS_EXCEPTION_IF_NULL(block);
1253 // The Tensor object will be parsed into an Interpret node. For example, Tensor(0).astype("int32")
1254 if (block->IsGlobalVar(name_id) || name_id == "Tensor") {
1255 MS_LOG(DEBUG) << "name_id: " << name_id;
1256 AnfNodePtr res = block->MakeResolveSymbol(name_id);
1257 block->CheckUndefinedSymbol(name_id, res);
1258 return res;
1259 }
1260
1261 AnfNodePtr res = block->ReadVariable(name_id);
1262 block->CheckUndefinedSymbol(name_id, res);
1263 return res;
1264 }
1265
ParseNone(const FunctionBlockPtr &,const py::object &)1266 AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
1267 MS_LOG(DEBUG) << "Process ast NoneType";
1268 return NewValueNode(kNone);
1269 }
1270
ParseEllipsis(const FunctionBlockPtr &,const py::object &)1271 AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
1272 MS_LOG(DEBUG) << "Process ast Ellipsis";
1273 return NewValueNode(kEllipsis);
1274 }
1275
ParseNum(const FunctionBlockPtr &,const py::object & node)1276 AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
1277 MS_LOG(DEBUG) << "Process ast Num";
1278 py::object obj = python_adapter::GetPyObjAttr(node, "n");
1279 if (py::isinstance<py::int_>(obj)) {
1280 MS_LOG(INFO) << "The Num is int64_t:" << (std::string)py::str(obj);
1281 auto data = py::cast<int64_t>(obj);
1282 return NewValueNode(data);
1283 } else if (py::isinstance<py::float_>(obj)) {
1284 MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj);
1285 auto data = py::cast<float>(obj);
1286 auto res = NewValueNode(data);
1287 auto fp32_val = res->value()->cast<FP32ImmPtr>();
1288 if (fp32_val != nullptr) {
1289 MS_LOG(DEBUG) << "Set float64 value to FP32Imm.";
1290 fp32_val->set_prim_value(py::cast<double>(obj));
1291 }
1292 return res;
1293 } else {
1294 // no else actually
1295 errcode_ = PARSE_NODE_TYPE_UNKNOWN;
1296 MS_EXCEPTION(TypeError) << "Only support 'Number' type of 'int` and 'float', but got type: " << obj.get_type()
1297 << " Value:" << py::str(obj);
1298 }
1299 }
1300
ParseStr(const FunctionBlockPtr &,const py::object & node)1301 AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) {
1302 MS_LOG(DEBUG) << "Process ast Str";
1303 auto str_s = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "s"));
1304 return NewValueNode(str_s);
1305 }
1306
ParseConstant(const FunctionBlockPtr &,const py::object & node)1307 AnfNodePtr Parser::ParseConstant(const FunctionBlockPtr &, const py::object &node) {
1308 MS_LOG(DEBUG) << "Process ast Constant";
1309 py::object obj = python_adapter::GetPyObjAttr(node, "value");
1310 if (py::isinstance<py::bool_>(obj)) {
1311 MS_LOG(INFO) << "The Constant is bool:" << (std::string)py::str(obj);
1312 return NewValueNode(py::cast<bool>(obj));
1313 } else if (py::isinstance<py::int_>(obj)) {
1314 MS_LOG(INFO) << "The Constant is int64_t:" << (std::string)py::str(obj);
1315 return NewValueNode(py::cast<int64_t>(obj));
1316 } else if (py::isinstance<py::float_>(obj)) {
1317 MS_LOG(INFO) << "The Constant is float:" << (std::string)py::str(obj);
1318 auto data = py::cast<float>(obj);
1319 auto res = NewValueNode(data);
1320 auto fp32_val = res->value()->cast<FP32ImmPtr>();
1321 if (fp32_val != nullptr) {
1322 MS_LOG(DEBUG) << "Set float64 value to FP32Imm.";
1323 fp32_val->set_prim_value(py::cast<double>(obj));
1324 }
1325 return res;
1326 } else if (py::isinstance<py::str>(obj)) {
1327 MS_LOG(INFO) << "The Constant is string:" << (std::string)py::str(obj);
1328 return NewValueNode(py::cast<std::string>(obj));
1329 } else if (py::isinstance<py::none>(obj)) {
1330 MS_LOG(INFO) << "The Constant is none:" << (std::string)py::str(obj);
1331 return NewValueNode(kNone);
1332 } else if (py::isinstance<py::ellipsis>(obj)) {
1333 MS_LOG(INFO) << "The Constance is ellipsis:" << (std::string)py::str(obj);
1334 return NewValueNode(kEllipsis);
1335 } else {
1336 // no else actually
1337 MS_EXCEPTION(TypeError) << "Unsupported Constant type : " << (std::string)py::str(obj);
1338 }
1339 }
1340
ParseNameConstant(const FunctionBlockPtr &,const py::object & node)1341 AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) {
1342 MS_LOG(DEBUG) << "Process ast NameConstant";
1343 py::object obj = python_adapter::GetPyObjAttr(node, "value");
1344 if (py::isinstance<py::bool_>(obj)) {
1345 MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj);
1346 auto data = py::cast<bool>(obj);
1347 return NewValueNode(data);
1348 } else if (py::isinstance<py::none>(obj)) {
1349 MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
1350 return NewValueNode(kNone);
1351 }
1352 // no else actually
1353 errcode_ = PARSE_NODE_TYPE_UNKNOWN;
1354 MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
1355 }
1356
GenerateMakeTuple(const FunctionBlockPtr & block,const std::vector<AnfNodePtr> & element_nodes)1357 AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
1358 MS_EXCEPTION_IF_NULL(block);
1359 AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
1360 std::vector<AnfNodePtr> make_tuple_nodes;
1361 make_tuple_nodes.push_back(make_tuple_op);
1362 (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes),
1363 [](AnfNodePtr arg) -> AnfNodePtr { return arg; });
1364 MS_EXCEPTION_IF_NULL(block->func_graph());
1365 return block->func_graph()->NewCNodeInOrder(std::move(make_tuple_nodes));
1366 }
1367
ParseSuper(const FunctionBlockPtr & block,const py::list & args)1368 AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
1369 MS_EXCEPTION_IF_NULL(block);
1370 py::object father_class;
1371 const size_t expect_args_size = 2;
1372 if (args.empty()) {
1373 father_class = py::none();
1374 } else if (args.size() == expect_args_size) {
1375 father_class = args[0];
1376 auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[1])));
1377 if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") {
1378 MS_EXCEPTION(ArgumentError) << "Argument 2 of 'super()' must be 'self', but got '"
1379 << py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) << "'.";
1380 }
1381 } else {
1382 MS_EXCEPTION(ArgumentError) << "Arguments number of 'super()' should be 0 or 2, but got " << args.size() << ".";
1383 }
1384 py::object target_class_instance = ast_->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast_->obj());
1385 py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
1386 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
1387 SymbolPtr symbol = std::make_shared<Symbol>("namespace");
1388 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
1389 return block->MakeResolve(name_space, symbol);
1390 }
1391
HandleStrInError(const FunctionBlockPtr & block,const py::list & args,std::vector<AnfNodePtr> * str_nodes)1392 void Parser::HandleStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *str_nodes) {
1393 for (size_t i = 0; i < args.size(); ++i) {
1394 AnfNodePtr node = ParseExprNode(block, args[i]);
1395 (void)str_nodes->emplace_back(node);
1396 }
1397 }
1398
HandleException(const FunctionBlockPtr & block,const py::list & args,const std::string & name)1399 std::vector<AnfNodePtr> Parser::HandleException(const FunctionBlockPtr &block, const py::list &args,
1400 const std::string &name) {
1401 auto exception_type_node = NewValueNode(name);
1402 std::vector<AnfNodePtr> node_inputs = {exception_type_node};
1403 HandleStrInError(block, args, &node_inputs);
1404 return node_inputs;
1405 }
1406
ParseRaiseCall(const FunctionBlockPtr & block,const py::object & node)1407 std::vector<AnfNodePtr> Parser::ParseRaiseCall(const FunctionBlockPtr &block, const py::object &node) {
1408 MS_LOG(DEBUG) << "Process ast Call, the current node is raise.";
1409 // Process function call
1410 py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
1411 // Process raise ValueError
1412 if (py::isinstance<py::none>(function_ast_node)) {
1413 auto name = python_adapter::GetPyObjAttr(node, "id");
1414 auto name_id = py::cast<std::string>(name);
1415 if (block->CheckHasVariable(name_id)) {
1416 auto error_node = block->ReadVariable(name_id);
1417 block->CheckUndefinedSymbol(name_id, error_node);
1418 return {NewValueNode(name_id), error_node};
1419 } else if (exception_types_map.find(name_id) != exception_types_map.end()) {
1420 auto str_value = std::make_shared<StringImm>("None");
1421 return {NewValueNode(name_id), NewValueNode(str_value)};
1422 } else {
1423 MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id
1424 << ". Raise only support some Python standard exception types: "
1425 << SupportedExceptionsToString();
1426 }
1427 }
1428
1429 py::list args = python_adapter::GetPyObjAttr(node, "args");
1430
1431 auto arg_type =
1432 AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
1433 if (arg_type == AST_SUB_TYPE_NAME) {
1434 auto name = python_adapter::GetPyObjAttr(function_ast_node, "id");
1435 auto name_id = py::cast<std::string>(name);
1436 MS_LOG(DEBUG) << "The name of call node is: " << name_id;
1437 auto node_list = HandleException(block, args, name_id);
1438 if (block->CheckHasVariable(name_id)) {
1439 auto error_node = block->ReadVariable(name_id);
1440 block->CheckUndefinedSymbol(name_id, error_node);
1441 (void)node_list.emplace_back(error_node);
1442 return node_list;
1443 } else if (exception_types_map.find(name_id) != exception_types_map.end()) {
1444 auto str_value = std::make_shared<StringImm>("None");
1445 (void)node_list.emplace_back(NewValueNode(str_value));
1446 return node_list;
1447 } else {
1448 MS_LOG(EXCEPTION) << "Unsupported exception type: " << name_id
1449 << ". Raise only support some Python standard exception types: "
1450 << SupportedExceptionsToString();
1451 }
1452 }
1453 return {};
1454 }
1455
CompareIs(const FunctionBlockPtr &,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1456 bool Parser::CompareIs(const FunctionBlockPtr &, const py::object &left_obj, const py::object &comparator_obj,
1457 bool *bool_res) const {
1458 auto comparator_type_name = ast_->GetNodeType(comparator_obj)->node_name();
1459 // The type_name is "Constant" in py3.9.
1460 if (comparator_type_name != "NameConstant" && comparator_type_name != "Constant") {
1461 return false;
1462 }
1463 // xxx is None, the comparator must be a NameConstant.
1464 py::object name_constant_value = python_adapter::GetPyObjAttr(comparator_obj, "value");
1465 MS_LOG(DEBUG) << "name_constant_value: " << py::str(name_constant_value);
1466
1467 // Compare with None.
1468 if (py::isinstance<py::none>(name_constant_value)) {
1469 *bool_res = py::isinstance<py::none>(left_obj);
1470 return true;
1471 }
1472 // To add more NameConstants.
1473 return false;
1474 }
1475
CompareIsNot(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1476 bool Parser::CompareIsNot(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
1477 bool *bool_res) const {
1478 if (!CompareIs(block, left_obj, comparator_obj, bool_res)) {
1479 return false;
1480 }
1481 *bool_res = !(*bool_res);
1482 return true;
1483 }
1484
CompareEqual(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1485 bool Parser::CompareEqual(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
1486 bool *bool_res) const {
1487 auto left_obj_type_name = ast_->GetNodeType(left_obj)->node_name();
1488 if (left_obj_type_name == "Tensor" || left_obj_type_name == "Parameter") {
1489 return false;
1490 }
1491 auto comparator_type_name = ast_->GetNodeType(comparator_obj)->node_name();
1492 MS_LOG(DEBUG) << "comparator_type_name: " << comparator_type_name;
1493 if (comparator_type_name == "Num") {
1494 py::object num_value = python_adapter::GetPyObjAttr(comparator_obj, "n");
1495 MS_LOG(DEBUG) << "num_value: " << py::str(num_value);
1496 if (!py::isinstance<py::int_>(num_value) && !py::isinstance<py::float_>(num_value)) {
1497 return false;
1498 }
1499 *bool_res = left_obj.equal(num_value);
1500 return true;
1501 }
1502 if (comparator_type_name == "Str") {
1503 if (!py::isinstance<py::str>(left_obj)) {
1504 *bool_res = false;
1505 return true;
1506 }
1507 py::object str_value = python_adapter::GetPyObjAttr(comparator_obj, "s");
1508 auto left_obj_str = left_obj.cast<std::string>();
1509 *bool_res = (left_obj_str == str_value.cast<std::string>());
1510 return true;
1511 }
1512 if (comparator_type_name == "NameConstant") {
1513 py::object name_constant_value = python_adapter::GetPyObjAttr(comparator_obj, "value");
1514 MS_LOG(DEBUG) << "name_constant_value: " << py::str(name_constant_value);
1515 if (!py::isinstance<py::none>(name_constant_value)) {
1516 return false;
1517 }
1518 *bool_res = py::isinstance<py::none>(left_obj);
1519 return true;
1520 }
1521 if (comparator_type_name == "Attribute") {
1522 bool is_constant;
1523 auto attr_cond = GetPyObjForAstAttr(block, comparator_obj, &is_constant);
1524 if (!is_constant) {
1525 return false;
1526 }
1527 *bool_res = left_obj.equal(attr_cond);
1528 return true;
1529 }
1530 // The type_name is "Constant" in py3.9.
1531 if (comparator_type_name == "Constant") {
1532 py::object constant_value = python_adapter::GetPyObjAttr(comparator_obj, "value");
1533 MS_LOG(DEBUG) << "constant_value: " << py::str(constant_value);
1534 if (!py::isinstance<py::int_>(constant_value) && !py::isinstance<py::float_>(constant_value) &&
1535 !py::isinstance<py::str>(constant_value)) {
1536 return false;
1537 }
1538 *bool_res = left_obj.equal(constant_value);
1539 return true;
1540 }
1541 return false;
1542 }
1543
CompareNotEqual(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1544 bool Parser::CompareNotEqual(const FunctionBlockPtr &block, const py::object &left_obj,
1545 const py::object &comparator_obj, bool *bool_res) const {
1546 if (!CompareEqual(block, left_obj, comparator_obj, bool_res)) {
1547 return false;
1548 }
1549 *bool_res = !(*bool_res);
1550 return true;
1551 }
1552
CompareGreater(const FunctionBlockPtr &,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1553 bool Parser::CompareGreater(const FunctionBlockPtr &, const py::object &left_obj, const py::object &comparator_obj,
1554 bool *bool_res) const {
1555 auto comparator_type_name = ast_->GetNodeType(comparator_obj)->node_name();
1556 if (comparator_type_name != "Num" || (!py::isinstance<py::int_>(left_obj) && !py::isinstance<py::float_>(left_obj))) {
1557 return false;
1558 }
1559 py::object num_value = python_adapter::GetPyObjAttr(comparator_obj, "n");
1560 MS_LOG(DEBUG) << "num_value: " << py::str(num_value);
1561
1562 if (!py::isinstance<py::int_>(num_value) && !py::isinstance<py::float_>(num_value)) {
1563 return false;
1564 }
1565 *bool_res = (left_obj > num_value);
1566 return true;
1567 }
1568
CompareGreaterEqual(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1569 bool Parser::CompareGreaterEqual(const FunctionBlockPtr &block, const py::object &left_obj,
1570 const py::object &comparator_obj, bool *bool_res) const {
1571 bool greater = false;
1572 bool equal = false;
1573 if (!CompareGreater(block, left_obj, comparator_obj, &greater) ||
1574 !CompareEqual(block, left_obj, comparator_obj, &equal)) {
1575 return false;
1576 }
1577 if (greater || equal) {
1578 *bool_res = true;
1579 } else {
1580 *bool_res = false;
1581 }
1582 return true;
1583 }
1584
CompareLess(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1585 bool Parser::CompareLess(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
1586 bool *bool_res) const {
1587 bool greater = false;
1588 bool equal = false;
1589 if (!CompareGreater(block, left_obj, comparator_obj, &greater) ||
1590 !CompareEqual(block, left_obj, comparator_obj, &equal)) {
1591 return false;
1592 }
1593 if (greater || equal) {
1594 *bool_res = false;
1595 } else {
1596 *bool_res = true;
1597 }
1598 return true;
1599 }
1600
CompareLessEqual(const FunctionBlockPtr & block,const py::object & left_obj,const py::object & comparator_obj,bool * bool_res) const1601 bool Parser::CompareLessEqual(const FunctionBlockPtr &block, const py::object &left_obj,
1602 const py::object &comparator_obj, bool *bool_res) const {
1603 bool greater = false;
1604 if (!CompareGreater(block, left_obj, comparator_obj, &greater)) {
1605 return false;
1606 }
1607 if (greater) {
1608 *bool_res = false;
1609 } else {
1610 *bool_res = true;
1611 }
1612 return true;
1613 }
1614
GetParameterValue(const AnfNodePtr & parameter) const1615 ValuePtr Parser::GetParameterValue(const AnfNodePtr ¶meter) const {
1616 if (args_value_list_.empty()) {
1617 return nullptr;
1618 }
1619 const auto ¶meters = func_graph_->parameters();
1620 for (size_t i = 0; i < parameters.size(); ++i) {
1621 if (parameters.at(i) == parameter && i < args_value_list_.size()) {
1622 return args_value_list_[i];
1623 }
1624 }
1625 return nullptr;
1626 }
1627
GetBoolObjForAstCompare(const FunctionBlockPtr & block,const py::object & node,bool * bool_res) const1628 bool Parser::GetBoolObjForAstCompare(const FunctionBlockPtr &block, const py::object &node, bool *bool_res) const {
1629 MS_EXCEPTION_IF_NULL(bool_res);
1630 MS_EXCEPTION_IF_NULL(block);
1631 py::list ops = python_adapter::GetPyObjAttr(node, "ops");
1632 if (ops.size() != 1) {
1633 return false;
1634 }
1635 py::object op = ops[0];
1636 py::tuple namespace_var = ast()->CallParseModFunction(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op);
1637 constexpr size_t namespace_size = 3;
1638 if (namespace_var.size() != namespace_size) {
1639 MS_LOG(INTERNAL_EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size();
1640 }
1641 constexpr size_t op_str_index = 2;
1642 std::string op_str = py::str(namespace_var[op_str_index]);
1643 MS_LOG(DEBUG) << "op: " << py::str(op) << ", " << op_str;
1644 auto func_iter = compare_method_map_.find(op_str);
1645 if (func_iter == compare_method_map_.end()) {
1646 return false;
1647 }
1648
1649 py::object left = python_adapter::GetPyObjAttr(node, "left");
1650 py::object left_obj;
1651 auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, left)));
1652 if (arg_type == AST_SUB_TYPE_ATTRIBUTE) {
1653 bool is_constant;
1654 left_obj = GetPyObjForAstAttr(block, left, &is_constant);
1655 if (!is_constant) {
1656 return false;
1657 }
1658 } else {
1659 MS_LOG(DEBUG) << "Not attribute, attr_ast_node: " << py::str(left);
1660 py::object id = python_adapter::GetPyObjAttr(left, "id");
1661 if (!py::isinstance<py::str>(id)) {
1662 return false;
1663 }
1664
1665 auto anf_node = block->ReadVariable(id.cast<std::string>());
1666 if (anf_node == nullptr) {
1667 return false;
1668 }
1669 if (anf_node->isa<ValueNode>()) {
1670 MS_LOG(DEBUG) << "left value node: " << anf_node->DebugString();
1671 left_obj = ValueToPyData(anf_node->cast_ptr<ValueNode>()->value());
1672 } else if (anf_node->isa<Parameter>()) {
1673 MS_LOG(DEBUG) << "left parameter node: " << anf_node->DebugString();
1674 auto value = GetParameterValue(anf_node);
1675 if (value == nullptr || value->ContainsValueAny()) {
1676 return false;
1677 }
1678 left_obj = ValueToPyData(value);
1679 } else {
1680 return false;
1681 }
1682 }
1683 MS_LOG(DEBUG) << "left_obj: " << py::str(left_obj);
1684
1685 py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
1686 if (comparators.size() != 1) {
1687 return false;
1688 }
1689 return (this->*(func_iter->second))(block, left_obj, comparators[0], bool_res);
1690 }
1691
GetPyObjForAstAttr(const FunctionBlockPtr & block,const py::object & attr_ast_node,bool * is_constant) const1692 py::object Parser::GetPyObjForAstAttr(const FunctionBlockPtr &block, const py::object &attr_ast_node,
1693 bool *is_constant) const {
1694 auto attr_value = python_adapter::GetPyObjAttr(attr_ast_node, "value");
1695 auto attr_value_type =
1696 AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, attr_value)));
1697 if (attr_value_type != AST_SUB_TYPE_NAME) {
1698 MS_LOG(DEBUG) << "attr_value: " << py::str(attr_value);
1699 *is_constant = false;
1700 return py::none();
1701 }
1702
1703 auto value_name = py::cast<std::string>(python_adapter::GetPyObjAttr(attr_value, "id"));
1704 auto attr_name = py::cast<std::string>(python_adapter::GetPyObjAttr(attr_ast_node, "attr"));
1705 MS_LOG(DEBUG) << "attr name: " << value_name << "." << attr_name;
1706 py::object py_obj_attr_value = py::none();
1707 if (value_name != "self") {
1708 auto node = block->ReadVariable(value_name);
1709 if (node != nullptr && (node->isa<Parameter>() || IsPrimitiveCNode(node, prim::kPrimMixedPrecisionCast))) {
1710 *is_constant = false;
1711 return py::none();
1712 }
1713 py::tuple attr_namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value_name);
1714 constexpr size_t global_info_size = 4;
1715 // Handle nested function def.
1716 if (attr_namespace_info.size() == global_info_size) {
1717 constexpr size_t value_index = 2;
1718 py_obj_attr_value = attr_namespace_info[value_index];
1719 }
1720 } else {
1721 auto iter = setattr_nodes_map_.find(value_name);
1722 if (iter != setattr_nodes_map_.end()) {
1723 if (iter->second.find(attr_name) != iter->second.end()) {
1724 MS_LOG(DEBUG) << "The self." << attr_name << "has been modified.";
1725 *is_constant = false;
1726 return py::none();
1727 }
1728 }
1729 py_obj_attr_value = ast_->obj();
1730 }
1731 if (py::isinstance<py::none>(py_obj_attr_value) || !py::hasattr(py_obj_attr_value, py::str(attr_name))) {
1732 MS_LOG(DEBUG) << "Not found object for attribute, attr_ast_node: " << py::str(attr_ast_node);
1733 *is_constant = false;
1734 return py::none();
1735 }
1736 *is_constant = true;
1737 return python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_ATTR_FROM_OBJ, py_obj_attr_value,
1738 py::str(attr_name));
1739 }
1740
1741 // Process function call, eg : f1(x, y) ...
ParseCall(const FunctionBlockPtr & block,const py::object & node)1742 AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) {
1743 MS_LOG(DEBUG) << "Process ast Call";
1744 // Process function call
1745 py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
1746 py::list args = python_adapter::GetPyObjAttr(node, "args");
1747
1748 std::string name_id = "";
1749 auto arg_type =
1750 AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
1751 if (arg_type == AST_SUB_TYPE_NAME) {
1752 name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
1753 MS_LOG(DEBUG) << "The name of call node is: " << name_id;
1754 if (name_id == "super") {
1755 return ParseSuper(block, args);
1756 }
1757 }
1758 MS_LOG(DEBUG) << "Process ast Call, name_id: " << name_id;
1759 auto call_function_node = ParseExprNode(block, function_ast_node);
1760
1761 // Function call arguments should be passed in as groups and unpacked later using unpack call
1762 ArgsContext args_context = ArgsContext();
1763 ParseArgsInCall(block, args, &args_context);
1764 ParseKeywordsInCall(block, node, &args_context);
1765
1766 // If the expression is to create Tensor(including adapter tensor) without jit annotation,
1767 // using functional api to create corresponding tensor since the functional api has jit annotation.
1768 auto class_tensor_object = call_function_node->user_data<py::object>(kClassTensorObject);
1769 ClassInstanceType class_tensor_type = CLASS_INSTANCE_TYPE_INVALID;
1770 if (class_tensor_object != nullptr) {
1771 auto call_location = GetLocation(node);
1772 MS_EXCEPTION_IF_NULL(call_location);
1773 const auto &comments = call_location->comments();
1774 if (comments.empty()) {
1775 class_tensor_type = ClassInstanceType(
1776 ast_->CallParserObjMethod(PYTHON_PARSE_GET_CLASS_TENSOR_TYPE, *class_tensor_object).cast<int32_t>());
1777 AnfNodePtr new_call_function_node = nullptr;
1778 if (class_tensor_type == CLASS_INSTANCE_TYPE_TENSOR) {
1779 constexpr auto tensor_func_str = "__ms_tensor_func__";
1780 new_call_function_node = block->MakeResolveSymbol(tensor_func_str);
1781 } else if (class_tensor_type == CLASS_INSTANCE_TYPE_ADAPTER_TENSOR) {
1782 constexpr auto adapter_convert_function = "get_adapter_convert_function";
1783 py::object generate_func = ast_->CallParserObjMethod(adapter_convert_function, *class_tensor_object);
1784 if (!py::isinstance<py::none>(generate_func)) {
1785 new_call_function_node = NewValueNode(ParsePythonCode(generate_func));
1786 }
1787 }
1788 if (new_call_function_node != nullptr) {
1789 MS_LOG(INFO) << "Convert Tensor call node " << call_function_node->DebugString()
1790 << " to functional tensor call node " << new_call_function_node->DebugString();
1791 return GenerateAnfNodeForCall(block, new_call_function_node, args_context);
1792 }
1793 }
1794 }
1795
1796 auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, args_context);
1797 MS_EXCEPTION_IF_NULL(call_cnode);
1798 MS_LOG(DEBUG) << "call_cnode: " << call_cnode->DebugString()
1799 << ", call_function_node: " << call_function_node->DebugString();
1800
1801 // Process bulitin function, for example, sum(np.array(xx))
1802 py::tuple namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, name_id);
1803 constexpr size_t global_info_size = 4;
1804 if (namespace_info.size() == global_info_size) {
1805 constexpr size_t flag_index = 3;
1806 auto syntax_support = namespace_info[flag_index].cast<int32_t>();
1807 // For print, the inputs to the function determine whether the call_cnode is
1808 // a graph node or the interpret node. If the inputs contain interpret node (not Tensor), the call_cnode will
1809 // be interpretive executived. Otherwise, call_cnode will be a graph node.
1810 if (name_id == "print" && args_context.has_interpret_without_internal) {
1811 call_cnode->set_interpret(true);
1812 // Ensure the order of print
1813 call_cnode = fallback::ConvertCNodeToPyExecuteForPrim(call_cnode->cast<CNodePtr>(), name_id);
1814 return call_cnode;
1815 } else if (syntax_support != SYNTAX_SUPPORTED) {
1816 call_cnode->set_interpret(true);
1817 call_cnode = HandleInterpret(block, call_cnode, node);
1818 // For the unsupported type function, if the input to the function contains tensor, the return value of
1819 // the function should be graph node too.
1820 if (args_context.has_interpret_internal) {
1821 call_cnode->set_interpret_internal_type(true);
1822 }
1823 }
1824 }
1825 if (class_tensor_type == CLASS_INSTANCE_TYPE_ADAPTER_TENSOR) {
1826 MS_LOG(DEBUG) << "Current adapter tensor node: " << call_cnode->DebugString();
1827 call_cnode->set_user_data<bool>(fallback::kAdapterTensor, std::make_shared<bool>(true));
1828 }
1829 return call_cnode;
1830 }
1831
MakeUnpackCall(const FuncGraphPtr & func_graph,const AnfNodePtr & call_function_node,const std::vector<AnfNodePtr> & packed_arguments)1832 CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_node,
1833 const std::vector<AnfNodePtr> &packed_arguments) {
1834 MS_EXCEPTION_IF_NULL(func_graph);
1835 std::vector<AnfNodePtr> unpack_call_nodes;
1836 auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
1837 unpack_call_nodes.push_back(unpack_call_op);
1838 unpack_call_nodes.push_back(call_function_node);
1839 (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
1840 [](AnfNodePtr node) -> AnfNodePtr { return node; });
1841 CNodePtr unpack_call = func_graph->NewCNodeInOrder(std::move(unpack_call_nodes));
1842 return unpack_call;
1843 }
1844
GenerateAnfNodeForCall(const FunctionBlockPtr & block,const AnfNodePtr & call_function_node,const ArgsContext & args_context) const1845 AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node,
1846 const ArgsContext &args_context) const {
1847 // If there is keyword arguments or starred, using an unpack_call op to unpack the argument
1848 MS_EXCEPTION_IF_NULL(block);
1849 if (args_context.need_unpack) {
1850 return MakeUnpackCall(block->func_graph(), call_function_node, args_context.packed_arguments);
1851 }
1852 // else there is no keyword arguments and starred, parsed as normal arguments without unpack
1853 const auto &group_arguments = args_context.group_arguments;
1854 if (group_arguments.size() == 0 && IsPrimitiveCNode(call_function_node, prim::kPrimPyInterpret)) {
1855 // call Interpret node is invalid. Do not new call Interpret node.
1856 // %1 = Interpret_node
1857 // %2 = %1()
1858 return call_function_node;
1859 }
1860 std::vector<AnfNodePtr> func_call_nodes;
1861 func_call_nodes.push_back(call_function_node);
1862 (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes),
1863 [](AnfNodePtr node) -> AnfNodePtr { return node; });
1864 MS_EXCEPTION_IF_NULL(block->func_graph());
1865 CNodePtr call_anf_node = block->func_graph()->NewCNodeInOrder(std::move(func_call_nodes));
1866 return call_anf_node;
1867 }
1868
ParseArgsInCall(const FunctionBlockPtr & block,const py::list & args,ArgsContext * args_context)1869 void Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, ArgsContext *args_context) {
1870 MS_LOG(DEBUG) << "Process ast args in call";
1871 MS_EXCEPTION_IF_NULL(args_context);
1872 for (size_t i = 0; i < args.size(); i++) {
1873 auto arg_node = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[i])));
1874 if (arg_node == AST_SUB_TYPE_STARRED) {
1875 if (!args_context->group_arguments.empty()) {
1876 args_context->packed_arguments.push_back(GenerateMakeTuple(block, args_context->group_arguments));
1877 }
1878 args_context->packed_arguments.push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value")));
1879 args_context->group_arguments.clear();
1880 args_context->need_unpack = true;
1881 } else {
1882 MS_LOG(DEBUG) << "args[" << i << "]: " << py::str(args[i]);
1883 AnfNodePtr node = ParseExprNode(block, args[i]);
1884 auto internal = node->interpret_internal_type();
1885 auto interpret_without_internal =
1886 ((node->interpret() || IsPrimitiveCNode(node, prim::kPrimPyInterpret)) && !internal);
1887 if (internal) {
1888 args_context->has_interpret_internal = true;
1889 } else if (interpret_without_internal) {
1890 args_context->has_interpret_without_internal = true;
1891 }
1892 args_context->group_arguments.push_back(node);
1893 }
1894 }
1895 if (!args_context->group_arguments.empty()) {
1896 args_context->packed_arguments.push_back(GenerateMakeTuple(block, args_context->group_arguments));
1897 }
1898 }
1899
ParseKeywordsInCall(const FunctionBlockPtr & block,const py::object & node,ArgsContext * args_context)1900 void Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, ArgsContext *args_context) {
1901 MS_LOG(DEBUG) << "Process ast key words in call";
1902 py::list keywords = python_adapter::GetPyObjAttr(node, "keywords");
1903 if (!keywords.empty()) {
1904 MS_EXCEPTION_IF_NULL(block);
1905 args_context->need_unpack = true;
1906 std::vector<AnfNodePtr> keys;
1907 std::vector<AnfNodePtr> values;
1908 for (size_t index = 0; index < keywords.size(); index++) {
1909 auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg");
1910 auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value");
1911 if (py::isinstance<py::none>(kw_key)) {
1912 args_context->packed_arguments.push_back(ParseExprNode(block, kw_value));
1913 } else {
1914 auto kw_key_c = kw_key.cast<std::string>();
1915 keys.push_back(NewValueNode(kw_key_c));
1916 auto ret_node = ParseExprNode(block, kw_value);
1917 values.push_back(ret_node);
1918 }
1919 }
1920 if (!keys.empty()) {
1921 auto keys_tuple = GenerateMakeTuple(block, keys);
1922 auto values_tuple = GenerateMakeTuple(block, values);
1923 auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
1924 std::vector<AnfNodePtr> make_dict_nodes = {make_dict_op, keys_tuple, values_tuple};
1925 MS_EXCEPTION_IF_NULL(block->func_graph());
1926 args_context->packed_arguments.push_back(block->func_graph()->NewCNodeInOrder(std::move(make_dict_nodes)));
1927 }
1928 }
1929 }
1930
ProcessAttributeWithClassMember(const FunctionBlockPtr & block,const py::object & node) const1931 AnfNodePtr Parser::ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node) const {
1932 MS_EXCEPTION_IF_NULL(block);
1933 std::string var_name = "self.";
1934 std::string attr_name = node.attr("attr").cast<std::string>();
1935 (void)var_name.append(attr_name);
1936 MS_LOG(DEBUG) << "var_name: " << var_name;
1937 auto attr_obj = ast()->obj().attr(attr_name.c_str());
1938 bool check_need_resolve = py::hasattr(ast()->obj(), attr_name.c_str()) &&
1939 (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
1940 py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
1941 py::isinstance<py::str>(attr_obj) || data_converter::IsCellInstance(attr_obj));
1942 if (check_need_resolve) {
1943 AnfNodePtr res = block->MakeResolveSymbol(var_name);
1944 block->CheckUndefinedSymbol(var_name, res);
1945 return res;
1946 }
1947 auto var_node = block->ReadVariable(var_name);
1948 block->CheckUndefinedSymbol(var_name, var_node);
1949 // Process numpy array, eg: self.x = np.array([1, 2])
1950 if (py::hasattr(ast()->obj(), attr_name.c_str()) && data_converter::IsNumpyArrayInstance(attr_obj)) {
1951 var_node->set_interpret(true);
1952 }
1953 return var_node;
1954 }
1955
ParseMsTensor(const FunctionBlockPtr & block,const py::object & node,const py::object & value_body,const AnfNodePtr & value_node)1956 AnfNodePtr Parser::ParseMsTensor(const FunctionBlockPtr &block, const py::object &node, const py::object &value_body,
1957 const AnfNodePtr &value_node) {
1958 if (py::hasattr(value_body, "id")) {
1959 std::string module_name = py::cast<std::string>(python_adapter::GetPyObjAttr(value_body, "id"));
1960 py::dict global_dict = const_cast<py::dict &>(block->global_py_params());
1961 if (global_dict.contains(module_name)) {
1962 py::object module_obj = global_dict[py::str(module_name)];
1963 std::string module_str = py::cast<std::string>(py::str(module_obj));
1964 // The module of Tensor imported from MsAdapter could be:
1965 // module 'msadapter' or module 'msadapter.pytorch' and so on.
1966 if (module_str.find("module 'mindspore'") != std::string::npos ||
1967 module_str.find("module 'mindtorch") != std::string::npos ||
1968 module_str.find("module 'msadapter") != std::string::npos) {
1969 std::string script_text = py::cast<std::string>(ast()->GetAstNodeText(node));
1970 AnfNodePtr interpret_node = MakeInterpretNode(block, value_node, script_text);
1971 interpret_node->set_interpret(true);
1972 interpret_node->set_interpret_internal_type(true);
1973 if ((module_str.find("module 'mindtorch") != std::string::npos ||
1974 module_str.find("module 'msadapter") != std::string::npos) &&
1975 py::hasattr(module_obj, "Tensor")) {
1976 py::object tensor_obj = py::getattr(module_obj, "Tensor");
1977 interpret_node->set_user_data<py::object>(kClassTensorObject, std::make_shared<py::object>(tensor_obj));
1978 }
1979 return interpret_node;
1980 }
1981 }
1982 }
1983 return nullptr;
1984 }
1985
ParseNull(const FunctionBlockPtr & block,const py::object & value_body) const1986 AnfNodePtr Parser::ParseNull(const FunctionBlockPtr &block, const py::object &value_body) const {
1987 if (py::hasattr(value_body, "id")) {
1988 std::string module_name = py::cast<std::string>(python_adapter::GetPyObjAttr(value_body, "id"));
1989 py::dict global_dict = const_cast<py::dict &>(block->global_py_params());
1990 if (global_dict.contains(module_name)) {
1991 py::object module_obj = global_dict[py::str(module_name)];
1992 std::string module_str = py::cast<std::string>(py::str(module_obj));
1993 if (module_str.find("module 'mindspore.common.dtype'") != std::string::npos) {
1994 return NewValueNode(std::make_shared<TypeNull>());
1995 }
1996 }
1997 }
1998 return nullptr;
1999 }
2000
GetGetAttrVectotFromMap(const std::string & obj_name,const std::string & attr_name)2001 std::vector<AnfNodePtr> Parser::GetGetAttrVectotFromMap(const std::string &obj_name, const std::string &attr_name) {
2002 std::vector<AnfNodePtr> getattr_nodes;
2003 auto iter = getattr_nodes_map_.find(obj_name);
2004 if (iter != getattr_nodes_map_.end()) {
2005 auto attr_iter = iter->second.find(attr_name);
2006 if (attr_iter != iter->second.end()) {
2007 getattr_nodes = attr_iter->second;
2008 }
2009 }
2010 return getattr_nodes;
2011 }
2012
GetSetAttrFromMap(const std::string & obj_name,const std::string & attr_name)2013 AnfNodePtr Parser::GetSetAttrFromMap(const std::string &obj_name, const std::string &attr_name) {
2014 auto iter = setattr_nodes_map_.find(obj_name);
2015 if (iter != setattr_nodes_map_.end()) {
2016 auto attr_iter = iter->second.find(attr_name);
2017 if (attr_iter != iter->second.end()) {
2018 return attr_iter->second;
2019 }
2020 }
2021 return nullptr;
2022 }
2023
MakeGetAttrWithInterpret(const std::string & obj_name,const std::string & attr_name,const py::object & getattr_obj,const FuncGraphPtr & cur_fg)2024 AnfNodePtr Parser::MakeGetAttrWithInterpret(const std::string &obj_name, const std::string &attr_name,
2025 const py::object &getattr_obj, const FuncGraphPtr &cur_fg) {
2026 AnfNodePtr setattr_node = GetSetAttrFromMap(obj_name, attr_name);
2027 AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
2028 AnfNodePtr attr_node = NewValueNode(attr_name);
2029 AnfNodePtr ret_node = nullptr;
2030 if (setattr_node != nullptr) {
2031 const auto &interpreted_obj = std::make_shared<InterpretedObject>(getattr_obj);
2032 AnfNodePtr value_node = NewValueNode(interpreted_obj);
2033 auto prev_setattr_fg = setattr_node->func_graph();
2034 MS_EXCEPTION_IF_NULL(prev_setattr_fg);
2035 if (prev_setattr_fg != cur_fg) {
2036 ret_node = cur_fg->NewCNodeInOrder({op_node, value_node, attr_node});
2037 } else {
2038 // Only add to new setattr node input if two nodes is in the same graph.
2039 ret_node = cur_fg->NewCNodeInOrder({op_node, value_node, attr_node, setattr_node});
2040 }
2041 ret_node->set_user_data<bool>(fallback::kObjectAttrChange, std::make_shared<bool>(true));
2042 }
2043 return ret_node;
2044 }
2045
TransPropertyToFunc(const FuncGraphPtr & fg,const AnfNodePtr & node,const py::object & property_net_obj,std::string attr_str)2046 AnfNodePtr TransPropertyToFunc(const FuncGraphPtr &fg, const AnfNodePtr &node, const py::object &property_net_obj,
2047 std::string attr_str) {
2048 py::object property_func = py::none();
2049 try {
2050 property_func = property_net_obj.attr("__class__").attr(py::str(attr_str));
2051 } catch (const std::exception &e) {
2052 MS_LOG(ERROR) << property_net_obj << " has no attribute " << attr_str;
2053 }
2054 py::object property_func_fget = property_func.attr(py::str("fget"));
2055 auto inner_fg = ParsePythonCode(property_func_fget);
2056 std::vector<AnfNodePtr> new_inputs = {NewValueNode(inner_fg)};
2057 new_inputs.push_back(node);
2058 AnfNodePtr call_func_node = fg->NewCNodeInOrder(new_inputs);
2059 MS_LOG(DEBUG) << "call_func_node:" << call_func_node->DebugString();
2060 return call_func_node;
2061 }
2062
2063 // Process call attributes of class type define, eg: x.y()
ParseAttribute(const FunctionBlockPtr & block,const py::object & node)2064 AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
2065 MS_LOG(DEBUG) << "Process ast Attribute";
2066 auto cur_fg = block->func_graph();
2067 MS_EXCEPTION_IF_NULL(cur_fg);
2068
2069 // Process the get attr
2070 // Use the Primitive replace the operation resolve node (getattr),
2071 // because the getattr will eventually be converted to Primitive node
2072 AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
2073
2074 // Process the node attr
2075 auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
2076 AnfNodePtr attr_node = NewValueNode(attr_str);
2077
2078 // Process the attr body
2079 py::object value_body = python_adapter::GetPyObjAttr(node, "value");
2080 MS_LOG(DEBUG) << "node: " << node << ", attr: " << attr_str << ", value: " << value_body;
2081
2082 // if getting class value 'self', eg: self.xx, use self object.
2083 std::string obj_name;
2084 py::object getattr_obj;
2085 const bool &is_self = ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE && ast()->IsClassMemberOfSelf(node);
2086 if (is_self) {
2087 // Check if the current Attribute is decorated by @property.
2088 py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
2089 bool is_property =
2090 (python_adapter::CallPyModFn(mod, PYTHON_PARSE_CHECK_ATTR_IS_PROPERTY, ast()->obj(), attr_str)).cast<bool>();
2091 if (is_property) {
2092 py::object property_net_obj = ast()->obj();
2093 AnfNodePtr value_node = ParseExprNode(block, value_body);
2094 return TransPropertyToFunc(cur_fg, value_node, property_net_obj, attr_str);
2095 }
2096 obj_name = "self";
2097 getattr_obj = ast()->obj();
2098 AnfNodePtr ret_node;
2099 AnfNodePtr getattr_node = MakeGetAttrWithInterpret(obj_name, attr_str, getattr_obj, cur_fg);
2100 // If setattr before, should make the getattr call into PyExecute also.
2101 if (getattr_node != nullptr) {
2102 ret_node = getattr_node;
2103 // if processing class value 'self', but did not find setattr before getattr, convert getattr later
2104 } else {
2105 ret_node = ProcessAttributeWithClassMember(block, node);
2106 (void)getattr_nodes_map_["self"][attr_str].emplace_back(ret_node);
2107 }
2108 ret_node->set_user_data<py::object>("__getattr__", std::make_shared<py::object>(getattr_obj));
2109 return ret_node;
2110 }
2111 // If not self.xx, process the obj, eg: obj.xx
2112 AnfNodePtr value_node = ParseExprNode(block, value_body);
2113 if (value_node == nullptr) {
2114 MS_LOG(INTERNAL_EXCEPTION) << "Parse attribute failed";
2115 }
2116 // Process xxx.Tensor() and xxx is mindspore.
2117 if (attr_str == "Tensor") {
2118 auto res = ParseMsTensor(block, node, value_body, value_node);
2119 if (res != nullptr) {
2120 return res;
2121 }
2122 }
2123 // For stype._null, return TypeNull value node directly.
2124 if (attr_str == "_null") {
2125 auto res = ParseNull(block, value_body);
2126 if (res != nullptr) {
2127 return res;
2128 }
2129 }
2130 // Create the apply node
2131 AnfNodePtr attr_cnode = cur_fg->NewCNodeInOrder({op_node, value_node, attr_node});
2132
2133 auto value_id_str = GetLocation(value_body)->expr_src();
2134 auto iter = setattr_nodes_map_.find(value_id_str);
2135 if (iter != setattr_nodes_map_.end() && iter->second.find(attr_str) != iter->second.end()) {
2136 attr_cnode->set_user_data<bool>(fallback::kObjectAttrChange, std::make_shared<bool>(true));
2137 }
2138
2139 // Directly resolve the symbol.
2140 if (IsValueNode<parse::NameSpace>(value_node)) {
2141 auto name_space = GetValueNode<parse::NameSpacePtr>(value_node);
2142 MS_EXCEPTION_IF_NULL(name_space);
2143 auto symbol = std::make_shared<parse::Symbol>(attr_str);
2144 attr_cnode = block->DoResolve(attr_cnode, name_space, symbol);
2145 }
2146
2147 if (attr_str == "pop") {
2148 list_pop_target_obj_ = value_body;
2149 }
2150 if (py::hasattr(value_body, "id")) {
2151 // Check the value is side effect operate from third-party module. eg: np.load(xx) or ts.save(xxx)
2152 auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(value_body, "id"));
2153 MS_LOG(DEBUG) << "The Name id is " << name_id;
2154 bool is_third_party_side_effect =
2155 ast_->CallParserObjMethod(PYTHON_PARSE_CHECK_THIRD_PARTY_LIBRARY_SIDE_EFFECT, name_id, attr_str).cast<bool>();
2156 if (is_third_party_side_effect) {
2157 auto pyexecute_node = fallback::ConvertCNodeToPyExecuteForPrim(attr_cnode->cast<CNodePtr>(), "getattr");
2158 MS_LOG(DEBUG) << "pyexecute_node:" << pyexecute_node->DebugString();
2159 return pyexecute_node;
2160 }
2161 }
2162 // if getting other object, eg: obj.xx, find object from namespace by name
2163 obj_name = GetLocation(value_body)->expr_src();
2164 try {
2165 py::tuple namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, obj_name);
2166 constexpr size_t value_index = 2;
2167 getattr_obj = namespace_info[value_index];
2168 } catch (const std::exception &e) {
2169 MS_LOG(DEBUG) << obj_name << " is not supported in JIT Fallback. Original steps are processing instead.";
2170 getattr_obj = py::none();
2171 }
2172 const bool got_obj = !py::isinstance<py::none>(getattr_obj);
2173 if (got_obj) {
2174 AnfNodePtr getattr_node = MakeGetAttrWithInterpret(obj_name, attr_str, getattr_obj, cur_fg);
2175 // If setattr before, should make the getattr call into PyExecute also.
2176 if (getattr_node != nullptr) {
2177 attr_cnode = getattr_node;
2178 } else {
2179 // if getting attr from other obj, but did not find setattr before getattr, convert getattr later
2180 (void)getattr_nodes_map_[GetLocation(value_body)->expr_src()][attr_str].emplace_back(attr_cnode);
2181 }
2182 attr_cnode->set_user_data<py::object>("__getattr__", std::make_shared<py::object>(getattr_obj));
2183 }
2184 return attr_cnode;
2185 }
2186
2187 // Process comparison expression : a == b. a > b etc.
ParseCompare(const FunctionBlockPtr & block,const py::object & node)2188 AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) {
2189 MS_LOG(DEBUG) << "Process ast Compare";
2190
2191 py::list ops = python_adapter::GetPyObjAttr(node, "ops");
2192 py::object left = python_adapter::GetPyObjAttr(node, "left");
2193 py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
2194 if (ops.size() == 0) {
2195 MS_LOG(INTERNAL_EXCEPTION) << "Parse ast Compare failed, found no ops.";
2196 }
2197 if (comparators.size() == 0) {
2198 MS_LOG(INTERNAL_EXCEPTION) << "Parse ast Compare failed, found no comparators.";
2199 }
2200 if (ops.size() != comparators.size()) {
2201 MS_LOG(INTERNAL_EXCEPTION) << "Parse ast Compare failed, length of ops and comparators not equal, len of ops: "
2202 << ops.size() << " and length of comparators: " << comparators.size();
2203 }
2204
2205 auto first_left = left;
2206 auto first_right = comparators[0];
2207 auto first_op = ops[0];
2208 auto first_compare_node = ParseSingleCompare(block, first_left, first_right, first_op);
2209 if (ops.size() == 1) {
2210 // For single compare, such as x < y.
2211 return first_compare_node;
2212 }
2213
2214 // For multiple compare, such as x < y <= z,
2215 // convert it to x < y and y <= z.
2216 std::vector<AnfNodePtr> compare_nodes = {first_compare_node};
2217 for (size_t i = 1; i < ops.size(); ++i) {
2218 auto cur_left = comparators[i - 1];
2219 auto cur_right = comparators[i];
2220 auto cur_op = ops[i];
2221 auto cur_compare_node = ParseSingleCompare(block, cur_left, cur_right, cur_op);
2222 (void)compare_nodes.emplace_back(cur_compare_node);
2223 }
2224
2225 AnfNodePtr ret_node = compare_nodes[0];
2226 for (size_t i = 1; i < compare_nodes.size(); ++i) {
2227 ret_node = ConnectSingleCompare(block, ret_node, compare_nodes[i]);
2228 }
2229
2230 return ret_node;
2231 }
2232
ParseSingleCompare(const FunctionBlockPtr & block,const py::object & left,const py::object & right,const py::object & compare_op)2233 AnfNodePtr Parser::ParseSingleCompare(const FunctionBlockPtr &block, const py::object &left, const py::object &right,
2234 const py::object &compare_op) {
2235 MS_LOG(DEBUG) << "Process ast Compare with single comparators";
2236
2237 AnfNodePtr left_node = ParseExprNode(block, left);
2238 AnfNodePtr right_node = ParseExprNode(block, right);
2239
2240 MS_EXCEPTION_IF_NULL(block);
2241 const auto &ns = block->GetAstOpNameSpace(compare_op);
2242 auto op_node = block->MakeResolveAstOpNameSpace(ns);
2243
2244 MS_EXCEPTION_IF_NULL(block->func_graph());
2245 return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
2246 }
2247
ConnectSingleCompare(const FunctionBlockPtr & block,const AnfNodePtr & left_node,const AnfNodePtr & right_node)2248 AnfNodePtr Parser::ConnectSingleCompare(const FunctionBlockPtr &block, const AnfNodePtr &left_node,
2249 const AnfNodePtr &right_node) {
2250 // Connect two compare result with 'and'.
2251 MS_LOG(DEBUG) << "Connect single compare node.";
2252
2253 MS_EXCEPTION_IF_NULL(left_node);
2254 MS_EXCEPTION_IF_NULL(right_node);
2255 FunctionBlockPtr true_block = nullptr;
2256 FunctionBlockPtr false_block = nullptr;
2257 auto block_fg = block->func_graph();
2258 MS_EXCEPTION_IF_NULL(block_fg);
2259 {
2260 TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block_fg->debug_info()));
2261 true_block = MakeFunctionBlock();
2262 }
2263 {
2264 TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block_fg->debug_info()));
2265 false_block = MakeFunctionBlock();
2266 }
2267 MakeConditionBlocks(block, true_block, false_block);
2268 MS_EXCEPTION_IF_NULL(true_block->func_graph());
2269 MS_EXCEPTION_IF_NULL(false_block->func_graph());
2270 true_block->func_graph()->set_output(right_node);
2271 TraceGuard trace_guard(std::make_shared<TraceCopy>(left_node->debug_info()));
2272 false_block->func_graph()->set_output(left_node);
2273
2274 AnfNodePtr cond_node = block->ForceToCondNode(left_node);
2275
2276 auto switch_app =
2277 block_fg->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
2278 NewValueNode(false_block->func_graph())});
2279
2280 std::vector<AnfNodePtr> call_graph_nodes{switch_app};
2281 auto switch_app_call = block_fg->NewCNodeInOrder(std::move(call_graph_nodes));
2282 return switch_app_call;
2283 }
2284
ProcessBoolOpValueList(const FunctionBlockPtr & block,const py::list & value_list,AstSubType mode)2285 AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
2286 // If there is only one bool op now
2287 MS_EXCEPTION_IF_NULL(block);
2288 if (value_list.empty()) {
2289 MS_LOG(INTERNAL_EXCEPTION) << "value list is empty.";
2290 }
2291 if (value_list.size() == 1) {
2292 AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
2293 return first_node;
2294 } else {
2295 py::object first = value_list[0];
2296 py::list rest;
2297 for (size_t i = 1; i < value_list.size(); i++) {
2298 rest.append(value_list[i]);
2299 }
2300 FunctionBlockPtr true_block = nullptr;
2301 FunctionBlockPtr false_block = nullptr;
2302 auto block_fg = block->func_graph();
2303 {
2304 TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block_fg->debug_info()));
2305 true_block = MakeFunctionBlock();
2306 }
2307 {
2308 TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block_fg->debug_info()));
2309 false_block = MakeFunctionBlock();
2310 }
2311 MakeConditionBlocks(block, true_block, false_block);
2312 FunctionBlockPtr b1;
2313 FunctionBlockPtr b2;
2314
2315 // If it is and, we need to process the rest nodes;
2316 // If it is or, we continue to next
2317 if (mode == AST_SUB_TYPE_AND) {
2318 b1 = true_block;
2319 b2 = false_block;
2320 } else if (mode == AST_SUB_TYPE_OR) {
2321 b2 = true_block;
2322 b1 = false_block;
2323 } else {
2324 MS_LOG(ERROR) << "Not supported mode: " << mode;
2325 return nullptr;
2326 }
2327 AnfNodePtr test_node = ParseExprNode(block, first);
2328 AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode);
2329 MS_EXCEPTION_IF_NULL(b1->func_graph());
2330 MS_EXCEPTION_IF_NULL(b2->func_graph());
2331 b1->func_graph()->set_output(rest_node);
2332 TraceGuard trace_guard(GetLocation(value_list[1]));
2333 b2->func_graph()->set_output(test_node);
2334
2335 AnfNodePtr cond_node = block->ForceToCondNode(test_node);
2336 auto switch_app =
2337 block_fg->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
2338 NewValueNode(false_block->func_graph())});
2339
2340 std::vector<AnfNodePtr> call_graph_nodes{switch_app};
2341 auto switch_app_call = block_fg->NewCNodeInOrder(std::move(call_graph_nodes));
2342 return switch_app_call;
2343 }
2344 }
2345
2346 // Process comparison expression : a and b. a or b .
ParseBoolOp(const FunctionBlockPtr & block,const py::object & node)2347 AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) {
2348 MS_LOG(DEBUG) << "Process ast BoolOp";
2349 py::object op_node = python_adapter::GetPyObjAttr(node, "op");
2350 AstSubType op_type = ast_->GetOpType(op_node);
2351 if (op_type == AST_SUB_TYPE_UNKNOWN) {
2352 MS_LOG(INTERNAL_EXCEPTION) << "ProcessBoolOp, got unknown op type";
2353 }
2354 py::list op_values = python_adapter::GetPyObjAttr(node, "values");
2355 return ProcessBoolOpValueList(block, op_values, op_type);
2356 }
2357
2358 // Process a function def
ParseFunctionDef(const FunctionBlockPtr & block,const py::object & node)2359 FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
2360 MS_LOG(DEBUG) << "Process ast FunctionDef";
2361 FunctionBlockPtr function_block = ParseDefFunction(node, block);
2362 MS_EXCEPTION_IF_NULL(function_block);
2363
2364 // Get function name
2365 py::str name = python_adapter::GetPyObjAttr(node, "name");
2366 std::string function_name = name;
2367 ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph());
2368 block->WriteVariable(function_name, valuenode_graph);
2369 return block;
2370 }
2371
2372 // Process a lambda expression . like lambda x,y: x + y
ParseLambda(const FunctionBlockPtr & block,const py::object & node)2373 AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
2374 MS_LOG(DEBUG) << "Process ast Lambda";
2375 FunctionBlockPtr function_block = ParseLambdaFunction(node, block);
2376 MS_EXCEPTION_IF_NULL(function_block);
2377
2378 auto block_fg = function_block->func_graph();
2379 ValueNodePtr const_graph = NewValueNode(block_fg);
2380 return const_graph;
2381 }
2382
2383 // a = *[1, 2], (3, 4)
2384 // StarredUnpackMerge(assign_node1, assign_node2, starred_flags_node, is_tuple)
2385 // StarredUnpackMerge(StarredUnpack(*[1, 2]), (3, 4), (1, 0), 1)
2386 // --> StarredUnpackMerge((1, 2), (3, 4), (1, 0), 1)
2387 // --> (1, 2, (3, 4))
ParseTupleOrListWithStarred(const FunctionBlockPtr & block,const py::object & node,bool is_tuple,const std::vector<AnfNodePtr> & starred_flags)2388 AnfNodePtr Parser::ParseTupleOrListWithStarred(const FunctionBlockPtr &block, const py::object &node, bool is_tuple,
2389 const std::vector<AnfNodePtr> &starred_flags) {
2390 auto prim = std::make_shared<prim::StarredUnpackMerge>(NAMED_METAGRAPH_STARRED_UNPACK_MERGE);
2391 std::vector<AnfNodePtr> unpack_merge_inputs{NewValueNode(prim)};
2392 auto starred_flags_node = block->func_graph()->NewCNodeInOrder(starred_flags);
2393 py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
2394 for (size_t i = 0; i < elts.size(); i++) {
2395 AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
2396 auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elts[i])));
2397 if (elt_type == AST_SUB_TYPE_STARRED) {
2398 auto starred_unpack_prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
2399 CNodePtr unpack_node = block->func_graph()->NewCNodeInOrder({NewValueNode(starred_unpack_prim), node_ptr});
2400 (void)unpack_merge_inputs.emplace_back(unpack_node);
2401 } else {
2402 (void)unpack_merge_inputs.emplace_back(node_ptr);
2403 }
2404 }
2405 (void)unpack_merge_inputs.emplace_back(starred_flags_node);
2406 if (is_tuple) {
2407 auto is_tuple_node = NewValueNode(static_cast<int64_t>(1));
2408 (void)unpack_merge_inputs.emplace_back(is_tuple_node);
2409 } else {
2410 auto is_tuple_node = NewValueNode(static_cast<int64_t>(0));
2411 (void)unpack_merge_inputs.emplace_back(is_tuple_node);
2412 }
2413
2414 CNodePtr unpack_merge_node = block->func_graph()->NewCNodeInOrder(unpack_merge_inputs);
2415 return unpack_merge_node;
2416 }
2417
ParseTupleOrList(const FunctionBlockPtr & block,const py::object & node,bool is_tuple)2418 AnfNodePtr Parser::ParseTupleOrList(const FunctionBlockPtr &block, const py::object &node, bool is_tuple) {
2419 MS_EXCEPTION_IF_NULL(block);
2420 py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
2421 if (elts.empty()) {
2422 if (is_tuple) {
2423 auto empty_tuple = std::vector<ValuePtr>();
2424 return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
2425 }
2426 auto empty_list = std::vector<ValuePtr>();
2427 return NewValueNode(std::make_shared<ValueList>(empty_list));
2428 }
2429
2430 bool exist_starred_expression = false;
2431 std::vector<AnfNodePtr> starred_flags{NewValueNode(prim::kPrimMakeTuple)};
2432 for (size_t i = 0; i < elts.size(); i++) {
2433 auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elts[i])));
2434 if (elt_type == AST_SUB_TYPE_STARRED) {
2435 exist_starred_expression = true;
2436 starred_flags.push_back(NewValueNode(static_cast<int64_t>(1)));
2437 } else {
2438 starred_flags.push_back(NewValueNode(static_cast<int64_t>(0)));
2439 }
2440 }
2441
2442 if (!exist_starred_expression) {
2443 std::vector<AnfNodePtr> sequence_vec;
2444 AnfNodePtr sequence_op = nullptr;
2445 if (is_tuple) {
2446 sequence_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
2447 } else {
2448 sequence_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
2449 }
2450 (void)sequence_vec.emplace_back(sequence_op);
2451 for (size_t i = 0; i < elts.size(); i++) {
2452 AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
2453 (void)sequence_vec.emplace_back(node_ptr);
2454 }
2455 MS_EXCEPTION_IF_NULL(block->func_graph());
2456 CNodePtr sequence_app = block->func_graph()->NewCNodeInOrder(std::move(sequence_vec));
2457 return sequence_app;
2458 }
2459 return ParseTupleOrListWithStarred(block, node, is_tuple, starred_flags);
2460 }
2461
2462 // Process a tuple
ParseTuple(const FunctionBlockPtr & block,const py::object & node)2463 AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
2464 MS_LOG(DEBUG) << "Process ast Tuple";
2465 return ParseTupleOrList(block, node, true);
2466 }
2467
2468 // Process a list
ParseList(const FunctionBlockPtr & block,const py::object & node)2469 AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
2470 MS_LOG(DEBUG) << "Process ast List";
2471 return ParseTupleOrList(block, node, false);
2472 }
2473
GetValuePythonObject(const py::object & value_node)2474 py::object Parser::GetValuePythonObject(const py::object &value_node) {
2475 auto value_name = py::cast<std::string>(value_node);
2476 py::tuple attr_namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value_name);
2477 // Handle nested function def.
2478 constexpr size_t value_index = 2;
2479 auto py_obj_attr_value = attr_namespace_info[value_index];
2480 if (!py::isinstance<py::none>(py_obj_attr_value)) {
2481 return py_obj_attr_value;
2482 }
2483 return py::none();
2484 }
2485
2486 // Process a subscript, such as x[y] , node expressed as value[slice]
ParseSubscript(const FunctionBlockPtr & block,const py::object & node)2487 AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) {
2488 MS_LOG(DEBUG) << "Process ast Subscript";
2489 MS_EXCEPTION_IF_NULL(block);
2490 AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
2491 py::object value_node = python_adapter::GetPyObjAttr(node, "value");
2492 py::object slice_node = python_adapter::GetPyObjAttr(node, "slice");
2493 AnfNodePtr value = ParseExprNode(block, value_node);
2494 AnfNodePtr slice = ParseExprNode(block, slice_node);
2495 MS_EXCEPTION_IF_NULL(block->func_graph());
2496 auto value_id = python_adapter::GetPyObjAttr(value_node, "id");
2497 AnfNodePtr getitem_node;
2498 py::object value_obj = py::none();
2499 auto str_getitem = std::make_shared<StringImm>("__getitem__");
2500 AnfNodePtr new_node;
2501 // value[slice]. The value of subscript must not be built-in functions
2502 // if the id of value object has the same name as built-in function, should not get the value_obj.
2503 bool value_id_is_builtins =
2504 py::cast<bool>(ast_->CallParserObjMethod(PYTHON_PARSE_IS_BUILTIN_FUNCTION_NAME, py::str(value_id)));
2505 if (!py::isinstance<py::none>(value_id) && !value_id_is_builtins) {
2506 value_obj = GetValuePythonObject(value_id);
2507 }
2508 bool is_adapter = false;
2509 if (!py::isinstance<py::none>(value_obj)) {
2510 if (py::hasattr(value_obj, "adapter_flag")) {
2511 is_adapter = py::cast<bool>(py::getattr(value_obj, "adapter_flag"));
2512 }
2513 }
2514 if (!py::isinstance<py::none>(value_obj) && !is_adapter) {
2515 getitem_node =
2516 block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), value, NewValueNode(str_getitem)});
2517 new_node = block->func_graph()->NewCNodeInOrder({getitem_node, slice});
2518 getitem_node->set_user_data<py::object>("__getitem__", std::make_shared<py::object>(value_obj));
2519 } else {
2520 new_node = block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
2521 }
2522
2523 return new_node;
2524 }
2525
2526 // Process a slice, get the slice value
ParseSlice(const FunctionBlockPtr & block,const py::object & node)2527 AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) {
2528 MS_LOG(DEBUG) << "Process ast Slice";
2529 MS_EXCEPTION_IF_NULL(block);
2530 AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE);
2531 py::object start = python_adapter::GetPyObjAttr(node, "lower");
2532 py::object stop = python_adapter::GetPyObjAttr(node, "upper");
2533 py::object step = python_adapter::GetPyObjAttr(node, "step");
2534 AnfNodePtr start_node = ParseExprNode(block, start);
2535 AnfNodePtr stop_node = ParseExprNode(block, stop);
2536 AnfNodePtr step_node = ParseExprNode(block, step);
2537 MS_EXCEPTION_IF_NULL(block->func_graph());
2538 return block->func_graph()->NewCNodeInOrder({op_makeslice, start_node, stop_node, step_node});
2539 }
2540
2541 // Process a extslice
ParseExtSlice(const FunctionBlockPtr & block,const py::object & node)2542 AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) {
2543 MS_LOG(DEBUG) << "Process ast ExtSlice";
2544 MS_EXCEPTION_IF_NULL(block);
2545 AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
2546 py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims");
2547
2548 std::vector<AnfNodePtr> node_vec;
2549 (void)node_vec.emplace_back(make_tuple_op);
2550 for (size_t i = 0; i < slice_tuple.size(); i++) {
2551 AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]);
2552 (void)node_vec.emplace_back(node_ptr);
2553 }
2554 MS_EXCEPTION_IF_NULL(block->func_graph());
2555 CNodePtr tuple_conde = block->func_graph()->NewCNodeInOrder(std::move(node_vec));
2556 return tuple_conde;
2557 }
2558
2559 // Process a index, get the index number
ParseIndex(const FunctionBlockPtr & block,const py::object & node)2560 AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) {
2561 MS_LOG(DEBUG) << "Process ast Index";
2562 py::object value_node = python_adapter::GetPyObjAttr(node, "value");
2563 return ParseExprNode(block, value_node);
2564 }
2565
2566 // Process a UnaryOp, +a, -b
ParseUnaryOp(const FunctionBlockPtr & block,const py::object & node)2567 AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) {
2568 MS_LOG(DEBUG) << "Process ast UnaryOp";
2569 py::object op = python_adapter::GetPyObjAttr(node, "op");
2570
2571 MS_EXCEPTION_IF_NULL(block);
2572 // Resolve the op
2573 const auto &ns = block->GetAstOpNameSpace(op);
2574 auto op_node = block->MakeResolveAstOpNameSpace(ns);
2575
2576 py::object operand = python_adapter::GetPyObjAttr(node, "operand");
2577 AnfNodePtr operand_node = ParseExprNode(block, operand);
2578 MS_EXCEPTION_IF_NULL(block->func_graph());
2579 return block->func_graph()->NewCNodeInOrder({op_node, operand_node});
2580 }
2581
2582 // Process a dict ast node expression
ParseDictByKeysAndValues(const FunctionBlockPtr & block,const std::vector<AnfNodePtr> & key_nodes,const std::vector<AnfNodePtr> & value_nodes)2583 AnfNodePtr Parser::ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
2584 const std::vector<AnfNodePtr> &value_nodes) {
2585 auto keys_tuple = GenerateMakeTuple(block, key_nodes);
2586 auto values_tuple = GenerateMakeTuple(block, value_nodes);
2587 MS_EXCEPTION_IF_NULL(block);
2588 auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
2589 MS_EXCEPTION_IF_NULL(block->func_graph());
2590 return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
2591 }
2592
GetRealKeysValuesFromName(const FunctionBlockPtr & block,const py::object & node)2593 std::pair<AnfNodePtr, AnfNodePtr> Parser::GetRealKeysValuesFromName(const FunctionBlockPtr &block,
2594 const py::object &node) {
2595 MS_EXCEPTION_IF_NULL(block);
2596 auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
2597 AnfNodePtr dict = block->ReadVariable(name_id);
2598 auto keys = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimDictGetKeys), dict});
2599 auto values = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimDictGetValues), dict});
2600 // Using the MakeTuple node, pass the need_unpack tag from the AnfNode to the abstract
2601 auto tuple_keys = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), keys});
2602 auto tuple_values = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), values});
2603 constexpr auto need_unpack = "need_unpack";
2604 tuple_keys->set_user_data<bool>(need_unpack, std::make_shared<bool>(true));
2605 tuple_values->set_user_data<bool>(need_unpack, std::make_shared<bool>(true));
2606 return {tuple_keys, tuple_values};
2607 }
2608
GetRealKeysValues(const FunctionBlockPtr & block,const py::object & node)2609 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> Parser::GetRealKeysValues(const FunctionBlockPtr &block,
2610 const py::object &node) {
2611 py::list keys = node.attr("keys");
2612 py::list values = node.attr("values");
2613 if (keys.size() != values.size()) {
2614 MS_LOG(INTERNAL_EXCEPTION) << "The keys' size is not equal to the values' size.";
2615 }
2616 std::vector<AnfNodePtr> inner_key_nodes;
2617 std::vector<AnfNodePtr> inner_value_nodes;
2618 for (size_t index = 0; index < keys.size(); ++index) {
2619 auto inner_key_node_type = ast_->GetNodeType(keys[index]);
2620 const std::string &inner_key_node_type_name = inner_key_node_type->node_name();
2621 // The key does not exist, mean the value is a dict which need unpack.
2622 if (inner_key_node_type_name == "NoneType") {
2623 auto unpack_dict = values[index];
2624 auto inner_value_node_type =
2625 AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, unpack_dict)));
2626 if (inner_value_node_type == AST_SUB_TYPE_DICT) {
2627 auto [unpack_keys, unpack_values] = GetRealKeysValues(block, unpack_dict);
2628 for (size_t i = 0; i < unpack_keys.size(); ++i) {
2629 inner_key_nodes.push_back(unpack_keys[i]);
2630 inner_value_nodes.push_back(unpack_values[i]);
2631 }
2632 } else if (inner_value_node_type == AST_SUB_TYPE_NAME) {
2633 auto [unpack_key, unpack_value] = GetRealKeysValuesFromName(block, unpack_dict);
2634 inner_key_nodes.push_back(unpack_key);
2635 inner_value_nodes.push_back(unpack_value);
2636 } else {
2637 MS_LOG(INTERNAL_EXCEPTION) << "The input of dict which need unpack must be dict, but got "
2638 << inner_value_node_type;
2639 }
2640 } else {
2641 AnfNodePtr key_node = ParseExprNode(block, keys[index]);
2642 inner_key_nodes.push_back(key_node);
2643 AnfNodePtr value_node = ParseExprNode(block, values[index]);
2644 inner_value_nodes.push_back(value_node);
2645 }
2646 }
2647 return {inner_key_nodes, inner_value_nodes};
2648 }
2649
ParseDict(const FunctionBlockPtr & block,const py::object & node)2650 AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
2651 MS_LOG(DEBUG) << "Process ast Dict";
2652 auto [key_nodes, value_nodes] = GetRealKeysValues(block, node);
2653 return ParseDictByKeysAndValues(block, key_nodes, value_nodes);
2654 }
2655
2656 // Process a augment assign such as a += b or mat[stride_slice] += b.
ParseAugAssign(const FunctionBlockPtr & block,const py::object & node)2657 FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
2658 MS_LOG(DEBUG) << "Process ast AugAssign";
2659 MS_EXCEPTION_IF_NULL(block);
2660 MS_EXCEPTION_IF_NULL(ast_);
2661
2662 py::object target_object = python_adapter::GetPyObjAttr(node, "target");
2663 py::object op_object = python_adapter::GetPyObjAttr(node, "op");
2664 py::object value_object = python_adapter::GetPyObjAttr(node, "value");
2665 AnfNodePtr target_node = nullptr;
2666
2667 const auto &ns = block->GetAstOpNameSpace(op_object);
2668 auto op_node = block->MakeResolveAstOpNameSpace(ns);
2669
2670 AnfNodePtr value_node = ParseExprNode(block, value_object);
2671 {
2672 TraceGuard trace_guard(GetLocation(target_object));
2673 target_node = ParseExprNode(block, target_object);
2674 }
2675
2676 if (target_node == nullptr) {
2677 MS_LOG(INTERNAL_EXCEPTION) << "Can not get target node ";
2678 }
2679 MS_EXCEPTION_IF_NULL(block->func_graph());
2680 AnfNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
2681
2682 // b += list_x.pop(a)
2683 // --> list_x, b = list_x, b + list_x.pop(a) need renew the list_x.
2684 if (IsPopOperation(value_node)) {
2685 ProcessPopOperationInAugAssign(block, value_node, target_node, op_node, target_object);
2686 return block;
2687 }
2688
2689 WriteAssignVars(block, target_object, augassign_app);
2690 return block;
2691 }
2692
2693 // Process global declaration such as 'global x';
ParseGlobal(const FunctionBlockPtr & block,const py::object & node)2694 FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
2695 MS_LOG(DEBUG) << "Process ast Global";
2696 MS_EXCEPTION_IF_NULL(block);
2697 py::list vars = python_adapter::GetPyObjAttr(node, "names");
2698 for (auto &item : vars) {
2699 block->AddGlobalVar(py::cast<std::string>(item));
2700 }
2701 return block;
2702 }
2703
CheckControlFlowAlterationInIf(std::pair<FunctionBlockPtr,FunctionBlockPtr> * branch_graphs_pair,const FunctionBlockPtr & branch_block,const FunctionBlockPtr & branch_end,const FunctionBlockPtr & after_block,const FunctionBlockPtr & block) const2704 void Parser::CheckControlFlowAlterationInIf(std::pair<FunctionBlockPtr, FunctionBlockPtr> *branch_graphs_pair,
2705 const FunctionBlockPtr &branch_block, const FunctionBlockPtr &branch_end,
2706 const FunctionBlockPtr &after_block, const FunctionBlockPtr &block) const {
2707 if (branch_block->is_return_statement_inside()) {
2708 MS_LOG(DEBUG)
2709 << "Inside the branch block has return statement, ignore for transformation to parallel-if call, branch block:"
2710 << branch_block->ToString() << ", block: " << block->ToString();
2711 block->set_is_return_statement_inside();
2712 return;
2713 }
2714 if (branch_block->is_break_continue_statement_inside()) {
2715 MS_LOG(DEBUG) << "Inside the branch block has break or continue statement, ignore for transformation to "
2716 "parallel-if call, branch block: "
2717 << branch_block->ToString() << ", branch end: " << branch_end->ToString()
2718 << ", block: " << block->ToString();
2719 MS_LOG(DEBUG) << "Propagate flag of break or continue statement from branch block to block, branch block:"
2720 << branch_block->ToString() << ", block: " << block->ToString();
2721 block->set_break_continue_statement_inside();
2722 } else if (branch_end->func_graph()->get_return() != nullptr) {
2723 // Currently, this can only happen with raise statement inside. As try/expect is not supported now,
2724 // and contional for raise will be evaluated in Compile time. If raise condition is met, it will
2725 // cause compile fail, so no need to propagate the flag back.
2726 MS_LOG(DEBUG) << "Ignore the block as branch_end will not call after_block, branch_block: "
2727 << branch_block->ToString() << ", branch_end: " << branch_end->ToString()
2728 << ", after_block: " << after_block->ToString();
2729 } else {
2730 branch_graphs_pair->second = branch_end;
2731 }
2732 }
2733
2734 // Check constant bool constant attr, such as:
2735 // if self.has_bias
CheckAttributeConstantCond(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond) const2736 bool Parser::CheckAttributeConstantCond(const FunctionBlockPtr &block, const py::object &test_node,
2737 bool *is_true_cond) const {
2738 bool is_constant;
2739 auto attr_cond = GetPyObjForAstAttr(block, test_node, &is_constant);
2740 if (!is_constant) {
2741 return false;
2742 }
2743 if (!py::isinstance<py::bool_>(attr_cond)) {
2744 return false;
2745 }
2746 *is_true_cond = py::cast<bool>(attr_cond);
2747 return true;
2748 }
2749
2750 // Check constant local var, such as:
2751 // if has_bias
CheckNameConstantCond(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond) const2752 bool Parser::CheckNameConstantCond(const FunctionBlockPtr &block, const py::object &test_node,
2753 bool *is_true_cond) const {
2754 auto id = python_adapter::GetPyObjAttr(test_node, "id");
2755 if (!py::isinstance<py::str>(id)) {
2756 return false;
2757 }
2758 auto anf_node = block->ReadVariable(id.cast<std::string>());
2759 if (anf_node == nullptr) {
2760 return false;
2761 }
2762 MS_LOG(DEBUG) << "CheckNameConstantCond anf_node: " << anf_node->DebugString();
2763 ValuePtr value = nullptr;
2764 if (anf_node->isa<ValueNode>()) {
2765 value = anf_node->cast<ValueNodePtr>()->value();
2766 } else if (anf_node->isa<Parameter>()) {
2767 value = GetParameterValue(anf_node);
2768 if (value == nullptr || value->ContainsValueAny()) {
2769 return false;
2770 }
2771 MS_LOG(DEBUG) << "Found constant value: " << value->ToString() << " for anf_node: " << anf_node;
2772 }
2773 if (value == nullptr || !value->isa<BoolImm>()) {
2774 return false;
2775 }
2776 MS_LOG(DEBUG) << "CheckNameConstantCond value: " << value->ToString();
2777 *is_true_cond = GetValue<bool>(value);
2778 return true;
2779 }
2780
2781 // Check constant unary op result, such as:
2782 // if not self.has_bias
CheckUnaryOpConstantCond(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond) const2783 bool Parser::CheckUnaryOpConstantCond(const FunctionBlockPtr &block, const py::object &test_node,
2784 bool *is_true_cond) const {
2785 auto op = python_adapter::GetPyObjAttr(test_node, "op");
2786 auto op_node_type = ast()->GetNodeType(op);
2787 const auto &op_node_type_name = op_node_type->node_name();
2788 MS_LOG(DEBUG) << "op_node_type_name: " << op_node_type_name;
2789 if (op_node_type_name != "Not") {
2790 return false;
2791 }
2792 auto operand = python_adapter::GetPyObjAttr(test_node, "operand");
2793 auto check_constant_cond = CheckConstantCondition(block, operand, is_true_cond);
2794 if (!check_constant_cond) {
2795 return false;
2796 }
2797 *is_true_cond = !(*is_true_cond);
2798 return true;
2799 }
2800
2801 // Check constant compare result, such as:
2802 // if self.has_bias is None
CheckCompareConstantCond(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond) const2803 bool Parser::CheckCompareConstantCond(const FunctionBlockPtr &block, const py::object &test_node,
2804 bool *is_true_cond) const {
2805 return GetBoolObjForAstCompare(block, test_node, is_true_cond);
2806 }
2807
2808 // Check constant bool op result, such as:
2809 // if self.has_bias is None and self.beta == 1
CheckBoolOpConstantCond(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond) const2810 bool Parser::CheckBoolOpConstantCond(const FunctionBlockPtr &block, const py::object &test_node,
2811 bool *is_true_cond) const {
2812 auto op = python_adapter::GetPyObjAttr(test_node, "op");
2813 auto op_node_type = ast()->GetNodeType(op);
2814 const auto &op_node_type_name = op_node_type->node_name();
2815 MS_LOG(DEBUG) << "op_node_type_name: " << op_node_type_name;
2816 py::list values = python_adapter::GetPyObjAttr(test_node, "values");
2817 bool determined = false;
2818 for (size_t i = 0; i < values.size(); ++i) {
2819 bool sub_is_true_cond;
2820 auto check_constant_cond = CheckConstantCondition(block, values[i], &sub_is_true_cond);
2821 if (!check_constant_cond) {
2822 return false;
2823 }
2824 if (op_node_type_name == "Or" && sub_is_true_cond) {
2825 determined = true;
2826 break;
2827 } else if (op_node_type_name == "And" && !sub_is_true_cond) {
2828 determined = true;
2829 break;
2830 }
2831 }
2832 if (op_node_type_name == "Or") {
2833 *is_true_cond = determined;
2834 } else if (op_node_type_name == "And") {
2835 *is_true_cond = !determined;
2836 }
2837 return true;
2838 }
2839
GetConstantConditionFromComment(const FunctionBlockPtr & block,const py::object & if_node,bool * is_true_cond) const2840 bool Parser::GetConstantConditionFromComment(const FunctionBlockPtr &block, const py::object &if_node,
2841 bool *is_true_cond) const {
2842 auto location = GetLocation(if_node);
2843 MS_EXCEPTION_IF_NULL(location);
2844 const auto &comments = location->comments();
2845 if (comments.empty()) {
2846 return false;
2847 }
2848 const auto &comment = comments.back();
2849 MS_LOG(DEBUG) << "The comment of if statement: " << comment << ", block: " << block->ToString();
2850 std::regex regex("^#\\s*@jit.cond:\\s*([A-Za-z]+)$");
2851 std::smatch matched_results;
2852 if (!std::regex_match(comment, matched_results, regex)) {
2853 return false;
2854 }
2855 constexpr auto container_match_count = 2;
2856 if (matched_results.size() != container_match_count) {
2857 return false;
2858 }
2859 const auto &cond_str = matched_results[1].str();
2860 MS_LOG(DEBUG) << "The cond string of comment is " << cond_str;
2861 if (cond_str != "True" && cond_str != "False") {
2862 return false;
2863 }
2864 *is_true_cond = (cond_str == "True");
2865 return true;
2866 }
2867
2868 // Return true if it's constant condition and the condition value returned by is_true_cond, otherwise return false.
CheckConstantCondition(const FunctionBlockPtr & block,const py::object & test_node,bool * is_true_cond,const py::object & if_node) const2869 bool Parser::CheckConstantCondition(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond,
2870 const py::object &if_node) const {
2871 static const auto boost_parse = common::GetCompileConfig("BOOST_PARSE");
2872 if (boost_parse == "0") {
2873 return false;
2874 }
2875 MS_EXCEPTION_IF_NULL(block);
2876 MS_EXCEPTION_IF_NULL(is_true_cond);
2877 // Try to get the constant condition from the comment "@jit.cond: True/False".
2878 if (if_node != py::none() && GetConstantConditionFromComment(block, if_node, is_true_cond)) {
2879 return true;
2880 }
2881 auto node_type = ast()->GetNodeType(test_node);
2882 const std::string &node_type_name = node_type->node_name();
2883 MS_LOG(DEBUG) << "node_type_name: " << node_type_name;
2884
2885 auto func_iter = condition_method_map_.find(node_type_name);
2886 if (func_iter == condition_method_map_.end()) {
2887 return false;
2888 }
2889 auto check_constant = (this->*(func_iter->second))(block, test_node, is_true_cond);
2890 if (check_constant) {
2891 MS_LOG(DEBUG) << "Has constant condition, is_true_cond: " << *is_true_cond;
2892 return true;
2893 }
2894 MS_LOG(DEBUG) << "Has no constant condition, test_node: " << py::str(test_node);
2895 return false;
2896 }
2897
2898 // Process a if statement
ParseIf(const FunctionBlockPtr & block,const py::object & node)2899 FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) {
2900 MS_LOG(DEBUG) << "Process ast If";
2901 MS_EXCEPTION_IF_NULL(block);
2902 py::object test_node = python_adapter::GetPyObjAttr(node, "test");
2903 bool is_true_cond = false;
2904 bool is_bool_const_cond = CheckConstantCondition(block, test_node, &is_true_cond, node);
2905
2906 // Make condition node.
2907 AnfNodePtr bool_node = nullptr;
2908 if (!is_bool_const_cond) {
2909 AnfNodePtr condition_node = ParseExprNode(block, test_node);
2910 bool_node = block->ForceToCondNode(condition_node);
2911 }
2912
2913 FunctionBlockPtr true_block = nullptr;
2914 FunctionBlockPtr false_block = nullptr;
2915 FunctionBlockPtr after_block = nullptr;
2916 auto block_fg = block->func_graph();
2917 MS_EXCEPTION_IF_NULL(block_fg);
2918 if (!is_bool_const_cond || is_true_cond) {
2919 TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block_fg->debug_info()));
2920 true_block = MakeFunctionBlock();
2921 MS_LOG(DEBUG) << "Make true branch, " << true_block->ToString();
2922 }
2923 if (!is_bool_const_cond || !is_true_cond) {
2924 TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block_fg->debug_info()));
2925 false_block = MakeFunctionBlock();
2926 MS_LOG(DEBUG) << "Make false branch, " << false_block->ToString();
2927 }
2928
2929 if (!is_bool_const_cond) {
2930 MakeConditionBlocks(block, true_block, false_block);
2931 } else if (is_true_cond) {
2932 MS_LOG(DEBUG) << "Connect true branch, " << true_block->ToString();
2933 block->Jump(true_block, {});
2934 true_block->Mature();
2935 true_block->UpdateGlobalPyParam(block->global_py_params());
2936 } else { // !is_true_cond
2937 MS_LOG(DEBUG) << "Connect false branch, " << false_block->ToString();
2938 block->Jump(false_block, {});
2939 false_block->Mature();
2940 false_block->UpdateGlobalPyParam(block->global_py_params());
2941 }
2942
2943 {
2944 TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block_fg->debug_info()));
2945 after_block = MakeFunctionBlock();
2946 }
2947
2948 if (!is_bool_const_cond && MsContext::GetInstance()->backend_policy() != "ge") {
2949 // For backends excludes 'ge', it can handle multi graph call, use this flag to
2950 // generate call not inline `after_block` graph to reduce if by if switch expansion.
2951 MS_EXCEPTION_IF_NULL(after_block->func_graph());
2952 after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
2953 }
2954
2955 // Process the if-true branch
2956 std::pair<FunctionBlockPtr, FunctionBlockPtr> true_branch_graphs;
2957 if (!is_bool_const_cond || is_true_cond) {
2958 py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
2959 FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
2960 std::string true_branch_name = "true branch";
2961 true_end->set_block_name(true_branch_name);
2962 MS_EXCEPTION_IF_NULL(true_end->func_graph());
2963 CheckControlFlowAlterationInIf(&true_branch_graphs, true_block, true_end, after_block, block);
2964 // If the return_ is set, it has its own continuation block
2965 if (true_end->func_graph()->get_return() == nullptr) {
2966 TraceGuard trace_guard_true(GetLocation(bodyNode));
2967 true_end->Jump(after_block, {});
2968 MS_LOG(DEBUG) << "The true_end block jump to after, true_block: " << true_block->ToString()
2969 << ", true_end: " << true_end->ToString() << ", after: " << after_block->ToString();
2970 after_block->UpdateGlobalPyParam(true_end->global_py_params());
2971 }
2972 }
2973
2974 // Process the orelse branch
2975 std::pair<FunctionBlockPtr, FunctionBlockPtr> false_branch_graphs;
2976 if (!is_bool_const_cond || !is_true_cond) {
2977 py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
2978 FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode);
2979 std::string false_branch_name = "false branch";
2980 false_end->set_block_name(false_branch_name);
2981 MS_EXCEPTION_IF_NULL(false_end->func_graph());
2982 CheckControlFlowAlterationInIf(&false_branch_graphs, false_block, false_end, after_block, block);
2983 // If the return_ is set, it has its own continuation block
2984 if (false_end->func_graph()->get_return() == nullptr) {
2985 if (py::len_hint(orelseNode) != 0) {
2986 TraceGuard trace_guard_false(GetLocation(orelseNode));
2987 false_end->Jump(after_block, {});
2988 } else {
2989 false_end->Jump(after_block, {});
2990 }
2991 MS_LOG(DEBUG) << "The false_end block jump to after, false_block: " << false_block->ToString()
2992 << ", false_end: " << false_end->ToString() << ", after: " << after_block->ToString();
2993 after_block->UpdateGlobalPyParam(false_end->global_py_params());
2994 }
2995 }
2996
2997 if (!is_bool_const_cond) {
2998 MS_EXCEPTION_IF_NULL(bool_node);
2999 auto switch_app = block->ConditionalJump(bool_node, true_block, false_block);
3000
3001 // Record the former, middle, latter graphs info.
3002 static const auto transform_tail_call_to_parallel_call = (common::GetCompileConfig("IF_PARALLEL_CALL") != "0");
3003 if (transform_tail_call_to_parallel_call && true_branch_graphs.second != nullptr &&
3004 false_branch_graphs.second != nullptr) {
3005 true_branch_graphs.first = block;
3006 false_branch_graphs.first = block;
3007 MS_LOG(DEBUG) << "Record tail call {former: " << block->func_graph()->ToString()
3008 << ", true middle: " << true_branch_graphs.second->func_graph()->ToString()
3009 << ", false middle: " << false_branch_graphs.second->func_graph()->ToString() << "}";
3010 std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> branch_graphs_vec{true_branch_graphs,
3011 false_branch_graphs};
3012 (void)parallel_call_graphs_.emplace_back(branch_graphs_vec);
3013 }
3014
3015 static const auto transform_for_half_unroll_call = (common::GetCompileConfig("FOR_HALF_UNROLL") == "1");
3016 if (transform_for_half_unroll_call) {
3017 // Lift the if branches in for statement.
3018 (void)if_branch_calls_.emplace_back(std::make_tuple(switch_app, true_block, false_block));
3019 }
3020 }
3021
3022 if (after_block->prev_blocks().empty()) {
3023 MS_LOG(DEBUG) << "After block's previous block is null";
3024 after_block->SetAsDeadBlock();
3025 }
3026 after_block->Mature();
3027 return after_block;
3028 }
3029
CheckReturnInLoop(const FunctionBlockPtr & block,const FunctionBlockPtr & body_block) const3030 void Parser::CheckReturnInLoop(const FunctionBlockPtr &block, const FunctionBlockPtr &body_block) const {
3031 MS_EXCEPTION_IF_NULL(block);
3032 MS_EXCEPTION_IF_NULL(body_block);
3033 // Propagate flag of return statement in body_block back.
3034 if (body_block->is_return_statement_inside()) {
3035 MS_LOG(DEBUG) << "Propagate flag of return statement in body_block back, body_block: " << body_block->ToString()
3036 << ", block: " << block->ToString();
3037 block->set_is_return_statement_inside();
3038 }
3039 }
3040
ParseWhile(const FunctionBlockPtr & block,const py::object & node)3041 FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) {
3042 MS_LOG(DEBUG) << "Process ast While";
3043 MS_EXCEPTION_IF_NULL(block);
3044 std::string while_block_name = "while";
3045 block->set_block_name(while_block_name);
3046 FunctionBlockPtr header_block = nullptr;
3047 FunctionBlockPtr body_block = nullptr;
3048 FunctionBlockPtr after_block = nullptr;
3049 MS_EXCEPTION_IF_NULL(block->func_graph());
3050 {
3051 TraceGuard guard(std::make_shared<TraceWhileHeader>(block->func_graph()->debug_info()));
3052 header_block = MakeFunctionBlock();
3053 auto func_graph = header_block->func_graph();
3054 MS_EXCEPTION_IF_NULL(func_graph);
3055 func_graph->set_flag(GRAPH_FLAG_IS_WHILE_HEADER, true);
3056 }
3057 {
3058 TraceGuard guard(std::make_shared<TraceWhileBody>(block->func_graph()->debug_info()));
3059 body_block = MakeFunctionBlock();
3060 }
3061 {
3062 TraceGuard guard(std::make_shared<TraceWhileAfter>(block->func_graph()->debug_info()));
3063 after_block = MakeFunctionBlock();
3064 }
3065
3066 body_block->AddPrevBlock(header_block);
3067 after_block->AddPrevBlock(header_block);
3068 block->Jump(header_block, {});
3069
3070 py::object test_node = python_adapter::GetPyObjAttr(node, "test");
3071 header_block->UpdateGlobalPyParam(block->global_py_params());
3072 body_block->UpdateGlobalPyParam(block->global_py_params());
3073 after_block->UpdateGlobalPyParam(block->global_py_params());
3074 AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
3075 AnfNodePtr while_condition_node = nullptr;
3076 {
3077 TraceGuard trace_guard(std::make_shared<TraceForceWhileCond>(condition_node->debug_info()));
3078 while_condition_node = header_block->ForceToCondNode(condition_node, true);
3079 }
3080 (void)header_block->ConditionalJump(while_condition_node, body_block, after_block);
3081
3082 body_block->Mature();
3083 // Parse loop body statements with loop context.
3084 LoopContext loop_context{&loops_, header_block, nullptr};
3085 py::object body_node = python_adapter::GetPyObjAttr(node, "body");
3086 FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
3087 MS_EXCEPTION_IF_NULL(after_body->func_graph());
3088 if (after_body->func_graph()->get_return() == nullptr) {
3089 after_body->Jump(header_block, {});
3090 }
3091 header_block->Mature();
3092 after_block->Mature();
3093 py::object orelse_obj = python_adapter::GetPyObjAttr(node, "orelse");
3094 if (py::len_hint(orelse_obj) != 0) {
3095 TraceGuard trace_guard(GetLocation(orelse_obj));
3096 MS_LOG(EXCEPTION) << "The 'while...else...' statement is not supported now.";
3097 }
3098 auto &end_block = loop_context.EndBlock();
3099 // end_block exists if we encounter 'break' in loop body.
3100 if (end_block) {
3101 after_block->Jump(end_block, {});
3102 end_block->Mature();
3103 CheckReturnInLoop(block, body_block);
3104 return end_block;
3105 }
3106 // No 'break', no end_block.
3107 CheckReturnInLoop(block, body_block);
3108 return after_block;
3109 }
3110
ParseFor(const FunctionBlockPtr & block,const py::object & node)3111 FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
3112 // Check for-else
3113 py::object orelse_obj = python_adapter::GetPyObjAttr(node, "orelse");
3114 if (py::len_hint(orelse_obj) != 0) {
3115 TraceGuard trace_guard(GetLocation(orelse_obj));
3116 MS_LOG(EXCEPTION) << "The 'for...else...' statement is not supported now.";
3117 }
3118 std::string for_block_name = "for";
3119 block->set_block_name(for_block_name);
3120 static const auto transform_for_half_unroll_call = (common::GetCompileConfig("FOR_HALF_UNROLL") == "1");
3121 if (transform_for_half_unroll_call) {
3122 return ParseForRepeat(block, node);
3123 }
3124 return ParseForUnroll(block, node);
3125 }
3126
3127 // Implement unroll for statement with tuple/getitem.
ParseForUnroll(const FunctionBlockPtr & block,const py::object & node)3128 FunctionBlockPtr Parser::ParseForUnroll(const FunctionBlockPtr &block, const py::object &node) {
3129 MS_LOG(DEBUG) << "Process ast For by loop variable";
3130 MS_EXCEPTION_IF_NULL(block);
3131 AnfNodePtr op_len = block->MakeResolveOperation(NAMED_PRIMITIVE_LEN);
3132 AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
3133 AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
3134
3135 // Get variable name of 'x' in statement 'for x in xs'
3136 py::object target_node = python_adapter::GetPyObjAttr(node, "target");
3137
3138 // Create statement 'len(xs)'
3139 py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
3140 MS_LOG(DEBUG) << "Parse Recursive Iter, iter_obj: " << py::str(iter_obj);
3141 AnfNodePtr origin_iter_node = ParseExprNode(block, iter_obj);
3142 MS_EXCEPTION_IF_NULL(origin_iter_node);
3143 auto iter_node = block->func_graph()->NewCNodeInOrder({op_iter, origin_iter_node});
3144 CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
3145 FunctionBlockPtr header_block =
3146 MakeFunctionBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
3147 MS_EXCEPTION_IF_NULL(header_block);
3148 // Create loop variable 'i'
3149 ParameterPtr loop_var = header_block->func_graph()->add_parameter();
3150
3151 std::string less_module_name = "mindspore.ops.composite.multitype_ops.less_impl";
3152 ValuePtr less_op = prim::GetPythonOps("less", less_module_name);
3153 CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(less_op), loop_var, scalar_len});
3154
3155 // Generate the body of the for statement
3156 FunctionBlockPtr body_block = MakeFunctionBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
3157 MS_EXCEPTION_IF_NULL(body_block);
3158 body_block->AddPrevBlock(header_block);
3159 // Create 'x = xs[i]'
3160 auto body_func_graph = body_block->func_graph();
3161 MS_EXCEPTION_IF_NULL(body_func_graph);
3162 auto target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
3163 header_block->UpdateGlobalPyParam(block->global_py_params());
3164 body_block->UpdateGlobalPyParam(block->global_py_params());
3165 WriteAssignVars(body_block, target_node, target_var);
3166
3167 // Create 'i = i + 1'
3168 std::string add_module_name = "mindspore.ops.composite.multitype_ops.add_impl";
3169 ValuePtr add_op = prim::GetPythonOps("add", add_module_name);
3170 auto add_one = NewValueNode(static_cast<int64_t>(1));
3171 CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({NewValueNode(add_op), loop_var, add_one});
3172
3173 body_block->WriteVariable(loop_var->name(), loop_var_inc);
3174
3175 // Link the variable name with the target
3176 auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
3177 loop_var->debug_info()->set_trace_info(it_info);
3178
3179 FunctionBlockPtr after_block = nullptr;
3180 {
3181 TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
3182 after_block = MakeFunctionBlock();
3183 }
3184 MS_EXCEPTION_IF_NULL(after_block);
3185 after_block->AddPrevBlock(header_block);
3186 block->Jump(header_block, {NewValueNode(static_cast<int64_t>(0))});
3187 body_block->Mature();
3188 after_block->UpdateGlobalPyParam(block->global_py_params());
3189
3190 (void)header_block->ConditionalJump(cond_node, body_block, after_block);
3191
3192 // Parse loop body statements with loop context.
3193 LoopContext loop_context{&loops_, header_block, loop_var_inc};
3194 py::object body_node = python_adapter::GetPyObjAttr(node, "body");
3195 FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
3196 after_body_block->UpdateGlobalPyParam(block->global_py_params());
3197 if (after_body_block->func_graph()->get_return() == nullptr) {
3198 after_body_block->Jump(header_block, {loop_var_inc});
3199 }
3200
3201 header_block->Mature();
3202 after_block->Mature();
3203 auto &end_block = loop_context.EndBlock();
3204 if (end_block) {
3205 // end_block exists if we encounter 'break' in loop body.
3206 after_block->Jump(end_block, {});
3207 end_block->Mature();
3208 CheckReturnInLoop(block, body_block);
3209 return end_block;
3210 }
3211 // No 'break', no end_block.
3212 CheckReturnInLoop(block, body_block);
3213 return after_block;
3214 }
3215
3216 // Implement for statement with repeat calling sub graph.
ParseForRepeat(const FunctionBlockPtr & block,const py::object & node)3217 FunctionBlockPtr Parser::ParseForRepeat(const FunctionBlockPtr &block, const py::object &node) {
3218 MS_LOG(DEBUG) << "Process ast For by loop variable";
3219 MS_EXCEPTION_IF_NULL(block);
3220 FunctionBlockPtr header_block =
3221 MakeFunctionBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
3222 MS_EXCEPTION_IF_NULL(header_block);
3223
3224 // Create statement 'len(xs)'
3225 py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
3226 AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
3227 MS_EXCEPTION_IF_NULL(iter_node);
3228 // Generate node for loop count and convert it to tensor, to make the loop not unroll
3229 ParameterPtr header_iter_param = header_block->func_graph()->add_parameter();
3230 AnfNodePtr header_len = header_block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
3231 header_block->CheckUndefinedSymbol(NAMED_PRIMITIVE_LEN, header_len);
3232 CNodePtr scalar_len = header_block->func_graph()->NewCNodeInOrder({header_len, header_iter_param});
3233
3234 // Create loop variable 'i'
3235 ParameterPtr loop_var = header_block->func_graph()->add_parameter();
3236 // Create loop condition 'i < len(xs)'
3237 std::string less_module_name = "mindspore.ops.composite.multitype_ops.less_impl";
3238 ValuePtr less_op = prim::GetPythonOps("less", less_module_name);
3239 CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(less_op), loop_var, scalar_len});
3240
3241 // Generate the body of the for statement
3242 FunctionBlockPtr body_block = MakeFunctionBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
3243 MS_EXCEPTION_IF_NULL(body_block);
3244 body_block->AddPrevBlock(header_block);
3245 // Create 'x = xs[i]'
3246 auto body_func_graph = body_block->func_graph();
3247 AnfNodePtr body_getitem = body_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
3248 CNodePtr target_var = body_func_graph->NewCNodeInOrder({body_getitem, header_iter_param, loop_var});
3249
3250 header_block->UpdateGlobalPyParam(block->global_py_params());
3251 body_block->UpdateGlobalPyParam(block->global_py_params());
3252
3253 // Get variable name of 'x' in statement 'for x in xs'
3254 py::object target_node = python_adapter::GetPyObjAttr(node, "target");
3255 WriteAssignVars(body_block, target_node, target_var);
3256
3257 // Create 'i = i + 1'
3258 std::string add_module_name = "mindspore.ops.composite.multitype_ops.add_impl";
3259 ValuePtr add_op = prim::GetPythonOps("add", add_module_name);
3260 CNodePtr loop_var_inc =
3261 body_func_graph->NewCNodeInOrder({NewValueNode(add_op), loop_var, NewValueNode(static_cast<int64_t>(1))});
3262 body_block->WriteVariable(loop_var->name(), loop_var_inc);
3263
3264 // Link the variable name with the target
3265 auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
3266 loop_var->debug_info()->set_trace_info(it_info);
3267
3268 FunctionBlockPtr after_block = nullptr;
3269 {
3270 TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
3271 after_block = MakeFunctionBlock();
3272 }
3273 MS_EXCEPTION_IF_NULL(after_block);
3274 after_block->AddPrevBlock(header_block);
3275 block->Jump(header_block, {iter_node, NewValueNode(static_cast<int64_t>(0))});
3276 body_block->Mature();
3277 after_block->UpdateGlobalPyParam(block->global_py_params());
3278 (void)header_block->ConditionalJump(cond_node, body_block, after_block);
3279
3280 // Generate the body of the for statement
3281 FunctionBlockPtr rolled_body_block =
3282 MakeFunctionBlock(std::make_shared<TraceForRolledBody>(body_block->func_graph()->debug_info()));
3283 MS_EXCEPTION_IF_NULL(rolled_body_block);
3284
3285 rolled_body_block->Mature();
3286 body_block->Jump(rolled_body_block, {});
3287 auto rolled_body_call = dyn_cast<CNode>(body_block->func_graph()->output());
3288 rolled_body_block->UpdateGlobalPyParam(block->global_py_params());
3289
3290 // Parse loop body statements with loop context.
3291 LoopContext loop_context{&loops_, header_block, loop_var_inc};
3292 py::object body_node = python_adapter::GetPyObjAttr(node, "body");
3293 FunctionBlockPtr after_body_block = ParseStatements(rolled_body_block, body_node);
3294 after_body_block->UpdateGlobalPyParam(block->global_py_params());
3295 MS_LOG(DEBUG) << "Finish rolled block, after_body_block: " << after_body_block->ToString()
3296 << ", rolled_body_block: " << rolled_body_block->ToString();
3297 if (after_body_block->func_graph()->get_return() == nullptr) {
3298 after_body_block->Jump(header_block, {header_iter_param, loop_var_inc});
3299 }
3300
3301 // Record the former/middle/latter graphs for later transforming.
3302 static const auto transform_for_half_unroll_call = (common::GetCompileConfig("FOR_HALF_UNROLL") == "1");
3303 if (transform_for_half_unroll_call) {
3304 std::pair<FunctionBlockPtr, FunctionBlockPtr> loop_graphs;
3305 loop_graphs.first = body_block;
3306 loop_graphs.second = after_body_block;
3307 std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>> loop_graphs_vec{loop_graphs};
3308 (void)parallel_call_graphs_.emplace_back(loop_graphs_vec);
3309 MS_LOG(DEBUG) << "Record tail call graphs, loop: {former: " << loop_graphs.first->func_graph()->ToString()
3310 << ", middle: " << loop_graphs.second->func_graph()->ToString() << "}";
3311 // Record the rolled body function, for later lifting operation.
3312 if (rolled_body_call != nullptr) {
3313 (void)rolled_body_calls_.emplace_back(std::make_pair(rolled_body_call, rolled_body_block));
3314 constexpr int recursive_level = 2;
3315 MS_LOG(DEBUG) << "Record rolled body call: {CNode: " << rolled_body_call->DebugString(recursive_level)
3316 << ", rolled_graph: " << rolled_body_block->ToString() << "}";
3317 }
3318 auto rolled_body_func_graph = rolled_body_block->func_graph();
3319 rolled_body_func_graph->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
3320 rolled_body_func_graph->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUE, true);
3321 }
3322
3323 header_block->Mature();
3324 after_block->Mature();
3325 auto &end_block = loop_context.EndBlock();
3326 if (end_block) {
3327 // end_block exists if we encounter 'break' in loop body.
3328 after_block->Jump(end_block, {});
3329 end_block->Mature();
3330 return end_block;
3331 }
3332 // No 'break', no end_block.
3333 return after_block;
3334 }
3335
ParseIfExp(const FunctionBlockPtr & block,const py::object & node)3336 AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
3337 MS_LOG(DEBUG) << "Process ast IfExp";
3338 MS_EXCEPTION_IF_NULL(block);
3339 py::object test_node = python_adapter::GetPyObjAttr(node, "test");
3340 AnfNodePtr condition_node = ParseExprNode(block, test_node);
3341
3342 AnfNodePtr bool_node = block->ForceToCondNode(condition_node);
3343 FunctionBlockPtr true_block = nullptr;
3344 FunctionBlockPtr false_block = nullptr;
3345 MS_EXCEPTION_IF_NULL(block->func_graph());
3346 {
3347 TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
3348 true_block = MakeFunctionBlock();
3349 }
3350 {
3351 TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
3352 false_block = MakeFunctionBlock();
3353 }
3354
3355 MakeConditionBlocks(block, true_block, false_block);
3356
3357 // Process the if-true branch
3358 py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
3359 MS_EXCEPTION_IF_NULL(true_block->func_graph());
3360 MS_EXCEPTION_IF_NULL(true_block->func_graph()->debug_info());
3361 true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode));
3362 AnfNodePtr true_node = ParseExprNode(true_block, bodyNode);
3363
3364 // Process the orelse branch
3365 py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
3366 MS_EXCEPTION_IF_NULL(false_block->func_graph());
3367 MS_EXCEPTION_IF_NULL(false_block->func_graph()->debug_info());
3368 false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode));
3369 AnfNodePtr false_node = ParseExprNode(false_block, orelseNode);
3370
3371 true_block->func_graph()->set_output(true_node);
3372 false_block->func_graph()->set_output(false_node);
3373
3374 // Use the Primitive replace the operation resolve node (switch),
3375 // because the switch will eventually be converted to Primitive node
3376 CNodePtr switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), bool_node,
3377 NewValueNode(true_block->func_graph()),
3378 NewValueNode(false_block->func_graph())});
3379 std::vector<AnfNodePtr> call_graph_nodes{switch_app};
3380 CNodePtr switch_app_call = block->func_graph()->NewCNodeInOrder(std::move(call_graph_nodes));
3381 return switch_app_call;
3382 }
3383
ParseListCompIter(const FunctionBlockPtr & block,const py::object & node,const py::object & generator_node)3384 FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
3385 const py::object &generator_node) {
3386 // Create a header block.
3387 MS_EXCEPTION_IF_NULL(block->func_graph());
3388 FunctionBlockPtr top_block = MakeFunctionBlock(std::make_shared<TraceListComp>(block->func_graph()->debug_info()));
3389 top_block->AddPrevBlock(block);
3390 // Handle iter attribute.
3391 py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter");
3392 AnfNodePtr origin_iter_anf_node = ParseExprNode(block, iter_node);
3393 MS_EXCEPTION_IF_NULL(origin_iter_anf_node);
3394 AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
3395 MS_EXCEPTION_IF_NULL(op_iter);
3396 AnfNodePtr iter_anf_node = block->func_graph()->NewCNodeInOrder({op_iter, origin_iter_anf_node});
3397 MS_EXCEPTION_IF_NULL(iter_anf_node);
3398
3399 // Create header graph.
3400 FunctionBlockPtr list_header_block =
3401 MakeFunctionBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
3402
3403 // Create hasNext apply.
3404 AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
3405 MS_EXCEPTION_IF_NULL(list_header_block->func_graph());
3406 ParameterPtr iter_param = list_header_block->func_graph()->add_parameter();
3407 constexpr auto iter_param_name = "iter";
3408 iter_param->set_name(iter_param_name);
3409 MS_EXCEPTION_IF_NULL(iter_param->debug_info());
3410 iter_param->debug_info()->set_name(iter_param_name);
3411 CNodePtr cond_apply = list_header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
3412
3413 // Call the header graph with iter.
3414 ParameterPtr list_param = list_header_block->func_graph()->add_parameter();
3415 constexpr auto list_param_name = "list";
3416 list_param->set_name(list_param_name);
3417 MS_EXCEPTION_IF_NULL(list_param->debug_info());
3418 list_param->debug_info()->set_name(list_param_name);
3419 auto empty_list = std::vector<ValuePtr>();
3420 AnfNodePtr empty_list_node = NewValueNode(std::make_shared<ValueList>(empty_list));
3421 top_block->Jump(list_header_block, {iter_anf_node, empty_list_node});
3422
3423 // Create body graph.
3424 FunctionBlockPtr list_body_block =
3425 MakeFunctionBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
3426 list_body_block->AddPrevBlock(list_header_block);
3427 AnfNodePtr op_next = top_block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
3428 MS_EXCEPTION_IF_NULL(list_body_block->func_graph());
3429 CNodePtr next_apply = list_body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
3430 AnfNodePtr op_getitem = top_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
3431 CNodePtr item_apply =
3432 list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(0))});
3433 CNodePtr new_iter =
3434 list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(1))});
3435
3436 // Save the `target` in a variable.
3437 py::object gen_target_node = python_adapter::GetPyObjAttr(generator_node, "target");
3438 WriteAssignVars(list_body_block, gen_target_node, item_apply);
3439
3440 auto ifs_new_list = ParseListCompIfs(list_body_block, list_param, node, generator_node);
3441 list_body_block->Jump(list_header_block, {new_iter, ifs_new_list});
3442
3443 // Create after graph.
3444 FunctionBlockPtr list_after_block =
3445 MakeFunctionBlock(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
3446 list_after_block->AddPrevBlock(list_header_block);
3447 // Return the list in after graph.
3448 MS_EXCEPTION_IF_NULL(list_after_block->func_graph());
3449 list_after_block->func_graph()->set_output(list_param);
3450
3451 // Run the branches.
3452 (void)list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block);
3453
3454 top_block->Mature();
3455 list_header_block->Mature();
3456 list_body_block->Mature();
3457 list_after_block->Mature();
3458 return top_block;
3459 }
3460
ParseListCompIfs(const FunctionBlockPtr & list_body_block,const ParameterPtr & list_param,const py::object & node,const py::object & generator_node)3461 AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
3462 const py::object &node, const py::object &generator_node) {
3463 // Handle ifs attribute.
3464 py::list ifs_node = python_adapter::GetPyObjAttr(generator_node, "ifs");
3465 AnfNodePtr ifs_bool_node;
3466 if (ifs_node.empty()) {
3467 ifs_bool_node = NewValueNode(true);
3468 } else {
3469 ifs_bool_node = ProcessBoolOpValueList(list_body_block, ifs_node, AST_SUB_TYPE_AND);
3470 }
3471
3472 // Create if-true graph.
3473 FunctionBlockPtr if_true_block =
3474 MakeFunctionBlock(std::make_shared<TraceIfStmtTrueBranch>(list_body_block->func_graph()->debug_info()));
3475 if_true_block->AddPrevBlock(list_body_block);
3476 // Handle elt attribute in body block.
3477 py::object elt_obj = python_adapter::GetPyObjAttr(node, "elt");
3478 AnfNodePtr elt_node = ParseExprNode(list_body_block, elt_obj);
3479 // Append the element.
3480 std::vector<AnfNodePtr> list_vec;
3481 AnfNodePtr make_list_op = list_body_block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
3482 (void)list_vec.emplace_back(make_list_op);
3483 (void)list_vec.emplace_back(elt_node);
3484 CNodePtr list_app = list_body_block->func_graph()->NewCNodeInOrder(std::move(list_vec));
3485 std::string add_module_name = "mindspore.ops.composite.multitype_ops.add_impl";
3486 ValuePtr add_op = prim::GetPythonOps("add", add_module_name);
3487 CNodePtr new_list = list_body_block->func_graph()->NewCNodeInOrder({NewValueNode(add_op), list_param, list_app});
3488 // Return new list in true branch graph.
3489 if_true_block->func_graph()->set_output(new_list);
3490
3491 // Create if-false graph.
3492 FunctionBlockPtr if_false_block =
3493 MakeFunctionBlock(std::make_shared<TraceIfStmtFalseBranch>(list_body_block->func_graph()->debug_info()));
3494 if_false_block->AddPrevBlock(list_body_block);
3495 // Return original list in false branch graph.
3496 MS_EXCEPTION_IF_NULL(if_false_block->func_graph());
3497 if_false_block->func_graph()->set_output(list_param);
3498
3499 // We don't want to create a header graph, where to get and wrap the result of Switch().
3500 // So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky.
3501 (void)list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
3502 // Output is Switch() result, i.e. updated list.
3503 auto switch_apply_node = list_body_block->func_graph()->output();
3504 auto ifs_new_list = switch_apply_node;
3505 // Since we call ConditionalJump() above, to reset the Return as null before call Jump().
3506 list_body_block->func_graph()->set_return(nullptr);
3507 if_true_block->Mature();
3508 if_false_block->Mature();
3509 return ifs_new_list;
3510 }
3511
3512 // A ListComp contains: `elt` and `generators`.
3513 // `generators` contains: `target`, `iter` and `ifs`.
3514 // For example:
3515 // [x * x for x in range(0, 10) if x % 2 == 0]
3516 // It is compiled to be following statement:
3517 // list = []
3518 // for x in range(0, 10):
3519 // if x % 2 == 0:
3520 // list.append(x * x)
3521 // return list
ParseListComp(const FunctionBlockPtr & block,const py::object & node)3522 AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object &node) {
3523 MS_LOG(DEBUG) << "Process ast ListComp";
3524 MS_EXCEPTION_IF_NULL(block);
3525
3526 // Handle generators attribute.
3527 py::list generators_node = python_adapter::GetPyObjAttr(node, "generators");
3528 if (generators_node.size() != 1) {
3529 MS_EXCEPTION(TypeError) << "The 'generators' supports 1 'comprehension' in ListComp/GeneratorExp, but got "
3530 << generators_node.size() << " comprehensions.";
3531 }
3532 py::object generator_node = generators_node[0];
3533 auto generator_node_type = ast_->GetNodeType(generator_node);
3534 auto generator_node_name = generator_node_type->node_name();
3535 constexpr auto comprehension_name = "comprehension";
3536 if (generator_node_name != comprehension_name) {
3537 MS_LOG(INTERNAL_EXCEPTION) << "Generator node name should be " << comprehension_name << ", but got "
3538 << generator_node_name;
3539 }
3540
3541 // Parse ListComp's `iter` and add `elt` in it.
3542 auto top_block = ParseListCompIter(block, node, generator_node);
3543
3544 // Call the top graph and return the list.
3545 auto call_function_node = NewValueNode(top_block->func_graph());
3546 std::vector<AnfNodePtr> func_call_nodes;
3547 func_call_nodes.push_back(call_function_node);
3548 MS_EXCEPTION_IF_NULL(block->func_graph());
3549 AnfNodePtr output = block->func_graph()->NewCNodeInOrder(std::move(func_call_nodes));
3550 return output;
3551 }
3552
ParseDictCompIter(const FunctionBlockPtr & block,const py::object & node,const py::object & generator_node)3553 FunctionBlockPtr Parser::ParseDictCompIter(const FunctionBlockPtr &block, const py::object &node,
3554 const py::object &generator_node) {
3555 // Create a header block.
3556 MS_EXCEPTION_IF_NULL(block->func_graph());
3557 FunctionBlockPtr top_block = MakeFunctionBlock(std::make_shared<TraceDictComp>(block->func_graph()->debug_info()));
3558 top_block->AddPrevBlock(block);
3559 // Handle iter attribute.
3560 py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter");
3561 AnfNodePtr origin_iter_anf_node = ParseExprNode(block, iter_node);
3562 MS_EXCEPTION_IF_NULL(origin_iter_anf_node);
3563 AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
3564 MS_EXCEPTION_IF_NULL(op_iter);
3565 AnfNodePtr iter_anf_node = block->func_graph()->NewCNodeInOrder({op_iter, origin_iter_anf_node});
3566 MS_EXCEPTION_IF_NULL(iter_anf_node);
3567
3568 // Create header graph.
3569 FunctionBlockPtr dict_header_block =
3570 MakeFunctionBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
3571 AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
3572 MS_EXCEPTION_IF_NULL(dict_header_block->func_graph());
3573 ParameterPtr iter_param = dict_header_block->func_graph()->add_parameter();
3574 constexpr auto iter_param_name = "iter";
3575 iter_param->set_name(iter_param_name);
3576 MS_EXCEPTION_IF_NULL(iter_param->debug_info());
3577 iter_param->debug_info()->set_name(iter_param_name);
3578 CNodePtr cond_apply = dict_header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
3579
3580 // Call the header graph with iter.
3581 ParameterPtr dict_param = dict_header_block->func_graph()->add_parameter();
3582 constexpr auto dict_param_name = "dict";
3583 dict_param->set_name(dict_param_name);
3584 MS_EXCEPTION_IF_NULL(dict_param->debug_info());
3585 dict_param->debug_info()->set_name(dict_param_name);
3586 auto empty_key = std::vector<ValuePtr>();
3587 AnfNodePtr empty_key_node = NewValueNode(std::make_shared<ValueTuple>(empty_key));
3588 auto empty_value = std::vector<ValuePtr>();
3589 AnfNodePtr empty_value_node = NewValueNode(std::make_shared<ValueTuple>(empty_value));
3590 auto make_dict_op = top_block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
3591 auto empty_dict_node = top_block->func_graph()->NewCNodeInOrder({make_dict_op, empty_key_node, empty_value_node});
3592 top_block->Jump(dict_header_block, {iter_anf_node, empty_dict_node});
3593
3594 // Create body graph.
3595 FunctionBlockPtr dict_body_block =
3596 MakeFunctionBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
3597 dict_body_block->AddPrevBlock(dict_header_block);
3598 AnfNodePtr op_next = top_block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
3599 MS_EXCEPTION_IF_NULL(dict_body_block->func_graph());
3600 CNodePtr next_apply = dict_body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
3601 AnfNodePtr op_getitem = top_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
3602 CNodePtr item_apply =
3603 dict_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(0))});
3604 CNodePtr new_iter =
3605 dict_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(1))});
3606
3607 // Save the `target` in a variable.
3608 py::object gen_target_node = python_adapter::GetPyObjAttr(generator_node, "target");
3609 WriteAssignVars(dict_body_block, gen_target_node, item_apply);
3610
3611 auto ifs_new_dic = ParseDictCompIfs(dict_body_block, dict_param, node, generator_node);
3612 dict_body_block->Jump(dict_header_block, {new_iter, ifs_new_dic});
3613
3614 // Create after graph.
3615 FunctionBlockPtr dict_after_block =
3616 MakeFunctionBlock(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
3617 dict_after_block->AddPrevBlock(dict_header_block);
3618 // Return the dict in after graph.
3619 MS_EXCEPTION_IF_NULL(dict_after_block->func_graph());
3620 dict_after_block->func_graph()->set_output(dict_param);
3621
3622 // Run the branches.
3623 (void)dict_header_block->ConditionalJump(cond_apply, dict_body_block, dict_after_block);
3624
3625 top_block->Mature();
3626 dict_header_block->Mature();
3627 dict_body_block->Mature();
3628 dict_after_block->Mature();
3629 return top_block;
3630 }
3631
ParseDictCompIfs(const FunctionBlockPtr & dict_body_block,const ParameterPtr & dict_param,const py::object & node,const py::object & generator_node)3632 AnfNodePtr Parser::ParseDictCompIfs(const FunctionBlockPtr &dict_body_block, const ParameterPtr &dict_param,
3633 const py::object &node, const py::object &generator_node) {
3634 // Handle ifs attribute.
3635 py::list ifs_node = python_adapter::GetPyObjAttr(generator_node, "ifs");
3636 AnfNodePtr ifs_bool_node;
3637 if (ifs_node.empty()) {
3638 ifs_bool_node = NewValueNode(true);
3639 } else {
3640 ifs_bool_node = ProcessBoolOpValueList(dict_body_block, ifs_node, AST_SUB_TYPE_AND);
3641 }
3642
3643 // Create if-true graph.
3644 MS_EXCEPTION_IF_NULL(dict_body_block->func_graph());
3645 FunctionBlockPtr if_true_block =
3646 MakeFunctionBlock(std::make_shared<TraceIfStmtTrueBranch>(dict_body_block->func_graph()->debug_info()));
3647 if_true_block->AddPrevBlock(dict_body_block);
3648 // Handle key, value attribute in body block.
3649 py::object key_obj = python_adapter::GetPyObjAttr(node, "key");
3650 AnfNodePtr key_node = ParseExprNode(dict_body_block, key_obj);
3651 py::object value_obj = python_adapter::GetPyObjAttr(node, "value");
3652 AnfNodePtr value_node = ParseExprNode(dict_body_block, value_obj);
3653 // update dict.
3654 std::vector<AnfNodePtr> key_vec;
3655 std::vector<AnfNodePtr> value_vec;
3656 std::vector<AnfNodePtr> dict_vec;
3657 AnfNodePtr make_dict_op = dict_body_block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
3658 AnfNodePtr make_tuple_op = dict_body_block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
3659 (void)key_vec.emplace_back(make_tuple_op);
3660 (void)key_vec.emplace_back(key_node);
3661 CNodePtr key_app = dict_body_block->func_graph()->NewCNodeInOrder(std::move(key_vec));
3662 (void)value_vec.emplace_back(make_tuple_op);
3663 (void)value_vec.emplace_back(value_node);
3664 CNodePtr value_app = dict_body_block->func_graph()->NewCNodeInOrder(std::move(value_vec));
3665 (void)dict_vec.emplace_back(make_dict_op);
3666 (void)dict_vec.emplace_back(key_app);
3667 (void)dict_vec.emplace_back(value_app);
3668 CNodePtr dict_app = dict_body_block->func_graph()->NewCNodeInOrder(std::move(dict_vec));
3669 std::string add_module_name = "mindspore.ops.composite.multitype_ops.add_impl";
3670 ValuePtr add_op = prim::GetPythonOps("add", add_module_name);
3671 CNodePtr new_dict = dict_body_block->func_graph()->NewCNodeInOrder({NewValueNode(add_op), dict_param, dict_app});
3672 // Return new dict in true branch graph.
3673 if_true_block->func_graph()->set_output(new_dict);
3674
3675 // Create if-false graph.
3676 FunctionBlockPtr if_false_block =
3677 MakeFunctionBlock(std::make_shared<TraceIfStmtFalseBranch>(dict_body_block->func_graph()->debug_info()));
3678 if_false_block->AddPrevBlock(dict_body_block);
3679 // Return original dict in false branch graph.
3680 MS_EXCEPTION_IF_NULL(if_false_block->func_graph());
3681 if_false_block->func_graph()->set_output(dict_param);
3682
3683 // We don't want to create a header graph, where to get and wrap the result of Switch().
3684 // So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky.
3685 (void)dict_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
3686 // Output is Switch() result, i.e. updated dict.
3687 auto switch_apply_node = dict_body_block->func_graph()->output();
3688 auto ifs_new_dict = switch_apply_node;
3689 // Since we call ConditionalJump() above, to reset the Return as null before call Jump().
3690 dict_body_block->func_graph()->set_return(nullptr);
3691 if_true_block->Mature();
3692 if_false_block->Mature();
3693 return ifs_new_dict;
3694 }
3695
3696 // A ListComp contains: `elt` and `generators`.
3697 // `generators` contains: `target`, `iter` and `ifs`.
3698 // For example:
3699 // {x: y for y, x in some_dict.items() if x > 1}
3700 // It is compiled to be following statement:
3701 // dict = {}
3702 // for y, x in some_dict.items():
3703 // if x > 1:
3704 // dict[x] = y
3705 // return dict
ParseDictComp(const FunctionBlockPtr & block,const py::object & node)3706 AnfNodePtr Parser::ParseDictComp(const FunctionBlockPtr &block, const py::object &node) {
3707 MS_LOG(DEBUG) << "Process ast DictComp";
3708 MS_EXCEPTION_IF_NULL(block);
3709
3710 // Handle generators attribute.
3711 py::list generators_node = python_adapter::GetPyObjAttr(node, "generators");
3712 if (generators_node.size() != 1) {
3713 MS_EXCEPTION(TypeError) << "The 'generators' supports 1 'comprehension' in DictComp/GeneratorExp, but got "
3714 << generators_node.size() << " comprehensions.";
3715 }
3716 py::object generator_node = generators_node[0];
3717 auto generator_node_type = ast_->GetNodeType(generator_node);
3718 auto generator_node_name = generator_node_type->node_name();
3719 constexpr auto comprehension_name = "comprehension";
3720 if (generator_node_name != comprehension_name) {
3721 MS_LOG(INTERNAL_EXCEPTION) << "Generator node name should be " << comprehension_name << ", but got "
3722 << generator_node_name;
3723 }
3724
3725 // Parse DictComp's `iter` and add `key`, `value` in it.
3726 auto top_block = ParseDictCompIter(block, node, generator_node);
3727
3728 // Call the top graph and return the dict.
3729 auto call_function_node = NewValueNode(top_block->func_graph());
3730 std::vector<AnfNodePtr> func_call_nodes;
3731 func_call_nodes.push_back(call_function_node);
3732 MS_EXCEPTION_IF_NULL(block->func_graph());
3733 AnfNodePtr output = block->func_graph()->NewCNodeInOrder(std::move(func_call_nodes));
3734 return output;
3735 }
3736
ParseJoinedStr(const FunctionBlockPtr & block,const py::object & node)3737 AnfNodePtr Parser::ParseJoinedStr(const FunctionBlockPtr &block, const py::object &node) {
3738 MS_LOG(DEBUG) << "Process ast JoinedStr.";
3739 TraceGuard trace_guard(GetLocation(node));
3740 MS_EXCEPTION_IF_NULL(block);
3741 const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(node));
3742 py::list py_values = python_adapter::GetPyObjAttr(node, "values");
3743 std::vector<AnfNodePtr> value_nodes{NewValueNode(prim::kPrimJoinedStr)};
3744 bool has_interpret_node = false;
3745 for (size_t i = 0; i < py_values.size(); ++i) {
3746 AnfNodePtr str_value = ParseExprNode(block, py_values[i]);
3747 // If exist interpret node in JoinedStr, all object in py_values will convert to interpret node.
3748 // Need to parse all elements in py_values in order to put them in local param dict.
3749 if (!has_interpret_node && fallback::CheckInterpretInput(str_value)) {
3750 has_interpret_node = true;
3751 }
3752 (void)value_nodes.emplace_back(str_value);
3753 }
3754 // JoinedStr can not get expr_src for their separate element. So, we convert the whole str directly.
3755 if (has_interpret_node) {
3756 return MakeInterpretNode(block, value_nodes[1], script_text);
3757 }
3758 auto func_graph = block->func_graph();
3759 MS_EXCEPTION_IF_NULL(func_graph);
3760 AnfNodePtr output = func_graph->NewCNodeInOrder(std::move(value_nodes));
3761 return output;
3762 }
3763
ParseFormattedValue(const FunctionBlockPtr & block,const py::object & node)3764 AnfNodePtr Parser::ParseFormattedValue(const FunctionBlockPtr &block, const py::object &node) {
3765 MS_LOG(DEBUG) << "Process ast FormattedValue.";
3766 TraceGuard trace_guard(GetLocation(node));
3767 MS_EXCEPTION_IF_NULL(block);
3768 py::object value_object = python_adapter::GetPyObjAttr(node, "value");
3769 AnfNodePtr value_node = ParseExprNode(block, value_object);
3770 return value_node;
3771 }
3772
ParseStarred(const FunctionBlockPtr & block,const py::object & node)3773 AnfNodePtr Parser::ParseStarred(const FunctionBlockPtr &block, const py::object &node) {
3774 MS_LOG(DEBUG) << "Process ast Starred.";
3775 TraceGuard trace_guard(GetLocation(node));
3776 MS_EXCEPTION_IF_NULL(block);
3777 py::object value_object = python_adapter::GetPyObjAttr(node, "value");
3778 AnfNodePtr value_node = ParseExprNode(block, value_object);
3779 AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
3780 auto func = block->func_graph();
3781 MS_EXCEPTION_IF_NULL(func);
3782 CNodePtr iterated_node = func->NewCNodeInOrder({op_iter, value_node});
3783 auto prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
3784 CNodePtr unpack_node = func->NewCNodeInOrder({NewValueNode(prim), iterated_node});
3785 return unpack_node;
3786 }
3787
HandleAssignStarred(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & assigned_node)3788 void Parser::HandleAssignStarred(const FunctionBlockPtr &block, const py::object &target,
3789 const AnfNodePtr &assigned_node) {
3790 MS_EXCEPTION_IF_NULL(block);
3791 MS_EXCEPTION_IF_NULL(assigned_node);
3792 py::object value_object = python_adapter::GetPyObjAttr(target, "value");
3793 py::str name = python_adapter::GetPyObjAttr(value_object, "id");
3794 std::string name_id = name;
3795 MS_EXCEPTION_IF_NULL(assigned_node->debug_info());
3796 assigned_node->debug_info()->set_name(name_id);
3797 MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
3798 block->WriteVariable(name_id, assigned_node);
3799 }
3800
HandleAssignName(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & assigned_node) const3801 void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target,
3802 const AnfNodePtr &assigned_node) const {
3803 MS_EXCEPTION_IF_NULL(block);
3804 MS_EXCEPTION_IF_NULL(assigned_node);
3805 py::str name = python_adapter::GetPyObjAttr(target, "id");
3806 std::string name_id = name;
3807
3808 MS_EXCEPTION_IF_NULL(assigned_node->debug_info());
3809 assigned_node->debug_info()->set_name(name_id);
3810 // Set the debug name of the constant graph
3811 if (IsValueNode<FuncGraph>(assigned_node)) {
3812 // The value should be graph
3813 auto fg = GetValueNode<FuncGraphPtr>(assigned_node);
3814 MS_EXCEPTION_IF_NULL(fg->debug_info());
3815 if (fg->debug_info()->name().empty()) {
3816 fg->debug_info()->set_name(name_id);
3817 }
3818 }
3819 MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
3820 block->WriteVariable(name_id, assigned_node);
3821 }
3822
HandleAssignTupleWithStarredExpression(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & assigned_node,const std::vector<int64_t> & positions)3823 void Parser::HandleAssignTupleWithStarredExpression(const FunctionBlockPtr &block, const py::object &target,
3824 const AnfNodePtr &assigned_node,
3825 const std::vector<int64_t> &positions) {
3826 // Process assigned_node
3827 auto func = block->func_graph();
3828 MS_EXCEPTION_IF_NULL(func);
3829 AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
3830 CNodePtr iterated_node = func->NewCNodeInOrder({op_iter, assigned_node});
3831 auto starred_unpack_prim = std::make_shared<prim::StarredUnpack>(NAMED_METAGRAPH_STARRED_UNPACK);
3832 CNodePtr unpack_node = func->NewCNodeInOrder({NewValueNode(starred_unpack_prim), iterated_node});
3833
3834 py::list items = python_adapter::GetPyObjAttr(target, "elts");
3835 for (size_t i = 0; i < items.size(); i++) {
3836 py::object elt = items[i];
3837 auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elt)));
3838 if (elt_type != AST_SUB_TYPE_STARRED) {
3839 std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
3840 ValuePtr op = prim::GetPythonOps("getitem", module_name);
3841 std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(op), unpack_node, NewValueNode(positions[i])};
3842 AnfNodePtr tuple_get_item = func->NewCNodeInOrder(tuple_get_item_inputs);
3843 MS_LOG(DEBUG) << "Assign name: `" << py::str(elt) << "` to node: " << tuple_get_item->DebugString();
3844 WriteAssignVars(block, elt, tuple_get_item);
3845 } else {
3846 auto starred_get_item_prim = std::make_shared<prim::StarredGetItem>(NAMED_METAGRAPH_STARRED_GET_ITEM);
3847 std::vector<AnfNodePtr> starred_get_item_inputs{NewValueNode(starred_get_item_prim), unpack_node,
3848 NewValueNode(positions[i]),
3849 NewValueNode(SizeToLong(items.size()))};
3850 AnfNodePtr starred_get_item = func->NewCNodeInOrder(starred_get_item_inputs);
3851 MS_LOG(DEBUG) << "Assign name: `" << py::str(elt) << "` to node: " << starred_get_item->DebugString();
3852 WriteAssignVars(block, elt, starred_get_item);
3853 }
3854 }
3855 }
3856
HandleAssignTupleOrList(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & assigned_node)3857 void Parser::HandleAssignTupleOrList(const FunctionBlockPtr &block, const py::object &target,
3858 const AnfNodePtr &assigned_node) {
3859 MS_EXCEPTION_IF_NULL(block);
3860 AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
3861 py::list items = python_adapter::GetPyObjAttr(target, "elts");
3862
3863 // Record the position with targets.
3864 size_t target_starred_num = 0;
3865 size_t starred_pos = items.size();
3866 std::vector<int64_t> positions(items.size(), 0);
3867 for (size_t i = 0; i < items.size(); i++) {
3868 py::object elt = items[i];
3869 auto elt_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, elt)));
3870 if (elt_type == AST_SUB_TYPE_STARRED) {
3871 target_starred_num++;
3872 if (target_starred_num > 1) {
3873 MS_LOG(EXCEPTION) << "SyntaxError: " << target_starred_num << " starred expressions in assignment.";
3874 }
3875 starred_pos = i;
3876 positions[i] = i;
3877 } else {
3878 if (i > starred_pos) {
3879 positions[i] = i - items.size();
3880 } else {
3881 positions[i] = i;
3882 }
3883 }
3884 }
3885 auto func = block->func_graph();
3886 MS_EXCEPTION_IF_NULL(func);
3887
3888 if (target_starred_num == 0) {
3889 for (size_t i = 0; i < items.size(); i++) {
3890 // Use the Primitive replace the operation resolve node (getitem),
3891 // because the getitem will eventually be converted to Primitive node
3892 CNodePtr item_apply = func->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast<int64_t>(i))});
3893 py::object elt = items[i];
3894 WriteAssignVars(block, elt, item_apply);
3895 }
3896 return;
3897 }
3898
3899 // Process AssignTuple with starred expression.
3900 // a, *b, c = x // targets(a, *b, c) = assign(x)
3901 HandleAssignTupleWithStarredExpression(block, target, assigned_node, positions);
3902 }
3903
IsClassParameterMember(const py::object & target_obj,const AnfNodePtr & target_node) const3904 bool Parser::IsClassParameterMember(const py::object &target_obj, const AnfNodePtr &target_node) const {
3905 auto attr_name = target_obj.attr("attr").cast<std::string>();
3906 if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
3907 return false;
3908 }
3909
3910 auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
3911 return (py::hasattr(obj, "__parameter__"));
3912 }
3913
HandleAssignClassParameterMember(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & value_node)3914 bool Parser::HandleAssignClassParameterMember(const FunctionBlockPtr &block, const py::object &target,
3915 const AnfNodePtr &value_node) {
3916 MS_EXCEPTION_IF_NULL(block);
3917 // Now only support the self.xx = xxxxx, can't support x.y = xxxx
3918 AnfNodePtr target_node = ParseExprNode(block, target);
3919 if (target_node == nullptr) {
3920 return false;
3921 }
3922
3923 if (!IsClassParameterMember(target, target_node)) {
3924 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
3925 if (!allow_fallback_runtime) {
3926 auto attr_name = target.attr("attr").cast<std::string>();
3927 std::string var_name = "self." + attr_name;
3928 auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
3929 auto obj_type = obj.attr("__class__").attr("__name__");
3930 MS_EXCEPTION(TypeError) << "In JIT strict mode, if need to modify a member attribute of a class with " << var_name
3931 << ", the member attribute must be of the Parameter type. But got '"
3932 << py::str(obj).cast<std::string>() << "' with type '"
3933 << py::str(obj_type).cast<std::string>()
3934 << "'. You can use os.environ['MS_DEV_JIT_SYNTAX_LEVEL'] = '2' "
3935 << "to enable the JIT lax mode to support the current syntax.\n\n"
3936 << trace::GetDebugInfoStr(target_node->debug_info());
3937 }
3938 MS_LOG(DEBUG) << "Erase unused node: " << target_node->DebugString();
3939 block->func_graph()->EraseUnusedNodeInOrder(target_node);
3940 return false;
3941 }
3942 block->SetStateAssign(target_node, value_node);
3943 return true;
3944 }
3945
MakeSetAttrNode(const FunctionBlockPtr & block,const AnfNodePtr & target_node,const AnfNodePtr & value_node,const std::string & target_id_str,const std::string & attr_str)3946 void Parser::MakeSetAttrNode(const FunctionBlockPtr &block, const AnfNodePtr &target_node, const AnfNodePtr &value_node,
3947 const std::string &target_id_str, const std::string &attr_str) {
3948 std::vector<AnfNodePtr> setattr_node_inputs{NewValueNode(prim::kPrimSetAttr)};
3949 (void)setattr_node_inputs.emplace_back(target_node);
3950 (void)setattr_node_inputs.emplace_back(NewValueNode(attr_str));
3951 (void)setattr_node_inputs.emplace_back(value_node);
3952 auto fg = block->func_graph();
3953 MS_EXCEPTION_IF_NULL(fg);
3954 auto setattr_node = fg->NewCNodeInOrder(setattr_node_inputs);
3955
3956 // Update setattr_nodes_map.
3957 auto iter = setattr_nodes_map_.find(target_id_str);
3958 if (iter == setattr_nodes_map_.end()) {
3959 auto attr_map = std::map<std::string, AnfNodePtr>();
3960 (void)attr_map.emplace(std::make_pair(attr_str, setattr_node));
3961 (void)setattr_nodes_map_.emplace(std::make_pair(target_id_str, attr_map));
3962 } else {
3963 // If found setattr node before, set it as new setattr node's input.
3964 auto iter_attr = iter->second.find(attr_str);
3965 if (iter_attr != iter->second.end()) {
3966 auto prev_node = iter_attr->second;
3967 MS_EXCEPTION_IF_NULL(prev_node);
3968 auto prev_node_fg = prev_node->func_graph();
3969 MS_EXCEPTION_IF_NULL(prev_node_fg);
3970 if (prev_node_fg == fg) {
3971 // Only add to new setattr node input if two nodes is in the same graph.
3972 setattr_node->add_input(iter_attr->second);
3973 }
3974 }
3975 // Force update the setattr node to keep the newest one.
3976 (void)iter->second.insert_or_assign(attr_str, setattr_node);
3977 }
3978 block->AddIsolatedNode(setattr_node);
3979 }
3980
HandleAssignClassMember(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & value_node)3981 void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target,
3982 const AnfNodePtr &value_node) {
3983 MS_EXCEPTION_IF_NULL(block);
3984 const py::object target_obj = python_adapter::GetPyObjAttr(target, "value");
3985 TraceGuard trace_guard(GetLocation(target_obj));
3986 std::string target_id_str;
3987 AnfNodePtr target_node = nullptr;
3988 auto node_type = ast()->GetNodeType(target_obj);
3989 const std::string &node_type_name = node_type->node_name();
3990 MS_LOG(DEBUG) << "node_type_name: " << node_type_name << ", target: " << py::str(target);
3991 if (node_type_name == "Attribute") {
3992 // Prepare for setattr with nested getattr target, parse getattr firstly.
3993 target_node = ParseExprNode(block, target_obj);
3994 target_id_str = GetLocation(target_obj)->expr_src();
3995 } else if (node_type_name == "Call") {
3996 // Prepare for setattr with nested 'getattr' call target, parse 'getattr' call firstly.
3997 target_node = ParseExprNode(block, target_obj);
3998 target_id_str = GetLocation(target_obj)->expr_src();
3999 } else if (node_type_name == "Name") {
4000 if (!py::hasattr(target_obj, "id")) {
4001 MS_LOG(INTERNAL_EXCEPTION) << "Wrong ast, target: " << target;
4002 }
4003 const py::object id_obj = python_adapter::GetPyObjAttr(target_obj, "id");
4004 target_id_str = id_obj.cast<std::string>();
4005 if (ast()->target_type() == PARSE_TARGET_OBJECT_INSTANCE && target_id_str == "self") {
4006 const auto &interpreted_obj = std::make_shared<InterpretedObject>(ast()->obj());
4007 target_node = NewValueNode(interpreted_obj);
4008 } else {
4009 py::object setattr_obj;
4010 try {
4011 py::tuple namespace_info = ast_->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, target_id_str);
4012 constexpr size_t value_index = 2;
4013 setattr_obj = namespace_info[value_index];
4014 } catch (const std::exception &e) {
4015 MS_LOG(DEBUG) << target_id_str << " is not supported in JIT Fallback. Original steps are processing instead.";
4016 setattr_obj = py::none();
4017 }
4018 // convert target node in setattr to convert getattr after it later.
4019 if (!py::isinstance<py::none>(setattr_obj)) {
4020 const auto &interpreted_obj = std::make_shared<InterpretedObject>(setattr_obj);
4021 target_node = NewValueNode(interpreted_obj);
4022 } else {
4023 target_node = ParseExprNode(block, target_obj);
4024 }
4025 }
4026 }
4027 if (target_node == nullptr) {
4028 MS_LOG(EXCEPTION) << "In graph mode, only attribute and name of class members can be assigned. But got "
4029 << node_type_name << ".";
4030 }
4031 const auto &attr_str = python_adapter::GetPyObjAttr(target, "attr").cast<std::string>();
4032 MS_LOG(DEBUG) << "target node: " << target_node->DebugString() << ", target name: " << target_id_str
4033 << ", attr: " << attr_str;
4034
4035 MakeSetAttrNode(block, target_node, value_node, target_id_str, attr_str);
4036 }
4037
MakeSetitemNode(const FunctionBlockPtr & block,const py::object & value_obj,const py::object & slice_obj,const AnfNodePtr & assigned_node,const AnfNodePtr & value_node)4038 CNodePtr Parser::MakeSetitemNode(const FunctionBlockPtr &block, const py::object &value_obj,
4039 const py::object &slice_obj, const AnfNodePtr &assigned_node,
4040 const AnfNodePtr &value_node) {
4041 AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
4042 auto value_id = python_adapter::GetPyObjAttr(value_obj, "id");
4043 AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
4044 auto str_setitem = std::make_shared<StringImm>("__setitem__");
4045 if (!py::isinstance<py::none>(value_id)) {
4046 py::object value_obj = GetValuePythonObject(value_id);
4047 if (!py::isinstance<py::none>(value_obj)) {
4048 AnfNodePtr setitem_node =
4049 block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), value_node, NewValueNode(str_setitem)});
4050 setitem_node->set_user_data<py::object>("__setitem__", std::make_shared<py::object>(value_obj));
4051 return block->func_graph()->NewCNodeInOrder({setitem_node, slice_node, assigned_node});
4052 }
4053 }
4054 return block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
4055 }
4056
HandleAssignSubscript(const FunctionBlockPtr & block,const py::object & target,const AnfNodePtr & assigned_node)4057 void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target,
4058 const AnfNodePtr &assigned_node) {
4059 MS_EXCEPTION_IF_NULL(block);
4060 py::object value_obj = python_adapter::GetPyObjAttr(target, "value");
4061 py::object slice_obj = python_adapter::GetPyObjAttr(target, "slice");
4062 AnfNodePtr value_node = ParseExprNode(block, value_obj);
4063 MS_EXCEPTION_IF_NULL(block->func_graph());
4064 auto setitem_app = MakeSetitemNode(block, value_obj, slice_obj, assigned_node, value_node);
4065 // Getitem apply should return the sequence data structure itself
4066 std::string var_name;
4067 if (ast_->IsClassMemberOfSelf(value_obj)) {
4068 auto attr_name = value_obj.attr("attr").cast<std::string>();
4069 var_name = "self." + attr_name;
4070 if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
4071 MS_EXCEPTION(TypeError) << "'" << var_name
4072 << "' should be initialized in the '__init__' function before subscript.\n\n"
4073 << trace::GetDebugInfoStr(value_node->debug_info());
4074 }
4075 auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
4076 if (py::hasattr(obj, PYTHON_CELL_AS_LIST) || py::hasattr(obj, PYTHON_CELL_AS_DICT)) {
4077 MS_EXCEPTION(TypeError) << "CellList or CellDict object " << py::str(obj).cast<std::string>()
4078 << " is not support to do assign subscript.";
4079 }
4080 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
4081 if (!allow_fallback_runtime) {
4082 if (!py::hasattr(obj, "__parameter__")) {
4083 auto obj_type = obj.attr("__class__").attr("__name__");
4084 MS_EXCEPTION(TypeError) << "When JIT_SYNTAX_LEVEL is not set to LAX" << var_name
4085 << " should be initialized as a 'Parameter' in the '__init__' function"
4086 << " to perform assign subscript, but got: " << py::str(obj).cast<std::string>()
4087 << "' with type '" << py::str(obj_type).cast<std::string>() << "'.\n\n"
4088 << trace::GetDebugInfoStr(value_node->debug_info());
4089 }
4090 }
4091 block->WriteVariable(var_name, setitem_app);
4092 return;
4093 }
4094 if (AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, value_obj))) ==
4095 AST_SUB_TYPE_SUBSCRIPT) {
4096 if (IsSubscriptReferenceType(value_obj)) {
4097 HandleAssignSubscript(block, value_obj, setitem_app);
4098 return;
4099 }
4100 }
4101 if (py::hasattr(value_obj, "id")) {
4102 var_name = value_obj.attr("id").cast<std::string>();
4103 }
4104 block->WriteVariable(var_name, setitem_app);
4105 }
4106
WriteAssignVars(const FunctionBlockPtr & block,const py::object & target_object,const AnfNodePtr & value_node)4107 void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object,
4108 const AnfNodePtr &value_node) {
4109 MS_EXCEPTION_IF_NULL(value_node);
4110 auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object)));
4111 MS_LOG(DEBUG) << "target_object: " << target_object << ", value_node: " << value_node->DebugString()
4112 << ", ast_type: " << ast_type;
4113 if (ast_type == AST_SUB_TYPE_NAME) {
4114 HandleAssignName(block, target_object, value_node);
4115 } else if (ast_type == AST_SUB_TYPE_TUPLE || ast_type == AST_SUB_TYPE_LIST) {
4116 HandleAssignTupleOrList(block, target_object, value_node);
4117 } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
4118 HandleAssignSubscript(block, target_object, value_node);
4119 } else if (ast_->IsClassMemberOfSelf(target_object)) {
4120 if (HandleAssignClassParameterMember(block, target_object, value_node)) {
4121 return;
4122 }
4123 HandleAssignClassMember(block, target_object, value_node);
4124 } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
4125 HandleAssignClassMember(block, target_object, value_node);
4126 } else if (ast_type == AST_SUB_TYPE_STARRED) {
4127 HandleAssignStarred(block, target_object, value_node);
4128 } else {
4129 TraceGuard trace_guard(GetLocation(target_object));
4130 MS_EXCEPTION(TypeError) << "Only supported augassign to attribute of self, variable and index value, but got "
4131 << target_object.get_type()
4132 << ".\nMore details please refer to syntax support at https://www.mindspore.cn";
4133 }
4134 }
4135
IsScriptInParams(const std::string & script_text,const py::dict & global_dict,const std::map<std::string,AnfNodePtr> & local_keys,const FuncGraphPtr & func_graph) const4136 bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
4137 const std::map<std::string, AnfNodePtr> &local_keys,
4138 const FuncGraphPtr &func_graph) const {
4139 MS_EXCEPTION_IF_NULL(func_graph);
4140 // Check global parameters.
4141 if (global_dict.contains(script_text)) {
4142 MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in global params.";
4143 return true;
4144 }
4145
4146 // Check local parameters.
4147 if (local_keys.find(script_text) != local_keys.end()) {
4148 MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in local params.";
4149 return true;
4150 }
4151 return false;
4152 }
4153
HandleInterpret(const FunctionBlockPtr & block,const AnfNodePtr & value_node,const py::object & value_object)4154 AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
4155 const py::object &value_object) {
4156 MS_EXCEPTION_IF_NULL(value_node);
4157 if (!value_node->interpret()) {
4158 return value_node;
4159 }
4160 const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
4161 return MakeInterpretNode(block, value_node, script_text);
4162 }
4163
CheckNeedConvertInterpret(const FunctionBlockPtr & block,const AnfNodePtr & node,const string & script_text) const4164 bool Parser::CheckNeedConvertInterpret(const FunctionBlockPtr &block, const AnfNodePtr &node,
4165 const string &script_text) const {
4166 MS_EXCEPTION_IF_NULL(block);
4167 MS_EXCEPTION_IF_NULL(node);
4168 // Check if script_text is in global/local params.
4169 const py::dict &global_dict = block->global_py_params();
4170 auto keys = std::get<0>(block->local_py_params());
4171 if (IsScriptInParams(script_text, global_dict, keys, block->func_graph())) {
4172 return false;
4173 }
4174 return true;
4175 }
4176
GetSubStrNum(const string & script_text,const string & sub)4177 size_t GetSubStrNum(const string &script_text, const string &sub) {
4178 size_t count = 0;
4179 size_t pos = script_text.find(sub);
4180 while (pos != string::npos) {
4181 count++;
4182 pos = script_text.find(sub, pos + 1);
4183 }
4184 return count;
4185 }
4186
UpdateString(const string & str)4187 std::string UpdateString(const string &str) {
4188 string temp = "";
4189 std::string new_string = "";
4190 for (size_t i = 0; i < str.length(); i++) {
4191 if (str[i] != '\n') {
4192 temp += str[i];
4193 } else {
4194 if (temp[temp.length() - 1] != '\\') {
4195 temp += "+\\";
4196 }
4197 auto pos = temp.find_first_not_of(" ");
4198 temp = temp.substr(pos) + '\n';
4199 new_string += temp;
4200 temp = "";
4201 }
4202 }
4203 auto pos = temp.find_first_not_of(" ");
4204 new_string += temp.substr(pos);
4205 return new_string;
4206 }
4207
ProcessIndentationInScript(const string & script_text)4208 string ProcessIndentationInScript(const string &script_text) {
4209 const size_t f_string_num = 2;
4210 size_t num1 = GetSubStrNum(script_text, "f'");
4211 size_t num2 = GetSubStrNum(script_text, "f\"");
4212 if (script_text.find("\n") == string::npos || num1 + num2 < f_string_num) {
4213 return script_text;
4214 }
4215 return UpdateString(script_text);
4216 }
4217
MakeInterpretNode(const FunctionBlockPtr & block,const AnfNodePtr & value_node,const string & script_text)4218 AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
4219 const string &script_text) {
4220 MS_EXCEPTION_IF_NULL(block);
4221 MS_EXCEPTION_IF_NULL(value_node);
4222 if (!CheckNeedConvertInterpret(block, value_node, script_text)) {
4223 return value_node;
4224 }
4225 string new_script_text = ProcessIndentationInScript(script_text);
4226
4227 // Prepare global parameters.
4228 PyObjectWrapperPtr interpreted_global_dict = std::make_shared<InterpretedObject>(block->global_py_params());
4229 auto global_dict_node = NewValueNode(interpreted_global_dict);
4230 // Prepare local parameters. Select the needed local parameters for script.
4231 auto [keys, values] = block->local_py_params();
4232 std::vector<AnfNodePtr> filter_keys;
4233 std::vector<AnfNodePtr> filter_values;
4234 try {
4235 const py::set &ids = data_converter::GetPythonScriptIdAttrs(py::str(new_script_text));
4236 for (const auto &id : ids) {
4237 const auto &id_str = py::str(id);
4238 const auto &iter = values.find(id_str);
4239 if (iter != values.end()) {
4240 (void)filter_keys.emplace_back(keys[iter->first]);
4241 auto &val_node = iter->second;
4242 // '__py_interpret_local_value_flag__' is used by 'ConvertInterpretedObjForResolve' not to convert PyExecute.
4243 val_node->set_user_data("__py_interpret_local_value_flag__", std::make_shared<bool>(true));
4244 (void)filter_values.emplace_back(val_node);
4245 }
4246 }
4247 } catch (const std::exception &e) {
4248 MS_LOG(INTERNAL_EXCEPTION) << "GetPythonScriptIdAttrs failed, script: " << py::str(new_script_text) << ".\n"
4249 << e.what();
4250 }
4251 constexpr auto self_text = "self";
4252 if (keys.find(self_text) == keys.end() && new_script_text.find(self_text) != std::string::npos) {
4253 py::object self_namespace = ast()->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast()->obj());
4254 auto self_value = std::make_shared<InterpretedObject>(self_namespace);
4255 (void)filter_keys.emplace_back(NewValueNode(MakeValue(self_text)));
4256 (void)filter_values.emplace_back(NewValueNode(self_value));
4257 }
4258
4259 auto local_dict_node = ParseDictByKeysAndValues(block, filter_keys, filter_values);
4260 // Update the valued node if it need interpreting.
4261 constexpr int recursive_level = 2;
4262 MS_EXCEPTION_IF_NULL(block->func_graph());
4263 AnfNodePtr interpreted_node = block->MakeInterpret(new_script_text, global_dict_node, local_dict_node, value_node);
4264 MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: `" << new_script_text
4265 << "`,\nvalue_node: " << value_node->DebugString(recursive_level)
4266 << ",\nglobal_dict_node: " << global_dict_node->ToString()
4267 << ",\nlocal_dict_node: " << local_dict_node->DebugString(recursive_level)
4268 << ",\ninterpreted_node: " << interpreted_node->DebugString(recursive_level);
4269
4270 // Print a hint for user.
4271 auto line_info = trace::GetDebugInfoStr(value_node->debug_info());
4272 MS_LOG(INFO) << "Found unsupported syntax in graph mode, those codes would be fallen back to Python interpreter:"
4273 << "\n\n"
4274 << line_info;
4275 InterpretNodeRecorder::GetInstance().PushPyInterpretNode(value_node);
4276 return interpreted_node;
4277 }
4278
IsPopOperation(const AnfNodePtr & node) const4279 bool Parser::IsPopOperation(const AnfNodePtr &node) const {
4280 auto cnode = node->cast<CNodePtr>();
4281 if (cnode == nullptr) {
4282 return false;
4283 }
4284 auto attr_node = cnode->input(0);
4285 if (IsPrimitiveCNode(attr_node, prim::kPrimGetAttr)) {
4286 auto attr_cnode = attr_node->cast<CNodePtr>();
4287 MS_EXCEPTION_IF_NULL(attr_cnode);
4288 constexpr size_t attr_cnode_size = 3;
4289 constexpr size_t member_index = 2;
4290 if (attr_cnode->size() < attr_cnode_size) {
4291 MS_LOG(EXCEPTION) << "The attr operate has wrong input.";
4292 }
4293 auto member_node = attr_cnode->input(member_index);
4294 if (IsValueNode<StringImm>(member_node)) {
4295 auto attr_name = GetValue<std::string>(GetValueNode(member_node));
4296 if (attr_name == "pop") {
4297 return true;
4298 }
4299 }
4300 }
4301 return false;
4302 }
4303
ProcessPopOperation(const FunctionBlockPtr & block,const AnfNodePtr & value_node,const py::object & target_object)4304 void Parser::ProcessPopOperation(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
4305 const py::object &target_object) {
4306 auto func_graph = block->func_graph();
4307 MS_EXCEPTION_IF_NULL(func_graph);
4308 auto new_list =
4309 func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), value_node, NewValueNode(SizeToLong(0))});
4310 auto pop_node =
4311 func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), value_node, NewValueNode(SizeToLong(1))});
4312 if (!(ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE && ast_->IsClassMemberOfSelf(list_pop_target_obj_))) {
4313 WriteAssignVars(block, list_pop_target_obj_, new_list);
4314 }
4315 WriteAssignVars(block, target_object, pop_node);
4316 }
4317
ProcessPopOperationInAugAssign(const FunctionBlockPtr & block,const AnfNodePtr & value_node,const AnfNodePtr & target_node,const AnfNodePtr & op_node,const py::object & target_object)4318 void Parser::ProcessPopOperationInAugAssign(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
4319 const AnfNodePtr &target_node, const AnfNodePtr &op_node,
4320 const py::object &target_object) {
4321 auto func_graph = block->func_graph();
4322 MS_EXCEPTION_IF_NULL(func_graph);
4323 auto new_list =
4324 func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), value_node, NewValueNode(SizeToLong(0))});
4325 auto pop_node =
4326 func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), value_node, NewValueNode(SizeToLong(1))});
4327 if (!(ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE && ast_->IsClassMemberOfSelf(list_pop_target_obj_))) {
4328 WriteAssignVars(block, list_pop_target_obj_, new_list);
4329 }
4330 AnfNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, pop_node});
4331 WriteAssignVars(block, target_object, augassign_app);
4332 }
4333
4334 // Process a assign statement, such as a = b, a, b = tuple(xx, xx)
ParseAssign(const FunctionBlockPtr & block,const py::object & node)4335 FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
4336 MS_LOG(DEBUG) << "Process ast assign";
4337 py::object value_object = python_adapter::GetPyObjAttr(node, "value");
4338 py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
4339
4340 AnfNodePtr value_node = ParseExprNode(block, value_object);
4341
4342 py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
4343 size_t count = LongToSize(pcount);
4344 MS_LOG(DEBUG) << "The nodes count is " << count;
4345
4346 // b = list_x.pop(a)
4347 // --> list_x, b = list_x.pop(a) need renew the list_x.
4348 if (IsPopOperation(value_node)) {
4349 auto pop_obj = py::cast<py::list>(targets_object)[0];
4350 ProcessPopOperation(block, value_node, pop_obj);
4351 return block;
4352 }
4353 for (size_t i = 0; i < count; i++) {
4354 auto target_node = py::cast<py::list>(targets_object)[i];
4355 WriteAssignVars(block, target_node, value_node);
4356 }
4357
4358 return block;
4359 }
4360
4361 // Process a annassign statement, such as a:int = 1
4362 // target may be one of Name, Attribute, Subscript.
ParseAnnAssign(const FunctionBlockPtr & block,const py::object & node)4363 FunctionBlockPtr Parser::ParseAnnAssign(const FunctionBlockPtr &block, const py::object &node) {
4364 MS_LOG(DEBUG) << "Process ast annassign";
4365 py::object value_object = python_adapter::GetPyObjAttr(node, "value");
4366 py::object target_object = python_adapter::GetPyObjAttr(node, "target");
4367 AnfNodePtr value_node = ParseExprNode(block, value_object);
4368 // b: int = list_x.pop(a)
4369 // --> list_x, b = list_x.pop(a) need renew the list_x.
4370 if (IsPopOperation(value_node)) {
4371 ProcessPopOperation(block, value_node, target_object);
4372 return block;
4373 }
4374 WriteAssignVars(block, target_object, value_node);
4375 return block;
4376 }
4377
ParseBreak(const FunctionBlockPtr & block,const py::object & node)4378 FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) {
4379 if (loops_.empty()) {
4380 // Report error if loop context not set for the 'break' statement.
4381 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected 'break'.";
4382 }
4383 // Get current loop.
4384 Loop &loop = loops_.top();
4385 if (loop.end == nullptr) {
4386 // Create end_block if it is not existed.
4387 MS_EXCEPTION_IF_NULL(block->func_graph());
4388 TraceGuard trace_guard(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
4389 loop.end = MakeFunctionBlock();
4390 }
4391 block->set_break_continue_statement_inside();
4392 MS_LOG(DEBUG) << "Inside the block has break statement, block: " << block->ToString();
4393
4394 // Jump to the end_block.
4395 block->Jump(loop.end, {});
4396 return block;
4397 }
4398
ParseContinue(const FunctionBlockPtr & block,const py::object & node)4399 FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) {
4400 if (loops_.empty()) {
4401 // Report error if loop context not set for the 'continue' statement.
4402 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected 'continue'.";
4403 }
4404 // Jump to the header of the loop with iterator called.
4405 Loop &loop = loops_.top();
4406 std::vector<AnfNodePtr> args;
4407 if (loop.iterator != nullptr) {
4408 (void)args.emplace_back(loop.iterator);
4409 }
4410 block->set_break_continue_statement_inside();
4411 MS_LOG(DEBUG) << "Inside the block has continue statement, block: " << block->ToString();
4412
4413 block->Jump(loop.header, args);
4414 return block;
4415 }
4416
ParsePass(const FunctionBlockPtr & block,const py::object & node)4417 FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) {
4418 // We just bypass 'pass' statement.
4419 return block;
4420 }
4421
ParseRaise(const FunctionBlockPtr & block,const py::object & node)4422 FunctionBlockPtr Parser::ParseRaise(const FunctionBlockPtr &block, const py::object &node) {
4423 MS_LOG(DEBUG) << "Process raise statement";
4424 TraceGuard trace_guard(GetLocation(node));
4425 MS_EXCEPTION_IF_NULL(block);
4426 auto func_graph = block->func_graph();
4427 MS_EXCEPTION_IF_NULL(func_graph);
4428 py::object exc_ast_node = python_adapter::GetPyObjAttr(node, "exc");
4429 // raise
4430 if (py::isinstance<py::none>(exc_ast_node)) {
4431 CNodePtr raise_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimRaise)});
4432 func_graph->set_return(raise_node);
4433 return block;
4434 }
4435 auto exc_node_inputs = ParseRaiseCall(block, exc_ast_node);
4436 // raise ExceptionType or raise ExceptionType(ExceptionString)
4437 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimRaise)};
4438 (void)inputs.insert(inputs.end(), exc_node_inputs.begin(), exc_node_inputs.end());
4439 CNodePtr raise_node = func_graph->NewCNodeInOrder(inputs);
4440 CNodePtr return_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), raise_node});
4441 func_graph->set_return(return_node);
4442 return block;
4443 }
4444
MakeAssertErrorBlock(const FunctionBlockPtr & block,const py::object & node)4445 FunctionBlockPtr Parser::MakeAssertErrorBlock(const FunctionBlockPtr &block, const py::object &node) {
4446 MS_LOG(DEBUG) << "Process make AssertError block";
4447 MS_EXCEPTION_IF_NULL(block);
4448 const std::string kAssertionError = "AssertionError";
4449 std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimRaise), NewValueNode(kAssertionError)};
4450
4451 py::object msg_node = python_adapter::GetPyObjAttr(node, "msg");
4452 if (!py::isinstance<py::none>(msg_node)) {
4453 AnfNodePtr msg = ParseExprNode(block, msg_node);
4454 (void)inputs.emplace_back(msg);
4455 }
4456 auto str_none = std::make_shared<StringImm>("None");
4457 (void)inputs.emplace_back(NewValueNode(str_none));
4458
4459 auto func_graph = block->func_graph();
4460 CNodePtr raise_node = func_graph->NewCNodeInOrder(inputs);
4461 CNodePtr return_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimReturn), raise_node});
4462 func_graph->set_return(return_node);
4463 return block;
4464 }
4465
4466 // assert expression [, arguments]
4467 // =>
4468 // if not expression:
4469 // raise AssertionError(arguments)
ParseAssert(const FunctionBlockPtr & block,const py::object & node)4470 FunctionBlockPtr Parser::ParseAssert(const FunctionBlockPtr &block, const py::object &node) {
4471 MS_LOG(DEBUG) << "Process ast Assert";
4472 MS_EXCEPTION_IF_NULL(block);
4473 py::object test_node = python_adapter::GetPyObjAttr(node, "test");
4474 AnfNodePtr condition_node = ParseExprNode(block, test_node);
4475
4476 AnfNodePtr bool_node = block->ForceToCondNode(condition_node);
4477 TraceGuard guard(std::make_shared<TraceAssert>(block->func_graph()->debug_info()));
4478 TraceGuard location_guard(GetLocation(node));
4479 FunctionBlockPtr true_block = MakeFunctionBlock();
4480 FunctionBlockPtr false_block = MakeFunctionBlock();
4481 FunctionBlockPtr after_block = MakeFunctionBlock();
4482 MakeConditionBlocks(block, true_block, false_block);
4483
4484 true_block->Jump(after_block, {});
4485 false_block = MakeAssertErrorBlock(false_block, node);
4486 (void)block->ConditionalJump(bool_node, true_block, false_block);
4487
4488 after_block->Mature();
4489 return after_block;
4490 }
4491
ParseWithitem(const FunctionBlockPtr & block,const py::object & node,const AnfNodePtr & context_expr_node)4492 AnfNodePtr Parser::ParseWithitem(const FunctionBlockPtr &block, const py::object &node,
4493 const AnfNodePtr &context_expr_node) {
4494 MS_LOG(DEBUG) << "Process ast Withitem";
4495 MS_EXCEPTION_IF_NULL(block);
4496 // Handle __enter__(self)
4497 std::vector<AnfNodePtr> enter_inputs{NewValueNode(prim::kPrimWithEnter), context_expr_node};
4498 auto func_graph = block->func_graph();
4499 MS_EXCEPTION_IF_NULL(func_graph);
4500 AnfNodePtr enter_node = func_graph->NewCNodeInOrder(enter_inputs);
4501 py::object optional_vars_obj = python_adapter::GetPyObjAttr(node, "optional_vars");
4502 if (!py::isinstance<py::none>(optional_vars_obj)) {
4503 // with Sample() as sample: mean that sample = Sample()
4504 WriteAssignVars(block, optional_vars_obj, enter_node);
4505 }
4506 return enter_node;
4507 }
4508
4509 // with expression [as variable]:
4510 // with-block
ParseWith(const FunctionBlockPtr & block,const py::object & node)4511 FunctionBlockPtr Parser::ParseWith(const FunctionBlockPtr &block, const py::object &node) {
4512 MS_LOG(DEBUG) << "Process ast With";
4513 py::list items_objs = python_adapter::GetPyObjAttr(node, "items");
4514 if (items_objs.empty()) {
4515 MS_LOG(INTERNAL_EXCEPTION) << "Unexpected 'with'.";
4516 }
4517 std::stack<AnfNodePtr> context_expr_nodes;
4518 std::stack<AnfNodePtr> entered_nodes;
4519 for (size_t i = 0; i < items_objs.size(); ++i) {
4520 auto items_obj = items_objs[i];
4521 // with Sample() as sample:
4522 // mean context_expr is Sample(), sample is optional_vars
4523 py::object context_expr_obj = python_adapter::GetPyObjAttr(items_obj, "context_expr");
4524 AnfNodePtr context_expr_node = ParseExprNode(block, context_expr_obj);
4525 context_expr_nodes.push(context_expr_node);
4526 auto enter_node = ParseWithitem(block, items_obj, context_expr_node);
4527 entered_nodes.push(enter_node);
4528 }
4529 MS_EXCEPTION_IF_NULL(block);
4530 auto func_graph = block->func_graph();
4531 MS_EXCEPTION_IF_NULL(func_graph);
4532 py::object body_node = python_adapter::GetPyObjAttr(node, "body");
4533 FunctionBlockPtr body_block = ParseStatements(block, body_node);
4534 auto body_func = body_block->func_graph();
4535 MS_EXCEPTION_IF_NULL(body_func);
4536
4537 while (!context_expr_nodes.empty()) {
4538 auto context_expr_node = context_expr_nodes.top();
4539 auto entered_node = entered_nodes.top();
4540 context_expr_nodes.pop();
4541 entered_nodes.pop();
4542 // Use the depend node to ensure the execution order of enter and exit node.
4543 std::vector<AnfNodePtr> depend_inputs{NewValueNode(prim::kPrimDepend), context_expr_node, entered_node};
4544 context_expr_node = func_graph->NewCNodeInOrder(depend_inputs);
4545 // Handle __exit__(self, type, value, trace)
4546 std::vector<AnfNodePtr> exit_inputs{NewValueNode(prim::kPrimWithExit), context_expr_node};
4547 AnfNodePtr exit_node = func_graph->NewCNodeInOrder(exit_inputs);
4548 block->AddIsolatedNode(exit_node);
4549 }
4550 FunctionBlockPtr after_block = MakeFunctionBlock();
4551 if (body_func->get_return() == nullptr) {
4552 body_block->Jump(after_block, {});
4553 }
4554 after_block->Mature();
4555 return after_block;
4556 }
4557
PrintPhiArgMaps(const std::map<ParameterPtr,std::set<AnfNodePtr>> & phi_to_args,const std::map<AnfNodePtr,std::set<ParameterPtr>> & arg_to_phis)4558 void Parser::PrintPhiArgMaps(const std::map<ParameterPtr, std::set<AnfNodePtr>> &phi_to_args,
4559 const std::map<AnfNodePtr, std::set<ParameterPtr>> &arg_to_phis) {
4560 if (!IS_OUTPUT_ON(mindspore::kDebug)) {
4561 return;
4562 }
4563 std::ostringstream oss;
4564 oss << "==============================Start=============================="
4565 << "\n";
4566 size_t m = 0;
4567 for (const auto &[phi, args] : phi_to_args) {
4568 MS_EXCEPTION_IF_NULL(phi);
4569 oss << "phi[" << m++ << "]: " << phi->DebugString() << "\n";
4570 size_t n = 0;
4571 for (auto &arg : args) {
4572 MS_EXCEPTION_IF_NULL(arg);
4573 oss << " args[" << n++ << "]: " << arg->DebugString() << "\n";
4574 }
4575 }
4576
4577 m = 0;
4578 for (const auto &[arg, phis] : arg_to_phis) {
4579 MS_EXCEPTION_IF_NULL(arg);
4580 oss << "arg[" << m++ << "]: " << arg->DebugString() << "\n";
4581 size_t n = 0;
4582 for (auto &phi : phis) {
4583 MS_EXCEPTION_IF_NULL(phi);
4584 oss << " phis[" << n++ << "]: " << phi->DebugString() << "\n";
4585 }
4586 }
4587 oss << "===============================End==============================="
4588 << "\n";
4589 MS_LOG(DEBUG) << "\n" << oss.str();
4590 }
4591
4592 namespace {
UpdatePhiArgMaps(std::map<ParameterPtr,std::set<AnfNodePtr>> * phi_to_args,std::map<AnfNodePtr,std::set<ParameterPtr>> * arg_to_phis)4593 bool UpdatePhiArgMaps(std::map<ParameterPtr, std::set<AnfNodePtr>> *phi_to_args,
4594 std::map<AnfNodePtr, std::set<ParameterPtr>> *arg_to_phis) {
4595 MS_EXCEPTION_IF_NULL(phi_to_args);
4596 MS_EXCEPTION_IF_NULL(arg_to_phis);
4597 bool phi_arg_updated = false;
4598 auto copy_phi_to_args = *phi_to_args;
4599 for (const auto &[phi, args] : copy_phi_to_args) {
4600 // The phi node has only one arg can be replaced as arg.
4601 if (args.size() != 1) {
4602 continue;
4603 }
4604 auto phi_arg = *args.begin();
4605 MS_EXCEPTION_IF_NULL(phi_arg);
4606 MS_LOG(DEBUG) << "phi: " << phi->DebugString() << ", get one arg: " << phi_arg->DebugString();
4607 // If this phi is a arg of other phi.
4608 auto arg_to_phi_it = arg_to_phis->find(phi);
4609 if (arg_to_phi_it == arg_to_phis->end()) {
4610 continue;
4611 }
4612 // Use the new phi arg as the arg of other phi's arg. Usually other phi is a deeper subgraph's phi node.
4613 auto other_phis = arg_to_phi_it->second;
4614 MS_LOG(DEBUG) << "Find phi as arg of other phi, other phis num: " << other_phis.size();
4615 // Update all other phis' arg from phi to phi_arg.
4616 for (auto &other_phi : other_phis) {
4617 MS_EXCEPTION_IF_NULL(other_phi);
4618 MS_LOG(DEBUG) << "other phi: " << other_phi->DebugString();
4619 phi_arg_updated = true;
4620 // The phi will not be arg of any other phis.Erase map1.
4621 (void)(*phi_to_args)[other_phi].erase(phi);
4622 // If arg is same to the parameter phi, ignore the arg, keep maps don't have self arg.
4623 if (phi_arg == other_phi) {
4624 MS_LOG(DEBUG) << "Get phi arg of phi self.";
4625 continue;
4626 }
4627 MS_LOG(DEBUG) << "phi arg: " << phi_arg->DebugString()
4628 << " as new arg of other phi: " << other_phi->DebugString();
4629 // Replace other phi's arg as this phi's arg, instead of phi. (other_phi , phi) -> (other_phi, phi_arg)
4630 (void)(*phi_to_args)[other_phi].insert(phi_arg);
4631 // Add other phi to the phi_arg's phis set. (phi_arg, {phi_x, }) -> (phi_arg, {phi_x, other_phi})
4632 (void)(*arg_to_phis)[phi_arg].insert(other_phi);
4633 }
4634 MS_LOG(DEBUG) << "Remove phi type arg: " << phi;
4635 // The phi will not be arg of any other phis.Erase map2.
4636 (void)(*arg_to_phis).erase(phi);
4637 }
4638 return phi_arg_updated;
4639 }
4640 } // namespace
4641
UpdatePhiArgMapsRepeatedly(std::map<ParameterPtr,std::set<AnfNodePtr>> * phi_to_args,std::map<AnfNodePtr,std::set<ParameterPtr>> * arg_to_phis)4642 void Parser::UpdatePhiArgMapsRepeatedly(std::map<ParameterPtr, std::set<AnfNodePtr>> *phi_to_args,
4643 std::map<AnfNodePtr, std::set<ParameterPtr>> *arg_to_phis) {
4644 bool phi_arg_updated = true;
4645 size_t loop_count = 0;
4646 while (phi_arg_updated) {
4647 MS_LOG(DEBUG) << "update loop count: " << loop_count++;
4648 PrintPhiArgMaps(*phi_to_args, *arg_to_phis);
4649 phi_arg_updated = UpdatePhiArgMaps(phi_to_args, arg_to_phis);
4650 }
4651 }
4652
CreatePhiArgMaps(std::map<ParameterPtr,std::set<AnfNodePtr>> * phi_to_args,std::map<AnfNodePtr,std::set<ParameterPtr>> * arg_to_phis)4653 void Parser::CreatePhiArgMaps(std::map<ParameterPtr, std::set<AnfNodePtr>> *phi_to_args,
4654 std::map<AnfNodePtr, std::set<ParameterPtr>> *arg_to_phis) {
4655 MS_EXCEPTION_IF_NULL(phi_to_args);
4656 MS_EXCEPTION_IF_NULL(arg_to_phis);
4657 for (FunctionBlockPtr &block : func_block_list_) {
4658 MS_EXCEPTION_IF_NULL(block);
4659 for (const auto &[phi, args] : block->phi_args()) {
4660 // Filtered args exclude the arg pointer equals to phi pointer.
4661 for (const auto &arg : args) {
4662 if (phi == arg) {
4663 continue;
4664 }
4665 (void)(*phi_to_args)[phi].insert(arg);
4666 (void)(*arg_to_phis)[arg].insert(phi);
4667 }
4668 }
4669 }
4670 }
4671
CollectRemovablePhiArgs(const std::map<ParameterPtr,std::set<AnfNodePtr>> & phi_to_args)4672 std::shared_ptr<std::map<ParameterPtr, AnfNodePtr>> Parser::CollectRemovablePhiArgs(
4673 const std::map<ParameterPtr, std::set<AnfNodePtr>> &phi_to_args) {
4674 auto need_remove_phi_args = std::make_shared<std::map<ParameterPtr, AnfNodePtr>>();
4675 for (const auto &[phi, args] : phi_to_args) {
4676 if (args.empty()) {
4677 // phi's arg is phi self.
4678 (*need_remove_phi_args)[phi] = nullptr;
4679 continue;
4680 }
4681 if (args.size() == 1) {
4682 (*need_remove_phi_args)[phi] = *(args.begin());
4683 }
4684 }
4685 if (IS_OUTPUT_ON(mindspore::kDebug)) {
4686 size_t m = 0;
4687 std::ostringstream oss;
4688 oss << "=====================Need removed phis and args====================="
4689 << "\n";
4690 for (const auto &[phi, arg] : *need_remove_phi_args) {
4691 MS_EXCEPTION_IF_NULL(phi);
4692 oss << "phi[" << m << "]: " << phi->DebugString() << "\n";
4693 oss << "arg[" << m++ << "]: " << arg->DebugString() << "\n";
4694 }
4695 MS_LOG(DEBUG) << "\n" << oss.str();
4696 }
4697 return need_remove_phi_args;
4698 }
4699
CalRemovablePhis()4700 std::shared_ptr<std::map<ParameterPtr, AnfNodePtr>> Parser::CalRemovablePhis() {
4701 std::map<ParameterPtr, std::set<AnfNodePtr>> phi_to_args;
4702 std::map<AnfNodePtr, std::set<ParameterPtr>> arg_to_phis;
4703 CreatePhiArgMaps(&phi_to_args, &arg_to_phis);
4704 // Update phi arg maps by phi arg map relations, some phi can be replaced as arg.
4705 UpdatePhiArgMapsRepeatedly(&phi_to_args, &arg_to_phis);
4706 // Collect all one arg phis.
4707 return CollectRemovablePhiArgs(phi_to_args);
4708 }
4709
ReplacePhiAsArg(const std::map<ParameterPtr,AnfNodePtr> & removable_phis,const FuncGraphManagerPtr & manager)4710 void ReplacePhiAsArg(const std::map<ParameterPtr, AnfNodePtr> &removable_phis, const FuncGraphManagerPtr &manager) {
4711 MS_LOG(DEBUG) << "Removable phi size: " << removable_phis.size();
4712 for (const auto &[phi, arg] : removable_phis) {
4713 MS_LOG(DEBUG) << "Removable phi: " << phi->DebugString()
4714 << ", arg: " << (arg == nullptr ? "null" : arg->DebugString());
4715 if (arg != nullptr) {
4716 (void)manager->Replace(phi, arg);
4717 }
4718 }
4719 }
4720
4721 // Remove the removable phi parameter and get the corresponding index.
RemovePhiParametersAndGetRemoveIndex(const FunctionBlockPtr & block,const std::map<ParameterPtr,AnfNodePtr> & removable_phis)4722 HashSet<size_t> RemovePhiParametersAndGetRemoveIndex(const FunctionBlockPtr &block,
4723 const std::map<ParameterPtr, AnfNodePtr> &removable_phis) {
4724 MS_EXCEPTION_IF_NULL(block);
4725 auto func_graph = block->func_graph();
4726 MS_EXCEPTION_IF_NULL(func_graph);
4727 MS_LOG(DEBUG) << "Check removable parameters of block: " << block->ToString();
4728 const auto ¶meters = func_graph->parameters();
4729 std::vector<AnfNodePtr> new_parameters;
4730 // Remove the unnecessary phi parameters.
4731 HashSet<size_t> need_removed_indexes;
4732 for (size_t i = 0; i < parameters.size(); ++i) {
4733 auto parameter_i = parameters[i];
4734 MS_EXCEPTION_IF_NULL(parameter_i);
4735 if (removable_phis.find(parameter_i->cast<ParameterPtr>()) == removable_phis.end()) {
4736 new_parameters.push_back(parameter_i);
4737 continue;
4738 }
4739 // Record all removed indexes.
4740 (void)need_removed_indexes.insert(i);
4741 }
4742 MS_LOG(DEBUG) << "parameters.size(): " << parameters.size()
4743 << ", need_removed_indexes.size(): " << need_removed_indexes.size();
4744 // Only if need_removed_indexes not empty, parameters need be updated.
4745 if (!need_removed_indexes.empty()) {
4746 func_graph->set_parameters(new_parameters);
4747 }
4748 return need_removed_indexes;
4749 }
4750
4751 // If phi parameter is removable, then the corresponding arg should be removed.
RemoveJumpNodeArgs(const FunctionBlockPtr & block,const HashSet<size_t> & need_removed_indexes,const FuncGraphManagerPtr & manager)4752 void RemoveJumpNodeArgs(const FunctionBlockPtr &block, const HashSet<size_t> &need_removed_indexes,
4753 const FuncGraphManagerPtr &manager) {
4754 MS_EXCEPTION_IF_NULL(block);
4755 if (need_removed_indexes.empty()) {
4756 return;
4757 }
4758 for (const auto &prev_block : block->prev_blocks()) {
4759 MS_EXCEPTION_IF_NULL(prev_block);
4760 const auto &jump_node = prev_block->GetJumpNode(block.get());
4761 // Switch call has no jump node.
4762 if (jump_node == nullptr) {
4763 continue;
4764 }
4765 std::vector<AnfNodePtr> new_inputs = {jump_node->input(0)};
4766 for (size_t arg_index = 0; arg_index < jump_node->size() - 1; ++arg_index) {
4767 if (need_removed_indexes.find(arg_index) == need_removed_indexes.end()) {
4768 new_inputs.push_back(jump_node->input(arg_index + 1));
4769 }
4770 }
4771 MS_EXCEPTION_IF_NULL(prev_block->func_graph());
4772 const auto &new_jump_node = prev_block->func_graph()->NewCNodeInOrder(new_inputs);
4773 MS_LOG(DEBUG) << "Replace old jump node: " << jump_node->DebugString()
4774 << " as new jump node: " << new_jump_node->DebugString()
4775 << ", jump node block: " << prev_block->ToString();
4776 (void)manager->Replace(jump_node, new_jump_node);
4777 }
4778 }
4779
RemoveUnnecessaryPhis(const FuncGraphManagerPtr & manager)4780 void Parser::RemoveUnnecessaryPhis(const FuncGraphManagerPtr &manager) {
4781 // Merge all removable phis to one map;
4782 const auto &removable_phis = CalRemovablePhis();
4783 if (removable_phis->empty()) {
4784 return;
4785 }
4786 MS_EXCEPTION_IF_NULL(manager);
4787 // Replace all phi node as arg.
4788 ReplacePhiAsArg(*removable_phis, manager);
4789 // Remove the unnecessary phi parameters.
4790 for (const auto &block : func_block_list_) {
4791 MS_EXCEPTION_IF_NULL(block);
4792 MS_LOG(DEBUG) << "Start remove phi of block: " << block->ToString();
4793 // Remove the unnecessary phi parameters.
4794 const auto &need_removed_indexes = RemovePhiParametersAndGetRemoveIndex(block, *removable_phis);
4795 // Remove all block->prev_blocks()'s jump node corresponding args.
4796 RemoveJumpNodeArgs(block, need_removed_indexes, manager);
4797 }
4798 }
4799
4800 // ParseFunctionAst class code
InitParseAstInfo(const std::string & python_mod_get_parse_method)4801 bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
4802 // Init the type
4803 target_type_ = PARSE_TARGET_UNKNOW;
4804
4805 // Call python parse, get the parser fn
4806 module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
4807 py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_PARSE_METHOD);
4808
4809 // Get the obj type
4810 auto type = data_converter::GetObjType(obj_);
4811 if (type == RESOLVE_TYPE_FUNCTION) {
4812 target_type_ = PARSE_TARGET_FUNCTION;
4813 function_ = obj_;
4814 } else if (type == RESOLVE_TYPE_METHOD) {
4815 // Process the method ,need get the method's self obj
4816 target_type_ = PARSE_TARGET_METHOD;
4817 py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS);
4818 if (py::isinstance<py::none>(method_object)) {
4819 MS_LOG(ERROR) << "Get method's self object instance failed.";
4820 return false;
4821 }
4822 target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
4823 function_ = obj_;
4824 obj_ = method_object;
4825 } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) {
4826 // 'obj' is class instance, get the method to parse.
4827 function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method);
4828 if (py::isinstance<py::none>(function_)) {
4829 MS_LOG(ERROR) << "Get obj method function failed.";
4830 return false;
4831 }
4832 target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
4833 // Check the fn is method
4834 auto obj_type = data_converter::GetObjType(function_);
4835 if (obj_type != RESOLVE_TYPE_METHOD) {
4836 MS_LOG(WARNING) << "Parse method function is invalid.";
4837 return false;
4838 }
4839 } else {
4840 MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type: " << type;
4841 return false;
4842 }
4843
4844 // Call python parse get ast tree
4845 parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method);
4846 py::tuple ast_info = python_adapter::CallPyObjMethod(parser_, "parse");
4847 const size_t ast_info_size = 2;
4848 if (ast_info.size() != ast_info_size) {
4849 MS_INTERNAL_EXCEPTION(NameError) << "ast info size is not equal to 2.";
4850 }
4851 ast_tokens_ = ast_info[0];
4852 ast_tree_ = ast_info[1];
4853
4854 // Get fn name and module
4855 function_module_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_module"));
4856 function_name_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_name"));
4857 function_filename_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "filename"));
4858 function_line_offset_ = py::cast<int64_t>(python_adapter::GetPyObjAttr(parser_, "line_offset"));
4859
4860 return true;
4861 }
4862
4863 // Get ast tree node : is the tree bode list[0]
GetAstNode()4864 py::object ParseFunctionAst::GetAstNode() {
4865 py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body");
4866 py::object ast_node = tree_body[0];
4867 return ast_node;
4868 }
4869
4870 // Get ast tokens node text.
GetAstNodeText(const py::object & node_obj)4871 py::str ParseFunctionAst::GetAstNodeText(const py::object &node_obj) {
4872 return python_adapter::CallPyObjMethod(ast_tokens_, "get_text", node_obj);
4873 }
4874
GetArgs(const py::object & func_node)4875 py::list ParseFunctionAst::GetArgs(const py::object &func_node) {
4876 py::list res = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS, func_node);
4877 return res;
4878 }
4879
GetArgsDefaultValues(const py::object & func_node)4880 py::list ParseFunctionAst::GetArgsDefaultValues(const py::object &func_node) {
4881 py::list res = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node);
4882 return res;
4883 }
4884
GetNodeType(const py::object & node)4885 AstNodeTypePtr ParseFunctionAst::GetNodeType(const py::object &node) {
4886 py::list list_value = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_NODE_TYPE, node);
4887 const size_t list_value_size = 2;
4888 if (list_value.size() < list_value_size) {
4889 MS_LOG(INTERNAL_EXCEPTION) << "The node of python method must has 2 values.";
4890 }
4891 auto node_name = py::cast<std::string>(list_value[0]);
4892 auto type = AstMainType(py::cast<int32_t>(list_value[1]));
4893 return std::make_shared<AstNodeType>(node, node_name, type);
4894 }
4895
GetOpType(const py::object & node)4896 AstSubType ParseFunctionAst::GetOpType(const py::object &node) {
4897 auto op_type = AstSubType(python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_AST_TYPE, node).cast<int32_t>());
4898 return op_type;
4899 }
4900
IsClassMemberOfSelf(const py::object & node)4901 bool ParseFunctionAst::IsClassMemberOfSelf(const py::object &node) {
4902 py::object res = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER_OF_SELF, node);
4903 if (!py::isinstance<py::bool_>(res)) {
4904 MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
4905 return false;
4906 }
4907 return res.cast<bool>();
4908 }
4909
IsClassMemberRecursive(const py::object & node)4910 bool ParseFunctionAst::IsClassMemberRecursive(const py::object &node) {
4911 py::object res = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER_RECURSIVE, node);
4912 if (!py::isinstance<py::bool_>(res)) {
4913 MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
4914 return false;
4915 }
4916 return res.cast<bool>();
4917 }
4918
SetMixedPrecisionFlag(const py::object & obj,const FuncGraphPtr & func_graph)4919 void SetMixedPrecisionFlag(const py::object &obj, const FuncGraphPtr &func_graph) {
4920 MS_EXCEPTION_IF_NULL(func_graph);
4921 if (!py::isinstance<Cell>(obj)) {
4922 return;
4923 }
4924 auto cell = py::cast<CellPtr>(obj);
4925 MS_EXCEPTION_IF_NULL(cell);
4926 auto mixed_type = cell->GetMixedPrecisionType();
4927 if (mixed_type != MixedPrecisionType::kNotSet) {
4928 func_graph->set_flag(GRAPH_FLAG_MIX_PRECISION_FP16, mixed_type == MixedPrecisionType::kFP16);
4929 func_graph->set_flag(GRAPH_FLAG_MIX_PRECISION_FP32, mixed_type == MixedPrecisionType::kFP32);
4930 func_graph->set_flag(GRAPH_FLAG_MIX_PRECISION_BF16, mixed_type == MixedPrecisionType::kBF16);
4931 }
4932 }
4933
UpdateFuncGraphFlags(const py::object & obj,const FuncGraphPtr & func_graph,bool is_construct_function)4934 bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph, bool is_construct_function) {
4935 if (func_graph == nullptr) {
4936 MS_LOG(ERROR) << "FuncGraph is null";
4937 return false;
4938 }
4939
4940 SetMixedPrecisionFlag(obj, func_graph);
4941
4942 if (!py::hasattr(obj, PYTHON_FUNC_GRAPH_FLAGS)) {
4943 MS_LOG(DEBUG) << "No flags";
4944 return true;
4945 }
4946 py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_FUNC_GRAPH_FLAGS);
4947 for (auto &item : flags) {
4948 if (!py::isinstance<py::str>(item.first)) {
4949 MS_LOG(ERROR) << "Type error in flags dict convert";
4950 return false;
4951 }
4952 auto name = py::cast<std::string>(item.first);
4953 if (py::isinstance<py::bool_>(item.second)) {
4954 auto value = py::cast<bool>(item.second);
4955 MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
4956 if (!is_construct_function && name == FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) {
4957 continue;
4958 }
4959 func_graph->set_flag(name, value);
4960 } else if (py::isinstance<py::str>(item.second)) {
4961 auto value = py::cast<std::string>(item.second);
4962 MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
4963 func_graph->set_attr(name, MakeValue(value));
4964 } else {
4965 MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
4966 return false;
4967 }
4968 }
4969 return true;
4970 }
4971
UpdateRecomputeScope(const FuncGraphPtr & func_graph)4972 void UpdateRecomputeScope(const FuncGraphPtr &func_graph) {
4973 MS_EXCEPTION_IF_NULL(func_graph);
4974 auto nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple);
4975
4976 for (const auto &node : nodes) {
4977 MS_EXCEPTION_IF_NULL(node);
4978 const auto &origin_scope_name = node->scope()->name();
4979 if (node->isa<CNode>() && origin_scope_name.compare(0, strlen(kAttrRecompute), kAttrRecompute) != 0) {
4980 std::stringstream scope_name_buffer;
4981 scope_name_buffer << kAttrRecompute << "_" << origin_scope_name;
4982 node->set_scope(std::make_shared<Scope>(scope_name_buffer.str()));
4983 }
4984 }
4985 }
4986
IsSubscriptReferenceType(const py::object & obj)4987 bool Parser::IsSubscriptReferenceType(const py::object &obj) {
4988 py::object slice_node = python_adapter::GetPyObjAttr(obj, "slice");
4989 auto node_type = ast_->GetNodeType(slice_node);
4990 auto node_name = node_type->node_name();
4991 return node_name != "Slice";
4992 }
4993
4994 struct CompileConfigCollectRegister {
CompileConfigCollectRegistermindspore::parse::CompileConfigCollectRegister4995 CompileConfigCollectRegister() noexcept {
4996 CompileConfigManager::set_collect_func([]() {
4997 std::map<std::string, std::string> compile_config;
4998 const auto module_name = "mindspore._extends.parse.compile_config";
4999 py::list config_list = py::cast<py::list>(python_adapter::GetPyFn(module_name, "__all__"));
5000 for (size_t i = 0; i < config_list.size(); ++i) {
5001 auto config_name = config_list[i].cast<std::string>();
5002 auto config = python_adapter::GetPyFn(module_name, config_name);
5003 if (py::isinstance<py::none>(config)) {
5004 MS_LOG(INTERNAL_EXCEPTION) << config_name << " not found in " << module_name << ".";
5005 }
5006 if (py::isinstance<py::int_>(config)) {
5007 compile_config[config_name] = std::to_string(py::cast<int64_t>(config));
5008 } else {
5009 compile_config[config_name] = config.cast<std::string>();
5010 }
5011 }
5012 return compile_config;
5013 });
5014 }
5015 ~CompileConfigCollectRegister() = default;
5016 } compile_config_collect_register;
5017 } // namespace parse
5018 } // namespace mindspore
5019