1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2021 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/parse/parse.h"
20
21 #include <utility>
22 #include <string>
23 #include <memory>
24 #include <unordered_map>
25 #include <sstream>
26 #include <algorithm>
27 #include "pybind_api/pybind_patch.h"
28 #include "pipeline/jit/parse/resolve.h"
29 #include "pipeline/jit/parse/data_converter.h"
30 #include "frontend/operator/ops.h"
31 #include "frontend/operator/composite/composite.h"
32 #include "utils/ms_context.h"
33 #include "debug/trace.h"
34
35 namespace mindspore {
36 namespace parse {
37
ParsePythonCode(const py::object & obj,const std::string & python_mod_get_parse_method)38 FuncGraphPtr ParsePythonCode(const py::object &obj, const std::string &python_mod_get_parse_method) {
39 (void)python_adapter::set_python_scoped();
40
41 if (!obj || py::isinstance<py::none>(obj)) {
42 MS_LOG(ERROR) << "Parse the python code failed, obj is nullptr or none";
43 return nullptr;
44 }
45
46 auto ast = std::make_shared<ParseFunctionAst>(obj);
47 bool success = ast->InitParseAstInfo(python_mod_get_parse_method);
48 if (!success) {
49 MS_LOG(ERROR) << "Parse code to ast tree failed.";
50 return nullptr;
51 }
52
53 auto parser = std::make_shared<Parser>(ast);
54
55 FuncGraphPtr func_graph = parser->ParseFuncGraph();
56 if (func_graph == nullptr) {
57 MS_LOG(ERROR) << "Parse python code failed, errcode = " << parser->errcode();
58 return nullptr;
59 }
60
61 return func_graph;
62 }
63
GetMixedPrecisionTargetType(const FuncGraphPtr & func_graph)64 TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph) {
65 MS_EXCEPTION_IF_NULL(func_graph);
66 if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
67 return kFloat32;
68 } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
69 return kFloat16;
70 } else {
71 return nullptr;
72 }
73 }
74
75 // If any mixed precision flag add a cast node after the parameter node.
GetMixedPrecisionCastHelp(const FuncGraphPtr & func_graph,const AnfNodePtr & param)76 AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m) {
77 MS_EXCEPTION_IF_NULL(func_graph);
78 TypePtr dst_type;
79 if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
80 dst_type = kFloat32;
81 } else if (func_graph->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
82 dst_type = kFloat16;
83 } else {
84 return param;
85 }
86 auto cast_helper = prim::kPrimMixedPrecisionCast;
87 auto cast = func_graph->NewCNodeAfter(param, {NewValueNode(cast_helper), NewValueNode(dst_type), param});
88 return cast;
89 }
90
91 FuncGraphWeakPtr Parser::top_func_graph_ = FuncGraphWeakPtr();
92
Parser(const std::shared_ptr<ParseFunctionAst> & ast)93 Parser::Parser(const std::shared_ptr<ParseFunctionAst> &ast) : ast_(ast) {
94 max_for_loop_count_str_ = common::GetEnv("ENV_FOR_TO_WHILE_LOOP");
95 support_fallback_ = common::GetEnv("ENV_SUPPORT_FALLBACK");
96 errcode_ = PARSE_SUCCESS;
97 BuildMethodMap();
98 }
99
BuildMethodMap()100 void Parser::BuildMethodMap() {
101 stmt_method_map_["Return"] = &Parser::ParseReturn;
102 stmt_method_map_["Expr"] = &Parser::ParseExpr;
103 stmt_method_map_["If"] = &Parser::ParseIf;
104 stmt_method_map_["Assign"] = &Parser::ParseAssign;
105 stmt_method_map_["While"] = &Parser::ParseWhile;
106 stmt_method_map_["For"] = &Parser::ParseFor;
107 stmt_method_map_["FunctionDef"] = &Parser::ParseFunctionDef;
108 stmt_method_map_["AugAssign"] = &Parser::ParseAugAssign;
109 stmt_method_map_["Global"] = &Parser::ParseGlobal;
110 stmt_method_map_["Break"] = &Parser::ParseBreak;
111 stmt_method_map_["Continue"] = &Parser::ParseContinue;
112 stmt_method_map_["Pass"] = &Parser::ParsePass;
113 expr_method_map_["NoneType"] = &Parser::ParseNone;
114 expr_method_map_["BinOp"] = &Parser::ParseBinOp;
115 expr_method_map_["Name"] = &Parser::ParseName;
116 expr_method_map_["Num"] = &Parser::ParseNum;
117 expr_method_map_["Str"] = &Parser::ParseStr;
118 expr_method_map_["Constant"] = &Parser::ParseConstant;
119 expr_method_map_["NameConstant"] = &Parser::ParseNameConstant;
120 expr_method_map_["Call"] = &Parser::ParseCall;
121 expr_method_map_["IfExp"] = &Parser::ParseIfExp;
122 expr_method_map_["Attribute"] = &Parser::ParseAttribute;
123 expr_method_map_["Compare"] = &Parser::ParseCompare;
124 expr_method_map_["BoolOp"] = &Parser::ParseBoolOp;
125 expr_method_map_["Lambda"] = &Parser::ParseLambda;
126 expr_method_map_["Tuple"] = &Parser::ParseTuple;
127 expr_method_map_["List"] = &Parser::ParseList;
128 expr_method_map_["Subscript"] = &Parser::ParseSubscript;
129 expr_method_map_["Slice"] = &Parser::ParseSlice;
130 expr_method_map_["ExtSlice"] = &Parser::ParseExtSlice;
131 expr_method_map_["Index"] = &Parser::ParseIndex;
132 expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp;
133 expr_method_map_["Dict"] = &Parser::ParseDict;
134 expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis;
135 expr_method_map_["ListComp"] = &Parser::ParseListComp;
136 expr_method_map_["GeneratorExp"] = &Parser::ParseListComp; // We treat 'GeneratorExp' the same as 'ListComp'.
137 }
138
UpdateTopFuncGraph(const FuncGraphPtr & func_graph)139 void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); }
140
InitParserEnvironment(const py::object & obj)141 void Parser::InitParserEnvironment(const py::object &obj) {
142 Parser::top_func_graph_ = FuncGraphWeakPtr();
143 ScopeManager::GetInstance().ClearScope();
144 (void)python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GENERATE_SCOPE, obj);
145 }
146
CleanParserResource()147 void Parser::CleanParserResource() {
148 Parser::top_func_graph_ = FuncGraphWeakPtr();
149 ScopeManager::GetInstance().ClearScope();
150 }
151
CheckFuncReturn(const FuncGraphPtr & fn,const std::shared_ptr<ParseFunctionAst> & ast)152 void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseFunctionAst> &ast) {
153 // Check whether the functions referred by this function and itself are missing 'return' statement
154 auto mng = Manage(fn, false);
155 MS_EXCEPTION_IF_NULL(ast);
156 for (const auto &func_graph : mng->func_graphs()) {
157 MS_EXCEPTION_IF_NULL(func_graph);
158 if (func_graph->get_return() != nullptr) {
159 continue;
160 }
161 py::object node = ast->GetAstNode();
162 py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
163 constexpr auto min_list_size = 2;
164 if (ret.size() < min_list_size) {
165 MS_LOG(EXCEPTION) << "list size:" << ret.size() << " is less than 2.";
166 }
167 py::str desc =
168 python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
169 MS_EXCEPTION(TypeError) << "Function must has 'return' statement, but missing in " << desc.cast<std::string>()
170 << ". FuncGraph: " << func_graph->ToString();
171 }
172 }
173
ParseFuncGraph()174 FuncGraphPtr Parser::ParseFuncGraph() {
175 // Get ast FunctionDef node
176 py::object node = ast_->GetAstNode();
177 FunctionBlockPtr fn_block = ParseFunction(node);
178 if (errcode() != PARSE_SUCCESS) {
179 MS_LOG(ERROR) << "Parse function error, code is " << errcode();
180 return nullptr;
181 }
182 RemoveUnnecessaryPhis();
183 MS_EXCEPTION_IF_NULL(fn_block);
184 CheckFuncReturn(fn_block->func_graph(), ast_);
185 return fn_block->func_graph();
186 }
187
GenerateArgsNodeForFunction(const FunctionBlockPtr & block,const py::object & fn_node)188 void Parser::GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
189 py::object func_args = python_adapter::GetPyObjAttr(fn_node, "args");
190 py::object var_arg_node = python_adapter::GetPyObjAttr(func_args, "vararg");
191 MS_EXCEPTION_IF_NULL(block);
192 auto block_fg = block->func_graph();
193 block_fg->set_has_vararg(!py::isinstance<py::none>(var_arg_node));
194
195 py::object kw_arg_node = python_adapter::GetPyObjAttr(func_args, "kwarg");
196 block_fg->set_has_kwarg(!py::isinstance<py::none>(kw_arg_node));
197
198 py::list kwonly_args = python_adapter::GetPyObjAttr(func_args, "kwonlyargs");
199 block_fg->set_kwonlyargs_count(SizeToInt(kwonly_args.size()));
200
201 MS_EXCEPTION_IF_NULL(ast_);
202 py::list args = ast_->GetArgs(fn_node);
203 for (std::size_t i = 0; i < args.size(); i++) {
204 std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
205 if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
206 if (arg_name == "self") {
207 continue;
208 }
209 }
210 TraceGuard guard(GetLocation(args[i]));
211 auto para_node = std::make_shared<Parameter>(block_fg);
212 MS_EXCEPTION_IF_NULL(para_node);
213 para_node->set_name(arg_name);
214 para_node->debug_info()->set_name(arg_name);
215 block_fg->add_parameter(para_node);
216 AnfNodePtr para_after_cast = GetMixedPrecisionCastHelp(block_fg, para_node);
217 block->WriteVariable(arg_name, para_after_cast);
218 MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg_name;
219 }
220 }
221
GenerateArgsDefaultValueForFunction(const FunctionBlockPtr & block,const py::object & fn_node)222 void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node) {
223 MS_EXCEPTION_IF_NULL(block);
224 py::list defaults = ast_->GetArgsDefaultValues(fn_node);
225 py::list args = ast_->GetArgs(fn_node);
226 std::vector<std::string> namelist_for_default_value;
227 std::vector<AnfNodePtr> default_values;
228 for (std::size_t i = 0; i < args.size(); i++) {
229 std::string arg_name = py::cast<std::string>(args[i].attr("arg"));
230 if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
231 if (arg_name == "self") {
232 continue;
233 }
234 }
235
236 namelist_for_default_value.push_back(arg_name);
237 if (i >= defaults.size()) {
238 MS_LOG(EXCEPTION) << "Index:" << i << " out of range:" << defaults.size();
239 }
240 if (py::isinstance<py::none>(defaults[i])) {
241 default_values.push_back(NewValueNode(kNull));
242 } else {
243 default_values.push_back(ParseExprNode(block, defaults[i]));
244 }
245 }
246 block->func_graph()->SetDefaultValues(namelist_for_default_value, default_values);
247 }
248
GetScopeForParseFunction()249 ScopePtr Parser::GetScopeForParseFunction() {
250 ScopePtr scope = ScopeManager::GetInstance().GetCurrentScope();
251 if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
252 py::object scope_str = python_adapter::CallPyFn(PYTHON_MOD_PARSE_MODULE, PYTHON_PARSE_GET_SCOPE_NAME, ast_->obj());
253 if (!py::isinstance<py::none>(scope_str)) {
254 auto scope_name = py::cast<std::string>(scope_str);
255 scope = std::make_shared<Scope>(scope_name);
256 }
257 }
258 return scope;
259 }
260
ParseFunction(const py::object & node,const FunctionBlockPtr & block)261 FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlockPtr &block) {
262 ScopePtr scope = GetScopeForParseFunction();
263 // The node created in the parsefunction context, will inherit the scope created using scope_guard
264 ScopeGuard scope_guard(scope);
265 TraceGuard trace_guard(data_converter::GetObjKey(ast_->obj())[0], GetLocation(node));
266 FunctionBlockPtr func_block = MakeFunctionBlock(*this);
267 if (block != nullptr) {
268 func_block->AddPrevBlock(block);
269 } else {
270 func_graph_ = func_block->func_graph();
271 }
272 func_block->Mature();
273 auto current_fg = func_block->func_graph();
274 auto function_name = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "name"));
275 MS_LOG(DEBUG) << "The function name is " << function_name;
276 current_fg->debug_info()->set_name(function_name);
277 MS_EXCEPTION_IF_NULL(ast_);
278 py::list deco_list = node.attr("decorator_list");
279 if (!deco_list.empty()) {
280 current_fg->debug_info()->set_deco_location(GetLocation(deco_list));
281 }
282 bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg);
283 if (!ast_->obj().is(ast_->function())) {
284 set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg);
285 }
286
287 if (!set_flag) {
288 MS_LOG(ERROR) << "Set flags failed";
289 return nullptr;
290 }
291 GenerateArgsNodeForFunction(func_block, node);
292
293 // When parsing the top graph of construct, save the top graph
294 if (GetTopFuncGraph() == nullptr) {
295 UpdateTopFuncGraph(func_block->func_graph());
296 }
297
298 // Save the function node to block
299 func_block->WriteVariable(function_name, NewValueNode(current_fg));
300
301 py::object funcObj = python_adapter::GetPyObjAttr(node, "body");
302 (void)ParseStatements(func_block, funcObj);
303
304 // Add unused variables as isolate nodes.
305 for (auto &func_block_item : func_block_list_) {
306 MS_EXCEPTION_IF_NULL(func_block_item);
307 if (func_block_item->func_graph()->get_return() != nullptr) {
308 // Find unused variables.
309 func_block_item->FindIsolatedNodes();
310 // Attach all isolated nodes.
311 func_block_item->AttachIsolatedNodesBeforeReturn();
312 }
313 }
314
315 if (current_fg->get_return() == nullptr) {
316 py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
317 py::str desc = python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, node, ret[0], ret[1]);
318 MS_EXCEPTION(TypeError) << "Function must has 'return' statement, but missing in " << desc.cast<std::string>()
319 << ".";
320 }
321 GenerateArgsDefaultValueForFunction(func_block, node);
322 return func_block;
323 }
324
ParseStatements(FunctionBlockPtr block,const py::object & nodes)325 FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr block, const py::object &nodes) {
326 auto node_list = py::cast<py::list>(nodes);
327 size_t count = py::len(node_list);
328 MS_LOG(DEBUG) << "The nodes count is " << count;
329 for (size_t i = 0; i < count; ++i) {
330 MS_LOG(DEBUG) << "Start parse statement[" << i << "]: " << py::str(node_list[i]);
331 auto node = node_list[i];
332 block = ParseStatement(block, node);
333 MS_EXCEPTION_IF_NULL(block);
334 // Insert appropriate depended items for the function block if it has a return node
335 if (block->func_graph()->get_return() != nullptr || block->is_dead_block()) {
336 // If break is not the last expr.
337 if (i != count - 1) {
338 TraceGuard trace_guard(GetLocation(node_list[i + 1]));
339 MS_LOG(EXCEPTION) << "Dead code exist, please remove it.";
340 }
341 // Skip statements after 'return' (or 'break', 'continue').
342 break;
343 }
344 }
345 return block;
346 }
347
ParseStatement(const FunctionBlockPtr & block,const py::object & node)348 FunctionBlockPtr Parser::ParseStatement(const FunctionBlockPtr &block, const py::object &node) {
349 TraceGuard trace_guard(GetLocation(node));
350 auto node_type = ast_->GetNodeType(node);
351
352 // Check the node type
353 AstMainType nodeType = node_type->main_type();
354 if (nodeType != AST_MAIN_TYPE_STMT) {
355 MS_LOG(INFO) << "Node type is error : " << nodeType;
356 return block;
357 }
358 // Call the process function
359 std::string node_name = node_type->node_name();
360 MS_LOG(DEBUG) << "Ast node is " << node_name;
361 if (stmt_method_map_.count(node_name)) {
362 auto stmt_block = (this->*stmt_method_map_[node_name])(block, node);
363 TraceManager::ClearParseOrResolveDebugInfo();
364 return stmt_block;
365 } else {
366 errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
367 MS_LOG(EXCEPTION) << "Unsupported statement '" << node_name
368 << "'.\nMore details please refer to syntax support at https://www.mindspore.cn";
369 }
370 }
371
ParseExprNode(const FunctionBlockPtr & block,const py::object & node)372 AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object &node) {
373 MS_LOG(DEBUG) << "Process ast expr.";
374 TraceGuard trace_guard(GetLocation(node));
375 auto node_type = ast_->GetNodeType(node);
376 // Check the node type
377 AstMainType node_main_type = node_type->main_type();
378 if (node_main_type != AST_MAIN_TYPE_EXPR) {
379 errcode_ = PARSE_NODE_TYPE_NO_MATCH;
380 MS_LOG(EXCEPTION) << "Node type is error : " << node_main_type;
381 }
382 // Call the process function
383 std::string node_name = node_type->node_name();
384 MS_LOG(DEBUG) << "Ast node is " << node_name;
385 if (expr_method_map_.count(node_name)) {
386 auto expr_node = (this->*expr_method_map_[node_name])(block, node);
387 TraceManager::ClearParseOrResolveDebugInfo();
388 return expr_node;
389 } else {
390 errcode_ = PARSE_NODE_METHOD_UNSUPPORTED;
391 MS_LOG(EXCEPTION) << "Unsupported expression '" << node_name
392 << "'.\nMore details please refer to syntax support at https://www.mindspore.cn";
393 }
394 }
395
396 // Process the expr statement and expand it
ParseExpr(const FunctionBlockPtr & block,const py::object & node)397 FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) {
398 MS_LOG(DEBUG) << "Process ast Expr";
399 // Expr only have value, no target
400 py::tuple expand_info = ast_->CallParseModFunction(PYTHON_PARSE_EXPAND_EXPR_STATEMENT, node);
401
402 // Refer python function expand_expr_statement, expand_info is one of the following:
403 // True, expr.value, x
404 // True, expr.value
405 // False, None, None
406 //
407 // Check the expand info result
408 if (expand_info.empty()) {
409 MS_LOG(EXCEPTION) << "Empty expand_info.";
410 }
411 auto is_expand = py::cast<bool>(expand_info[0]);
412 if (is_expand) {
413 // Process the expr statement
414 constexpr size_t expect_size = 2;
415 if (expand_info.size() < expect_size) {
416 MS_LOG(EXCEPTION) << "expand_info size:" << expand_info.size() << " less than " << expect_size << ".";
417 }
418 py::object value_object = expand_info[1];
419 // Make a Expr CNode.
420 AnfNodePtr call_node = ParseExprNode(block, value_object);
421 if (py::len(expand_info) == 2) {
422 // Expression that not assigned to any variable.
423 // This is usually a call with side effects.
424 // e.g.: print(x)
425 // We save it as an isolated node.
426 auto &no_return_node = call_node;
427 MS_LOG(INFO) << "Isolated node found(NoReturn), no_return_node: " << no_return_node->DebugString(2)
428 << ", block: " << block << "/"
429 << (block->func_graph() ? block->func_graph()->ToString() : "FG(Null)")
430 << ", Line: " << trace::GetDebugInfo(no_return_node->debug_info(), "", kSourceLineTipDiscard);
431 block->AddIsolatedNode(no_return_node);
432 } else {
433 // Expand the assign statement,
434 // e.g.: x.append(y) -> x = x.append(y)
435 py::object target_node = expand_info[2];
436 WriteAssignVars(block, target_node, call_node);
437 }
438 }
439 return block;
440 }
441
GetLocation(const py::object & node) const442 LocationPtr Parser::GetLocation(const py::object &node) const {
443 MS_EXCEPTION_IF_NULL(ast_);
444 py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
445 constexpr size_t list_size = 5;
446 if (ret.size() < list_size) {
447 MS_LOG(EXCEPTION) << "List size should not be less than 5.";
448 }
449 const size_t file_name_index = 0;
450 const size_t line_index = 1;
451 const size_t column_index = 2;
452 const size_t line_end_index = 3;
453 const size_t column_end_index = 4;
454 // Refer to Location::Location() for each member of ret: line, column, line_end, column_end.
455 auto location = std::make_shared<Location>(ret[file_name_index].cast<std::string>(), ret[line_index].cast<int64_t>(),
456 ret[column_index].cast<int64_t>(), ret[line_end_index].cast<int64_t>(),
457 ret[column_end_index].cast<int64_t>());
458 return location;
459 }
460
MakeConditionBlocks(const FunctionBlockPtr & pre_block,const FunctionBlockPtr & true_block,const FunctionBlockPtr & false_block)461 void Parser::MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
462 const FunctionBlockPtr &false_block) {
463 MS_EXCEPTION_IF_NULL(true_block);
464 MS_EXCEPTION_IF_NULL(false_block);
465 true_block->AddPrevBlock(pre_block);
466 true_block->Mature();
467
468 false_block->AddPrevBlock(pre_block);
469 false_block->Mature();
470 }
471
ParseReturn(const FunctionBlockPtr & block,const py::object & node)472 FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::object &node) {
473 MS_LOG(DEBUG) << "Process ast return";
474 MS_EXCEPTION_IF_NULL(block);
475 // Create return valuenode
476 AnfNodePtr return_value_node = NewValueNode(prim::kPrimReturn);
477 // Parse the return Statements value
478 py::object value = python_adapter::GetPyObjAttr(node, "value");
479 AnfNodePtr return_expr_node = ParseExprNode(block, value);
480 // Create the cnode
481 auto block_fg = block->func_graph();
482 CNodePtr return_node = block_fg->NewCNodeInOrder({return_value_node, return_expr_node});
483 block_fg->set_return(return_node);
484 return block;
485 }
486
487 // Process binary operators,eg: `a + b`, `a | b`, etc.
ParseBinOp(const FunctionBlockPtr & block,const py::object & node)488 AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &node) {
489 MS_LOG(DEBUG) << "Process ast BinOP";
490
491 py::object left = python_adapter::GetPyObjAttr(node, "left");
492 py::object right = python_adapter::GetPyObjAttr(node, "right");
493 py::object op = python_adapter::GetPyObjAttr(node, "op");
494 // Create left and right ANF node
495 AnfNodePtr left_node = ParseExprNode(block, left);
496 if (left_node == nullptr) {
497 MS_LOG(EXCEPTION) << "DoBinOp process left node failed: " << errcode();
498 }
499 AnfNodePtr right_node = ParseExprNode(block, right);
500 if (right_node == nullptr) {
501 MS_LOG(EXCEPTION) << "DoBinOp process right node failed:" << errcode();
502 }
503 // Resolve the op
504 MS_EXCEPTION_IF_NULL(block);
505 AnfNodePtr op_node = block->MakeResolveAstOp(op);
506 // Create apply node
507 MS_EXCEPTION_IF_NULL(block->func_graph());
508 return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
509 }
510
ParseName(const FunctionBlockPtr & block,const py::object & node)511 AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) {
512 MS_LOG(DEBUG) << "Process ast Name";
513 auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "id"));
514 MS_LOG(DEBUG) << "The Name id is " << name_id;
515 MS_EXCEPTION_IF_NULL(block);
516 if (block->IsGlobalVar(name_id)) {
517 MS_LOG(DEBUG) << "name_id: " << name_id;
518 return block->MakeResolveSymbol(name_id);
519 }
520 return block->ReadVariable(name_id);
521 }
522
ParseNone(const FunctionBlockPtr &,const py::object &)523 AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) {
524 MS_LOG(DEBUG) << "Process ast NoneType";
525 return NewValueNode(kNone);
526 }
527
ParseEllipsis(const FunctionBlockPtr &,const py::object &)528 AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) {
529 MS_LOG(DEBUG) << "Process ast Ellipsis";
530 return NewValueNode(kEllipsis);
531 }
532
ParseNum(const FunctionBlockPtr &,const py::object & node)533 AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) {
534 MS_LOG(DEBUG) << "Process ast Num";
535 py::object obj = python_adapter::GetPyObjAttr(node, "n");
536 if (py::isinstance<py::int_>(obj)) {
537 MS_LOG(INFO) << "The Num is int64_t:" << (std::string)py::str(obj);
538 auto data = py::cast<int64_t>(obj);
539 return NewValueNode(data);
540 } else if (py::isinstance<py::float_>(obj)) {
541 MS_LOG(INFO) << "The Num is float:" << (std::string)py::str(obj);
542 auto data = py::cast<float>(obj);
543 return NewValueNode(data);
544 } else {
545 // no else actually
546 errcode_ = PARSE_NODE_TYPE_UNKNOWN;
547 MS_EXCEPTION(TypeError) << "Only support 'Number' type of 'int` and 'float', but got type: " << obj.get_type()
548 << " Value:" << py::str(obj);
549 }
550 }
551
ParseStr(const FunctionBlockPtr &,const py::object & node)552 AnfNodePtr Parser::ParseStr(const FunctionBlockPtr &, const py::object &node) {
553 MS_LOG(DEBUG) << "Process ast Str";
554 auto str_s = py::cast<std::string>(python_adapter::GetPyObjAttr(node, "s"));
555 return NewValueNode(str_s);
556 }
557
ParseConstant(const FunctionBlockPtr &,const py::object & node)558 AnfNodePtr Parser::ParseConstant(const FunctionBlockPtr &, const py::object &node) {
559 MS_LOG(DEBUG) << "Process ast Constant";
560 py::object obj = python_adapter::GetPyObjAttr(node, "value");
561 if (py::isinstance<py::bool_>(obj)) {
562 MS_LOG(INFO) << "The Constant is bool:" << (std::string)py::str(obj);
563 return NewValueNode(py::cast<bool>(obj));
564 } else if (py::isinstance<py::int_>(obj)) {
565 MS_LOG(INFO) << "The Constant is int64_t:" << (std::string)py::str(obj);
566 return NewValueNode(py::cast<int64_t>(obj));
567 } else if (py::isinstance<py::float_>(obj)) {
568 MS_LOG(INFO) << "The Constant is float:" << (std::string)py::str(obj);
569 return NewValueNode(py::cast<float>(obj));
570 } else if (py::isinstance<py::str>(obj)) {
571 MS_LOG(INFO) << "The Constant is string:" << (std::string)py::str(obj);
572 return NewValueNode(py::cast<std::string>(obj));
573 } else if (py::isinstance<py::none>(obj)) {
574 MS_LOG(INFO) << "The Constant is none:" << (std::string)py::str(obj);
575 return NewValueNode(kNone);
576 } else if (py::isinstance<py::ellipsis>(obj)) {
577 MS_LOG(INFO) << "The Constance is ellipsis:" << (std::string)py::str(obj);
578 return NewValueNode(kEllipsis);
579 } else {
580 // no else actually
581 MS_EXCEPTION(TypeError) << "Unsupported Constant type : " << (std::string)py::str(obj);
582 }
583 }
584
ParseNameConstant(const FunctionBlockPtr &,const py::object & node)585 AnfNodePtr Parser::ParseNameConstant(const FunctionBlockPtr &, const py::object &node) {
586 MS_LOG(DEBUG) << "Process ast NameConstant";
587 py::object obj = python_adapter::GetPyObjAttr(node, "value");
588 if (py::isinstance<py::bool_>(obj)) {
589 MS_LOG(INFO) << "The NameConstant is bool:" << (std::string)py::str(obj);
590 auto data = py::cast<bool>(obj);
591 return NewValueNode(data);
592 } else if (py::isinstance<py::none>(obj)) {
593 MS_LOG(INFO) << "The NameConstant is none:" << (std::string)py::str(obj);
594 return NewValueNode(kNone);
595 }
596 // no else actually
597 errcode_ = PARSE_NODE_TYPE_UNKNOWN;
598 MS_LOG(EXCEPTION) << "Unsupported NameConstant type: " << (std::string)py::str(obj);
599 }
600
GenerateMakeTuple(const FunctionBlockPtr & block,const std::vector<AnfNodePtr> & element_nodes)601 AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes) {
602 MS_EXCEPTION_IF_NULL(block);
603 AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
604 std::vector<AnfNodePtr> make_tuple_nodes;
605 make_tuple_nodes.push_back(make_tuple_op);
606 (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes),
607 [](AnfNodePtr arg) -> AnfNodePtr { return arg; });
608 return block->func_graph()->NewCNodeInOrder(make_tuple_nodes);
609 }
610
ParseSuper(const FunctionBlockPtr & block,const py::list & args)611 AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) {
612 MS_EXCEPTION_IF_NULL(block);
613 py::object father_class;
614 const size_t expect_args_size = 2;
615 if (args.empty()) {
616 father_class = py::none();
617 } else if (args.size() == expect_args_size) {
618 father_class = args[0];
619 auto arg_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[1])));
620 if (arg_type != AST_SUB_TYPE_NAME || py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) != "self") {
621 MS_EXCEPTION(ArgumentError) << "Argument 2 of 'super()' must be 'self', but got '"
622 << py::cast<std::string>(python_adapter::GetPyObjAttr(args[1], "id")) << "'.";
623 }
624 } else {
625 MS_EXCEPTION(ArgumentError) << "Arguments number of 'super()' should be 0 or 2, but got " << args.size() << ".";
626 }
627 py::object target_class_instance = ast_->CallParserObjMethod(PYTHON_PARSE_ANALYZE_SUPER, father_class, ast_->obj());
628 py::object namespace_var = ast_->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, target_class_instance);
629 NameSpacePtr name_space = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
630 SymbolPtr symbol = std::make_shared<Symbol>("namespace");
631 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
632 return block->MakeResolve(name_space, symbol);
633 }
634
635 // Process function call, eg : f1(x, y) ...
ParseCall(const FunctionBlockPtr & block,const py::object & node)636 AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &node) {
637 MS_LOG(DEBUG) << "Process ast Call";
638 // Process function call
639 py::object function_ast_node = python_adapter::GetPyObjAttr(node, "func");
640 py::list args = python_adapter::GetPyObjAttr(node, "args");
641
642 auto arg_type =
643 AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, function_ast_node)));
644 if (arg_type == AST_SUB_TYPE_NAME) {
645 auto name_id = py::cast<std::string>(python_adapter::GetPyObjAttr(function_ast_node, "id"));
646 if (name_id == "super") {
647 return ParseSuper(block, args);
648 }
649 }
650
651 AnfNodePtr call_function_node = ParseExprNode(block, function_ast_node);
652 // Function call arguments should be passed in as groups and unpacked later using unpack call
653 std::vector<AnfNodePtr> packed_arguments;
654 std::vector<AnfNodePtr> group_arguments;
655
656 bool need_unpack_args = ParseArgsInCall(block, args, &packed_arguments, &group_arguments);
657 bool need_unpack_keywords = ParseKeywordsInCall(block, node, &packed_arguments);
658 // If there is stared or keyword argument, unpack may be needed
659 bool need_unpack = need_unpack_args || need_unpack_keywords;
660
661 auto call_cnode = GenerateAnfNodeForCall(block, call_function_node, packed_arguments, group_arguments, need_unpack);
662 if (call_function_node->interpret()) {
663 call_cnode->set_interpret(true);
664 }
665 return call_cnode;
666 }
667
MakeUnpackCall(const FuncGraphPtr & func_graph,const AnfNodePtr & call_function_node,const std::vector<AnfNodePtr> & packed_arguments)668 CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_node,
669 const std::vector<AnfNodePtr> &packed_arguments) {
670 MS_EXCEPTION_IF_NULL(func_graph);
671 std::vector<AnfNodePtr> unpack_call_nodes;
672 auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
673 unpack_call_nodes.push_back(unpack_call_op);
674 unpack_call_nodes.push_back(call_function_node);
675 (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
676 [](AnfNodePtr node) -> AnfNodePtr { return node; });
677 CNodePtr unpack_call = func_graph->NewCNodeInOrder(unpack_call_nodes);
678 return unpack_call;
679 }
680
GenerateAnfNodeForCall(const FunctionBlockPtr & block,const AnfNodePtr & call_function_node,const std::vector<AnfNodePtr> & packed_arguments,const std::vector<AnfNodePtr> & group_arguments,bool need_unpack) const681 AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node,
682 const std::vector<AnfNodePtr> &packed_arguments,
683 const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
684 // If there is keyword arguments or starred, using an unpack_call op to unpack the argument
685 MS_EXCEPTION_IF_NULL(block);
686 if (need_unpack) {
687 return MakeUnpackCall(block->func_graph(), call_function_node, packed_arguments);
688 }
689 // else there is no keyword arguments and starred, parsed as normal arguments without unpack
690 std::vector<AnfNodePtr> func_call_nodes;
691 func_call_nodes.push_back(call_function_node);
692 (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes),
693 [](AnfNodePtr node) -> AnfNodePtr { return node; });
694 CNodePtr call_anf_node = block->func_graph()->NewCNodeInOrder(func_call_nodes);
695 return call_anf_node;
696 }
697
ParseArgsInCall(const FunctionBlockPtr & block,const py::list & args,std::vector<AnfNodePtr> * packed_arguments,std::vector<AnfNodePtr> * group_arguments)698 bool Parser::ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args,
699 std::vector<AnfNodePtr> *packed_arguments, std::vector<AnfNodePtr> *group_arguments) {
700 MS_EXCEPTION_IF_NULL(packed_arguments);
701 MS_EXCEPTION_IF_NULL(group_arguments);
702 bool need_unpack = false;
703 for (size_t i = 0; i < args.size(); i++) {
704 auto arg_node = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, args[i])));
705 if (arg_node == AST_SUB_TYPE_STARRED) {
706 if (!group_arguments->empty()) {
707 packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
708 }
709 packed_arguments->push_back(ParseExprNode(block, python_adapter::GetPyObjAttr(args[i], "value")));
710 group_arguments->clear();
711 need_unpack = true;
712 } else {
713 auto node = ParseExprNode(block, args[i]);
714 node = HandleInterpret(block, node, args[i]);
715 group_arguments->push_back(node);
716 }
717 }
718 if (!group_arguments->empty()) {
719 packed_arguments->push_back(GenerateMakeTuple(block, *group_arguments));
720 }
721 return need_unpack;
722 }
723
ParseKeywordsInCall(const FunctionBlockPtr & block,const py::object & node,std::vector<AnfNodePtr> * packed_arguments)724 bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
725 std::vector<AnfNodePtr> *packed_arguments) {
726 bool need_unpack = false;
727 py::list keywords = python_adapter::GetPyObjAttr(node, "keywords");
728 if (!keywords.empty()) {
729 MS_EXCEPTION_IF_NULL(block);
730 MS_EXCEPTION_IF_NULL(packed_arguments);
731 need_unpack = true;
732 std::vector<AnfNodePtr> keys;
733 std::vector<AnfNodePtr> values;
734 for (size_t index = 0; index < keywords.size(); index++) {
735 auto kw_key = python_adapter::GetPyObjAttr(keywords[index], "arg");
736 auto kw_value = python_adapter::GetPyObjAttr(keywords[index], "value");
737 if (py::isinstance<py::none>(kw_key)) {
738 packed_arguments->push_back(ParseExprNode(block, kw_value));
739 } else {
740 auto kw_key_c = kw_key.cast<std::string>();
741 keys.push_back(NewValueNode(kw_key_c));
742 auto ret_node = ParseExprNode(block, kw_value);
743 ret_node = HandleInterpret(block, ret_node, kw_value);
744 values.push_back(ret_node);
745 }
746 }
747 auto keys_tuple = GenerateMakeTuple(block, keys);
748 auto values_tuple = GenerateMakeTuple(block, values);
749 auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
750 std::vector<AnfNodePtr> make_dict_nodes;
751 make_dict_nodes.push_back(make_dict_op);
752 make_dict_nodes.push_back(keys_tuple);
753 make_dict_nodes.push_back(values_tuple);
754 packed_arguments->push_back(block->func_graph()->NewCNodeInOrder(make_dict_nodes));
755 }
756 return need_unpack;
757 }
758
759 // Process call attributes of class type define, eg: x.y()
ParseAttribute(const FunctionBlockPtr & block,const py::object & node)760 AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::object &node) {
761 MS_LOG(DEBUG) << "Process ast Attribute";
762 // Process class value, eg: self.xx
763 if (ast_->target_type() == PARSE_TARGET_OBJECT_INSTANCE) {
764 if (ast_->IsClassMember(node)) {
765 std::string var_name = "self.";
766 std::string attr_name = node.attr("attr").cast<std::string>();
767 (void)var_name.append(attr_name);
768 auto attr_obj = ast()->obj().attr(attr_name.c_str());
769 MS_EXCEPTION_IF_NULL(block);
770 if (py::hasattr(ast()->obj(), attr_name.c_str()) &&
771 (py::hasattr(attr_obj, PYTHON_PRIMITIVE_FLAG) || py::isinstance<py::int_>(attr_obj) ||
772 py::isinstance<py::float_>(attr_obj) || py::isinstance<py::bool_>(attr_obj) ||
773 py::isinstance<py::str>(attr_obj) || data_converter::IsCellInstance(attr_obj))) {
774 MS_LOG(DEBUG) << "var_name: " << var_name;
775 return block->MakeResolveSymbol(var_name);
776 } else {
777 return block->ReadVariable(var_name);
778 }
779 }
780 }
781
782 // Process the get attr
783 // Use the Primitive replace the operation resolve node (getattr),
784 // because the getattr will eventually be converted to Primitive node
785 AnfNodePtr op_node = NewValueNode(prim::kPrimGetAttr);
786
787 // Process the attr body
788 py::object value_body = python_adapter::GetPyObjAttr(node, "value");
789 AnfNodePtr value_node = ParseExprNode(block, value_body);
790 if (value_node == nullptr) {
791 MS_LOG(EXCEPTION) << "Parse attribute failed";
792 }
793
794 // Process the node attr
795 auto attr_str = python_adapter::GetPyObjAttr(node, "attr").cast<std::string>();
796 MS_LOG(DEBUG) << "Attr = " << attr_str;
797 AnfNodePtr attr_node = nullptr;
798 {
799 TraceGuard guard(GetLocation(python_adapter::GetPyObjAttr(node, "attr")));
800 attr_node = NewValueNode(attr_str);
801 }
802
803 // Create the apply node
804 auto attr_cnode = block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node});
805 if (value_node->interpret()) {
806 attr_cnode->set_interpret(true);
807 }
808 return attr_cnode;
809 }
810
811 // Process comparison expression : a == b. a > b etc.
ParseCompare(const FunctionBlockPtr & block,const py::object & node)812 AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object &node) {
813 MS_LOG(DEBUG) << "Process ast Compare";
814 TraceGuard guard(GetLocation(node));
815
816 // For python comparison ,there may be if x>y>5 ,
817 // Which there is two ops , but we only support one now
818 py::list ops = python_adapter::GetPyObjAttr(node, "ops");
819 if (ops.size() != MAX_COMPARISON_OPS_SUPPORTED) {
820 MS_EXCEPTION(NotSupportError) << "Only support comparison with 1 operator, but got " << ops.size() << ", which is "
821 << py::str(ops);
822 }
823
824 py::object left = python_adapter::GetPyObjAttr(node, "left");
825 py::list comparators = python_adapter::GetPyObjAttr(node, "comparators");
826 if (comparators.empty()) {
827 MS_LOG(EXCEPTION) << "Comparators can't be empty.";
828 }
829 AnfNodePtr left_node = ParseExprNode(block, left);
830 AnfNodePtr right_node = ParseExprNode(block, comparators[0]);
831
832 MS_EXCEPTION_IF_NULL(block);
833 AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]);
834 return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node});
835 }
836
ProcessBoolOpValueList(const FunctionBlockPtr & block,const py::list & value_list,AstSubType mode)837 AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) {
838 // If there is only one bool op now
839 MS_EXCEPTION_IF_NULL(block);
840 if (value_list.empty()) {
841 MS_LOG(EXCEPTION) << "value list is empty.";
842 }
843 if (value_list.size() == 1) {
844 AnfNodePtr first_node = ParseExprNode(block, value_list[0]);
845 return first_node;
846 } else {
847 py::object first = value_list[0];
848 py::list rest;
849 for (size_t i = 1; i < value_list.size(); i++) {
850 rest.append(value_list[i]);
851 }
852 FunctionBlockPtr true_block = nullptr;
853 FunctionBlockPtr false_block = nullptr;
854 auto block_fg = block->func_graph();
855 {
856 TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block_fg->debug_info()));
857 true_block = MakeFunctionBlock(*this);
858 }
859 {
860 TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block_fg->debug_info()));
861 false_block = MakeFunctionBlock(*this);
862 }
863 MakeConditionBlocks(block, true_block, false_block);
864 FunctionBlockPtr b1, b2;
865
866 // If it is and, we need to process the rest nodes;
867 // If it is or, we continue to next
868 if (mode == AST_SUB_TYPE_AND) {
869 b1 = true_block;
870 b2 = false_block;
871 } else if (mode == AST_SUB_TYPE_OR) {
872 b2 = true_block;
873 b1 = false_block;
874 } else {
875 MS_LOG(ERROR) << "Not supported mode: " << mode;
876 return nullptr;
877 }
878 AnfNodePtr test_node = ParseExprNode(block, first);
879 AnfNodePtr rest_node = ProcessBoolOpValueList(b1, rest, mode);
880 b1->func_graph()->set_output(rest_node);
881 b2->func_graph()->set_output(test_node);
882
883 auto cond_node = block->ForceToBoolNode(test_node);
884 auto switch_app =
885 block_fg->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()),
886 NewValueNode(false_block->func_graph())});
887
888 std::vector<AnfNodePtr> call_graph_nodes{switch_app};
889 auto switch_app_call = block_fg->NewCNodeInOrder(call_graph_nodes);
890 return switch_app_call;
891 }
892 }
893
894 // Process comparison expression : a and b. a or b .
ParseBoolOp(const FunctionBlockPtr & block,const py::object & node)895 AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object &node) {
896 MS_LOG(DEBUG) << "Process ast BoolOp";
897 py::object op_node = python_adapter::GetPyObjAttr(node, "op");
898 AstSubType op_type = ast_->GetOpType(op_node);
899 if (op_type == AST_SUB_TYPE_UNKNOWN) {
900 MS_LOG(EXCEPTION) << "ProcessBoolOp, got unknown op type";
901 }
902 py::list op_values = python_adapter::GetPyObjAttr(node, "values");
903 return ProcessBoolOpValueList(block, op_values, op_type);
904 }
905
906 // Process a function def
ParseFunctionDef(const FunctionBlockPtr & block,const py::object & node)907 FunctionBlockPtr Parser::ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node) {
908 MS_LOG(DEBUG) << "Process ast FunctionDef";
909 FunctionBlockPtr function_block = ParseFunction(node, block);
910 MS_EXCEPTION_IF_NULL(function_block);
911
912 // Get function name
913 py::str name = python_adapter::GetPyObjAttr(node, "name");
914 std::string function_name = name;
915 ValueNodePtr valuenode_graph = NewValueNode(function_block->func_graph());
916 block->WriteVariable(function_name, valuenode_graph);
917 return block;
918 }
919
920 // Process a lambda expression . like lambda x,y: x + y
ParseLambda(const FunctionBlockPtr & block,const py::object & node)921 AnfNodePtr Parser::ParseLambda(const FunctionBlockPtr &block, const py::object &node) {
922 MS_LOG(DEBUG) << "Process ast Lambda";
923 FunctionBlockPtr func_block = MakeFunctionBlock(*this);
924 func_block->AddPrevBlock(block);
925 func_block->Mature();
926
927 // Get lambda args
928 py::list args = ast_->GetArgs(node);
929 auto block_fg = func_block->func_graph();
930 for (std::size_t i = 0; i < args.size(); i++) {
931 std::string arg = py::cast<std::string>(args[i].attr("arg"));
932 TraceGuard guard(GetLocation(args[i]));
933 auto para_node = std::make_shared<Parameter>(block_fg);
934 para_node->debug_info()->set_name(arg);
935 block_fg->add_parameter(para_node);
936 func_block->WriteVariable(arg, para_node);
937 MS_LOG(DEBUG) << "The arg[" << i << "] is " << arg;
938 }
939
940 py::object body_node = python_adapter::GetPyObjAttr(node, "body");
941 AnfNodePtr lambda_body_node = ParseExprNode(func_block, body_node);
942 block_fg->set_output(lambda_body_node);
943 ValueNodePtr const_graph = NewValueNode(block_fg);
944 return const_graph;
945 }
946
947 // Process a tuple
ParseTuple(const FunctionBlockPtr & block,const py::object & node)948 AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &node) {
949 MS_LOG(DEBUG) << "Process ast Tuple";
950 MS_EXCEPTION_IF_NULL(block);
951 py::tuple elts = python_adapter::GetPyObjAttr(node, "elts");
952 if (elts.empty()) {
953 auto empty_tuple = std::vector<ValuePtr>();
954 return NewValueNode(std::make_shared<ValueTuple>(empty_tuple));
955 }
956
957 std::vector<AnfNodePtr> tuple_vec;
958 AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
959 tuple_vec.emplace_back(make_tuple_op);
960 for (size_t i = 0; i < elts.size(); i++) {
961 AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
962 tuple_vec.emplace_back(node_ptr);
963 }
964 CNodePtr tuple_app = block->func_graph()->NewCNodeInOrder(tuple_vec);
965 return tuple_app;
966 }
967
968 // Process a list
ParseList(const FunctionBlockPtr & block,const py::object & node)969 AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &node) {
970 MS_LOG(DEBUG) << "Process ast List";
971 MS_EXCEPTION_IF_NULL(block);
972 py::list elts = python_adapter::GetPyObjAttr(node, "elts");
973 if (elts.empty()) {
974 auto empty_list = std::vector<ValuePtr>();
975 return NewValueNode(std::make_shared<ValueList>(empty_list));
976 }
977
978 std::vector<AnfNodePtr> list_vec;
979 AnfNodePtr make_list_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKELIST);
980 list_vec.emplace_back(make_list_op);
981 for (size_t i = 0; i < elts.size(); i++) {
982 AnfNodePtr node_ptr = ParseExprNode(block, elts[i]);
983 list_vec.emplace_back(node_ptr);
984 }
985 CNodePtr list_app = block->func_graph()->NewCNodeInOrder(list_vec);
986 return list_app;
987 }
988
989 // Process a subscript, such as x[y] , node expressed as value[slice]
ParseSubscript(const FunctionBlockPtr & block,const py::object & node)990 AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::object &node) {
991 MS_LOG(DEBUG) << "Process ast Subscript";
992 MS_EXCEPTION_IF_NULL(block);
993 AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
994 py::object value_node = python_adapter::GetPyObjAttr(node, "value");
995 py::object slice_node = python_adapter::GetPyObjAttr(node, "slice");
996 AnfNodePtr value = ParseExprNode(block, value_node);
997 AnfNodePtr slice = ParseExprNode(block, slice_node);
998 return block->func_graph()->NewCNodeInOrder({op_getitem, value, slice});
999 }
1000
1001 // Process a slice, get the slice value
ParseSlice(const FunctionBlockPtr & block,const py::object & node)1002 AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &node) {
1003 MS_LOG(DEBUG) << "Process ast Slice";
1004 MS_EXCEPTION_IF_NULL(block);
1005 AnfNodePtr op_makeslice = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKESLICE);
1006 py::object start = python_adapter::GetPyObjAttr(node, "lower");
1007 py::object stop = python_adapter::GetPyObjAttr(node, "upper");
1008 py::object step = python_adapter::GetPyObjAttr(node, "step");
1009 AnfNodePtr start_node = ParseExprNode(block, start);
1010 AnfNodePtr stop_node = ParseExprNode(block, stop);
1011 AnfNodePtr step_node = ParseExprNode(block, step);
1012 return block->func_graph()->NewCNodeInOrder({op_makeslice, start_node, stop_node, step_node});
1013 }
1014
1015 // Process a extslice
ParseExtSlice(const FunctionBlockPtr & block,const py::object & node)1016 AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object &node) {
1017 MS_LOG(DEBUG) << "Process ast ExtSlice";
1018 MS_EXCEPTION_IF_NULL(block);
1019 AnfNodePtr make_tuple_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKETUPLE);
1020 py::tuple slice_tuple = python_adapter::GetPyObjAttr(node, "dims");
1021
1022 std::vector<AnfNodePtr> node_vec;
1023 node_vec.emplace_back(make_tuple_op);
1024 for (size_t i = 0; i < slice_tuple.size(); i++) {
1025 AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]);
1026 node_vec.emplace_back(node_ptr);
1027 }
1028 CNodePtr tuple_conde = block->func_graph()->NewCNodeInOrder(node_vec);
1029 return tuple_conde;
1030 }
1031
1032 // Process a index, get the index number
ParseIndex(const FunctionBlockPtr & block,const py::object & node)1033 AnfNodePtr Parser::ParseIndex(const FunctionBlockPtr &block, const py::object &node) {
1034 MS_LOG(DEBUG) << "Process ast Index";
1035 py::object value_node = python_adapter::GetPyObjAttr(node, "value");
1036 return ParseExprNode(block, value_node);
1037 }
1038
1039 // Process a UnaryOp, +a, -b
ParseUnaryOp(const FunctionBlockPtr & block,const py::object & node)1040 AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node) {
1041 MS_LOG(DEBUG) << "Process ast UnaryOp";
1042 py::object op = python_adapter::GetPyObjAttr(node, "op");
1043
1044 MS_EXCEPTION_IF_NULL(block);
1045 // Resolve the op
1046 AnfNodePtr op_node = block->MakeResolveAstOp(op);
1047
1048 py::object operand = python_adapter::GetPyObjAttr(node, "operand");
1049 AnfNodePtr operand_node = ParseExprNode(block, operand);
1050 return block->func_graph()->NewCNodeInOrder({op_node, operand_node});
1051 }
1052
1053 // Process a dict ast node expression
ParseDictByKeysAndValues(const FunctionBlockPtr & block,const std::vector<AnfNodePtr> & key_nodes,const std::vector<AnfNodePtr> & value_nodes)1054 AnfNodePtr Parser::ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
1055 const std::vector<AnfNodePtr> &value_nodes) {
1056 auto keys_tuple = GenerateMakeTuple(block, key_nodes);
1057 auto values_tuple = GenerateMakeTuple(block, value_nodes);
1058 MS_EXCEPTION_IF_NULL(block);
1059 auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT);
1060 return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple});
1061 }
1062
ParseDict(const FunctionBlockPtr & block,const py::object & node)1063 AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &node) {
1064 MS_LOG(DEBUG) << "Process ast Dict";
1065 py::list keys = node.attr("keys");
1066 py::list values = node.attr("values");
1067 std::vector<AnfNodePtr> key_nodes;
1068 std::vector<AnfNodePtr> value_nodes;
1069 for (size_t i = 0; i < keys.size(); i++) {
1070 key_nodes.push_back(ParseExprNode(block, keys[i]));
1071 value_nodes.push_back(ParseExprNode(block, values[i]));
1072 }
1073 return ParseDictByKeysAndValues(block, key_nodes, value_nodes);
1074 }
1075
1076 // Process a augment assign such as a += b or mat[stride_slice] += b.
ParseAugAssign(const FunctionBlockPtr & block,const py::object & node)1077 FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py::object &node) {
1078 MS_LOG(DEBUG) << "Process ast AugAssign";
1079 MS_EXCEPTION_IF_NULL(block);
1080 MS_EXCEPTION_IF_NULL(ast_);
1081
1082 py::object target_object = python_adapter::GetPyObjAttr(node, "target");
1083 py::object op_object = python_adapter::GetPyObjAttr(node, "op");
1084 py::object value_object = python_adapter::GetPyObjAttr(node, "value");
1085 AnfNodePtr target_node = nullptr;
1086 AnfNodePtr op_node = block->MakeResolveAstOp(op_object);
1087 AnfNodePtr value_node = ParseExprNode(block, value_object);
1088 auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object)));
1089
1090 if (ast_type == AST_SUB_TYPE_NAME) {
1091 target_node = ParseName(block, target_object);
1092 } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
1093 target_node = ParseSubscript(block, target_object);
1094 } else if (ast_->IsClassMember(target_object)) {
1095 target_node = ParseAttribute(block, target_object);
1096 } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
1097 TraceGuard(GetLocation(target_object));
1098 MS_EXCEPTION(TypeError) << "Only support augassign to attribute of self, but got attribute of "
1099 << py::str(target_object.attr("value").attr("id")) << ".\n"
1100 << "More details please refer to syntax support at https://www.mindspore.cn";
1101 } else {
1102 TraceGuard(GetLocation(target_object));
1103 MS_EXCEPTION(TypeError) << "Only supported augassign to attribute of self, variable and index value, but got "
1104 << target_object.get_type()
1105 << ".\nMore details please refer to syntax support at https://www.mindspore.cn";
1106 }
1107
1108 if (target_node == nullptr) {
1109 MS_LOG(EXCEPTION) << "Can not get target node ";
1110 }
1111 CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node});
1112 WriteAssignVars(block, target_object, augassign_app);
1113 return block;
1114 }
1115 // Process global declaration such as 'global x';
ParseGlobal(const FunctionBlockPtr & block,const py::object & node)1116 FunctionBlockPtr Parser::ParseGlobal(const FunctionBlockPtr &block, const py::object &node) {
1117 MS_LOG(DEBUG) << "Process ast Global";
1118 MS_EXCEPTION_IF_NULL(block);
1119 py::list vars = python_adapter::GetPyObjAttr(node, "names");
1120 for (auto &item : vars) {
1121 block->AddGlobalVar(py::cast<std::string>(item));
1122 }
1123 return block;
1124 }
1125
1126 // Process a if statement
ParseIf(const FunctionBlockPtr & block,const py::object & node)1127 FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object &node) {
1128 MS_LOG(DEBUG) << "Process ast If";
1129 py::object test_node = python_adapter::GetPyObjAttr(node, "test");
1130 AnfNodePtr condition_node = ParseExprNode(block, test_node);
1131 MS_EXCEPTION_IF_NULL(block);
1132 CNodePtr bool_node = block->ForceToBoolNode(condition_node);
1133
1134 FunctionBlockPtr true_block = nullptr;
1135 FunctionBlockPtr false_block = nullptr;
1136 auto block_fg = block->func_graph();
1137 {
1138 TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block_fg->debug_info()));
1139 true_block = MakeFunctionBlock(*this);
1140 }
1141 {
1142 TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block_fg->debug_info()));
1143 false_block = MakeFunctionBlock(*this);
1144 }
1145
1146 MakeConditionBlocks(block, true_block, false_block);
1147
1148 FunctionBlockPtr after_block = nullptr;
1149 {
1150 TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block_fg->debug_info()));
1151 after_block = MakeFunctionBlock(*this);
1152 }
1153
1154 if (MsContext::GetInstance()->backend_policy() != "ge") {
1155 // For backends excludes 'ge', it can handle multi graph call, use this flag to
1156 // generate call not inline `after_block` graph to reduce if by if switch expansion.
1157 after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
1158 }
1159
1160 // Process the if-true branch
1161 py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
1162 FunctionBlockPtr true_end = ParseStatements(true_block, bodyNode);
1163
1164 // If the return_ is set, it has its own continuation block
1165 if (true_end->func_graph()->get_return() == nullptr) {
1166 MS_LOG(DEBUG) << "true end jump to after.";
1167 true_end->Jump(after_block, {});
1168 }
1169
1170 // Process the orelse branch
1171 py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
1172 FunctionBlockPtr false_end = ParseStatements(false_block, orelseNode);
1173
1174 // If the return_ is set, it has its own continuation block
1175 if (false_end->func_graph()->get_return() == nullptr) {
1176 MS_LOG(DEBUG) << "false_end jump to after.";
1177 false_end->Jump(after_block, {});
1178 }
1179
1180 block->ConditionalJump(bool_node, true_block, false_block);
1181 if (after_block->prev_blocks().empty()) {
1182 after_block->SetAsDeadBlock();
1183 }
1184 after_block->Mature();
1185 return after_block;
1186 }
1187
ParseWhile(const FunctionBlockPtr & block,const py::object & node)1188 FunctionBlockPtr Parser::ParseWhile(const FunctionBlockPtr &block, const py::object &node) {
1189 MS_LOG(DEBUG) << "Process ast While";
1190 MS_EXCEPTION_IF_NULL(block);
1191 FunctionBlockPtr header_block = nullptr;
1192 FunctionBlockPtr body_block = nullptr;
1193 FunctionBlockPtr after_block = nullptr;
1194 {
1195 TraceGuard guard(std::make_shared<TraceWhileHeader>(block->func_graph()->debug_info()));
1196 header_block = MakeFunctionBlock(*this);
1197 }
1198 {
1199 TraceGuard guard(std::make_shared<TraceWhileBody>(block->func_graph()->debug_info()));
1200 body_block = MakeFunctionBlock(*this);
1201 }
1202 {
1203 TraceGuard guard(std::make_shared<TraceWhileAfter>(block->func_graph()->debug_info()));
1204 after_block = MakeFunctionBlock(*this);
1205 }
1206
1207 body_block->AddPrevBlock(header_block);
1208 after_block->AddPrevBlock(header_block);
1209 block->Jump(header_block, {});
1210
1211 py::object test_node = python_adapter::GetPyObjAttr(node, "test");
1212 AnfNodePtr condition_node = ParseExprNode(header_block, test_node);
1213 condition_node = header_block->ForceToWhileCond(condition_node);
1214 body_block->Mature();
1215 header_block->ConditionalJump(condition_node, body_block, after_block);
1216
1217 // Parse loop body statements with loop context.
1218 LoopContext loop_context{&loops_, header_block, nullptr};
1219 py::object body_node = python_adapter::GetPyObjAttr(node, "body");
1220 FunctionBlockPtr after_body = ParseStatements(body_block, body_node);
1221 if (after_body->func_graph()->get_return() == nullptr) {
1222 after_body->Jump(header_block, {});
1223 }
1224 header_block->Mature();
1225 after_block->Mature();
1226 auto &end_block = loop_context.EndBlock();
1227 // end_block exists if we encounter 'break' in loop body.
1228 if (end_block) {
1229 after_block->Jump(end_block, {});
1230 end_block->Mature();
1231 return end_block;
1232 }
1233 // No 'break', no end_block.
1234 return after_block;
1235 }
1236
GenerateIteratorInFor(const FunctionBlockPtr & block,const py::object & node,const AnfNodePtr & op_iter)1237 CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py::object &node,
1238 const AnfNodePtr &op_iter) {
1239 py::object iter_node = python_adapter::GetPyObjAttr(node, "iter");
1240 AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
1241 return block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node});
1242 }
1243
GenerateCondInFor(const ParameterPtr & iter_param,const FunctionBlockPtr & header_block,const AnfNodePtr & op_hasnext)1244 CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
1245 const AnfNodePtr &op_hasnext) {
1246 MS_EXCEPTION_IF_NULL(header_block);
1247 return header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
1248 }
1249
GenerateBlock(const TraceInfoPtr & trace_info)1250 FunctionBlockPtr Parser::GenerateBlock(const TraceInfoPtr &trace_info) {
1251 TraceGuard trace_guard(trace_info);
1252 FunctionBlockPtr block = MakeFunctionBlock(*this);
1253 MS_EXCEPTION_IF_NULL(block);
1254 return block;
1255 }
1256
GetForTransToWhileLoop()1257 int64_t Parser::GetForTransToWhileLoop() {
1258 // int64 support 63bits positive num mostly.
1259 constexpr auto max_num_length = 10;
1260 if (max_for_loop_count_str_.size() > max_num_length || max_for_loop_count_str_.empty()) {
1261 return MAX_FOR_LOOP_COUNT;
1262 }
1263 if (std::any_of(max_for_loop_count_str_.begin(), max_for_loop_count_str_.end(),
1264 [](char c) { return c < '0' || c > '9'; })) {
1265 return MAX_FOR_LOOP_COUNT;
1266 }
1267 int64_t loop_count;
1268 std::stringstream ss;
1269 ss << max_for_loop_count_str_;
1270 ss >> loop_count;
1271 return loop_count;
1272 }
1273
1274 // A for loop will generate 3 functions :the test, the body, and the continuation
1275 // for x in xs:
1276 // body
1277 // It is compiled to be following statement
1278 // if len(xs) < max_loop_cnt, ParseForIter. Use iter to implement for loop, which always unroll loop
1279 // else, ParseForLoop. Use loop var to implement for loop, which always sink loop
ParseFor(const FunctionBlockPtr & block,const py::object & node)1280 FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::object &node) {
1281 MS_LOG(DEBUG) << "Process ast For, create an if else statement";
1282 MS_EXCEPTION_IF_NULL(block);
1283 // Create statement 'len(xs) < MAX_FOR_LOOP_COUNT'
1284 AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
1285 py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER);
1286 AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
1287 CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
1288 CNodePtr bool_node = block->func_graph()->NewCNodeInOrder(
1289 {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())});
1290
1291 // Create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop'
1292 FunctionBlockPtr true_block = nullptr;
1293 FunctionBlockPtr false_block = nullptr;
1294 {
1295 TraceGuard guard(std::make_shared<TraceIfStmtTrueBranch>(block->func_graph()->debug_info()));
1296 true_block = MakeFunctionBlock(*this);
1297 }
1298 {
1299 TraceGuard guard(std::make_shared<TraceIfStmtFalseBranch>(block->func_graph()->debug_info()));
1300 false_block = MakeFunctionBlock(*this);
1301 }
1302
1303 MakeConditionBlocks(block, true_block, false_block);
1304
1305 FunctionBlockPtr after_block = nullptr;
1306 {
1307 TraceGuard guard(std::make_shared<TraceIfStmtAfterBranch>(block->func_graph()->debug_info()));
1308 after_block = MakeFunctionBlock(*this);
1309 }
1310
1311 FunctionBlockPtr true_end = ParseForIter(true_block, node);
1312 true_end->Jump(after_block, {});
1313
1314 FunctionBlockPtr false_end = ParseForLoop(false_block, node);
1315 false_end->Jump(after_block, {});
1316
1317 block->ConditionalJump(bool_node, true_block, false_block);
1318 after_block->Mature();
1319 return after_block;
1320 }
1321
1322 // A for loop will generate 3 functions :the test, the body, and the continuation
1323 // for x in xs:
1324 // body
1325 // It is compiled to be following statement
1326 // it = iter(xs)
1327 // while hastnext(it)
1328 // x, it = next(it)
1329 // body
ParseForIter(const FunctionBlockPtr & block,const py::object & node)1330 FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::object &node) {
1331 MS_LOG(DEBUG) << "Process ast For";
1332 MS_EXCEPTION_IF_NULL(block);
1333 AnfNodePtr op_iter = block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
1334 AnfNodePtr op_next = block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
1335 AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
1336 AnfNodePtr op_hasnext = block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
1337 // Generate the iterator apply
1338 CNodePtr iter_apply = GenerateIteratorInFor(block, node, op_iter);
1339 MS_EXCEPTION_IF_NULL(iter_apply);
1340 FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
1341 MS_EXCEPTION_IF_NULL(header_block);
1342 // Generate the hasnext apply which is a condition
1343 ParameterPtr iter_param = header_block->func_graph()->add_parameter();
1344 CNodePtr cond_apply = GenerateCondInFor(iter_param, header_block, op_hasnext);
1345 // Generate the body of the for statement
1346 FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
1347 MS_EXCEPTION_IF_NULL(body_block);
1348 body_block->AddPrevBlock(header_block);
1349 // Generate the iterator next apply
1350 // Process as following: `app = next(it); target = app[0]; it = app[1];`
1351 CNodePtr app = body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
1352 CNodePtr target_app =
1353 body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(0))});
1354 py::object target_node = python_adapter::GetPyObjAttr(node, "target");
1355
1356 CNodePtr iter2_app =
1357 body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast<int64_t>(1))});
1358 WriteAssignVars(body_block, target_node, target_app);
1359
1360 // Link the variable name with the target
1361 auto it_info = std::make_shared<TraceIterator>(target_app->debug_info());
1362 iter_param->debug_info()->set_trace_info(it_info);
1363 iter2_app->debug_info()->set_trace_info(it_info);
1364 iter_apply->debug_info()->set_trace_info(it_info);
1365
1366 FunctionBlockPtr after_block = nullptr;
1367 {
1368 TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
1369 after_block = MakeFunctionBlock(*this);
1370 }
1371 MS_EXCEPTION_IF_NULL(after_block);
1372 after_block->AddPrevBlock(header_block);
1373
1374 block->Jump(header_block, {iter_apply});
1375 body_block->Mature();
1376 header_block->ConditionalJump(cond_apply, body_block, after_block);
1377
1378 // Parse loop body statements with loop context.
1379 LoopContext loop_context{&loops_, header_block, iter2_app};
1380 py::object body_node = python_adapter::GetPyObjAttr(node, "body");
1381 FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
1382 if (after_body_block->func_graph()->get_return() == nullptr) {
1383 after_body_block->Jump(header_block, {iter2_app});
1384 }
1385
1386 header_block->Mature();
1387 after_block->Mature();
1388 auto &end_block = loop_context.EndBlock();
1389 if (end_block) {
1390 // end_block exists if we encounter 'break' in loop body.
1391 after_block->Jump(end_block, {});
1392 end_block->Mature();
1393 return end_block;
1394 }
1395 // No 'break', no end_block.
1396 return after_block;
1397 }
1398
1399 // A for loop will generate 3 functions :the test, the body, and the continuation
1400 // for x in xs:
1401 // body
1402 // It is compiled to be following statement
1403 // i = 0
1404 // while i < len(xs)
1405 // x = xs[i]
1406 // i = i + 1
1407 // body
ParseForLoop(const FunctionBlockPtr & block,const py::object & node)1408 FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::object &node) {
1409 MS_LOG(DEBUG) << "Process ast For by loop variable";
1410 MS_EXCEPTION_IF_NULL(block);
1411 AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN);
1412 AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
1413
1414 // Get variable name of 'x' in statement 'for x in xs'
1415 py::object target_node = python_adapter::GetPyObjAttr(node, "target");
1416
1417 // Create statement 'len(xs)'
1418 py::object iter_obj = python_adapter::GetPyObjAttr(node, "iter");
1419 AnfNodePtr iter_node = ParseExprNode(block, iter_obj);
1420 MS_EXCEPTION_IF_NULL(iter_node);
1421 // Generate node for loop count and convert it to tensor, to make the loop not unroll
1422 CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node});
1423 auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations");
1424 auto scalar_to_tensor_node = block->func_graph()->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
1425
1426 CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len});
1427
1428 FunctionBlockPtr header_block = GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
1429 MS_EXCEPTION_IF_NULL(header_block);
1430 // Create loop variable 'i'
1431 ParameterPtr loop_var = header_block->func_graph()->add_parameter();
1432 // Create loop condition 'i < len(xs)'
1433 auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations");
1434 auto less_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(prim_less)});
1435 CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter});
1436
1437 // Generate the body of the for statement
1438 FunctionBlockPtr body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
1439 MS_EXCEPTION_IF_NULL(body_block);
1440 body_block->AddPrevBlock(header_block);
1441 // Create 'x = xs[i]'
1442 auto body_func_graph = body_block->func_graph();
1443 CNodePtr target_var = body_func_graph->NewCNodeInOrder({op_getitem, iter_node, loop_var});
1444 WriteAssignVars(body_block, target_node, target_var);
1445 // Create 'i = i + 1'
1446 auto prim_add = prim::GetPythonOps("Add", "mindspore.ops.operations");
1447 auto add_node = body_func_graph->NewCNodeInOrder({NewValueNode(prim_add)});
1448 auto body_scalar_to_tensor_node = body_func_graph->NewCNodeInOrder({NewValueNode(scalar_to_tensor)});
1449 auto add_tensor_node =
1450 body_func_graph->NewCNodeInOrder({body_scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(1))});
1451 CNodePtr loop_var_inc = body_func_graph->NewCNodeInOrder({add_node, loop_var, add_tensor_node});
1452 body_block->WriteVariable(loop_var->name(), loop_var_inc);
1453
1454 // Link the variable name with the target
1455 auto it_info = std::make_shared<TraceIterator>(loop_var_inc->debug_info());
1456 loop_var->debug_info()->set_trace_info(it_info);
1457 len_iter->debug_info()->set_trace_info(it_info);
1458
1459 FunctionBlockPtr after_block = nullptr;
1460 {
1461 TraceGuard guard(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
1462 after_block = MakeFunctionBlock(*this);
1463 }
1464 MS_EXCEPTION_IF_NULL(after_block);
1465 after_block->AddPrevBlock(header_block);
1466
1467 CNodePtr zero_tensor =
1468 block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, NewValueNode(static_cast<int64_t>(0))});
1469 block->Jump(header_block, {zero_tensor});
1470 body_block->Mature();
1471
1472 header_block->ConditionalJump(cond_node, body_block, after_block, false);
1473
1474 // Parse loop body statements with loop context.
1475 LoopContext loop_context{&loops_, header_block, loop_var_inc};
1476 py::object body_node = python_adapter::GetPyObjAttr(node, "body");
1477 FunctionBlockPtr after_body_block = ParseStatements(body_block, body_node);
1478 if (after_body_block->func_graph()->get_return() == nullptr) {
1479 after_body_block->Jump(header_block, {loop_var_inc});
1480 }
1481
1482 header_block->Mature();
1483 after_block->Mature();
1484 auto &end_block = loop_context.EndBlock();
1485 if (end_block) {
1486 // end_block exists if we encounter 'break' in loop body.
1487 after_block->Jump(end_block, {});
1488 end_block->Mature();
1489 return end_block;
1490 }
1491 // No 'break', no end_block.
1492 return after_block;
1493 }
1494
ParseIfExp(const FunctionBlockPtr & block,const py::object & node)1495 AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &node) {
1496 MS_LOG(DEBUG) << "Process ast IfExp";
1497 MS_EXCEPTION_IF_NULL(block);
1498 py::object test_node = python_adapter::GetPyObjAttr(node, "test");
1499 AnfNodePtr condition_node = ParseExprNode(block, test_node);
1500 CNodePtr bool_node = block->ForceToBoolNode(condition_node);
1501
1502 FunctionBlockPtr true_block = nullptr;
1503 FunctionBlockPtr false_block = nullptr;
1504 {
1505 TraceGuard guard(std::make_shared<TraceIfExpTrueBranch>(block->func_graph()->debug_info()));
1506 true_block = MakeFunctionBlock(*this);
1507 }
1508 {
1509 TraceGuard guard(std::make_shared<TraceIfExpFalseBranch>(block->func_graph()->debug_info()));
1510 false_block = MakeFunctionBlock(*this);
1511 }
1512
1513 MakeConditionBlocks(block, true_block, false_block);
1514
1515 // Process the if-true branch
1516 py::object bodyNode = python_adapter::GetPyObjAttr(node, "body");
1517 true_block->func_graph()->debug_info()->set_location(GetLocation(bodyNode));
1518 AnfNodePtr true_node = ParseExprNode(true_block, bodyNode);
1519
1520 // Process the orelse branch
1521 py::object orelseNode = python_adapter::GetPyObjAttr(node, "orelse");
1522 false_block->func_graph()->debug_info()->set_location(GetLocation(orelseNode));
1523 AnfNodePtr false_node = ParseExprNode(false_block, orelseNode);
1524
1525 true_block->func_graph()->set_output(true_node);
1526 false_block->func_graph()->set_output(false_node);
1527
1528 // Use the Primitive replace the operation resolve node (switch),
1529 // because the switch will eventually be converted to Primitive node
1530 CNodePtr switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), bool_node,
1531 NewValueNode(true_block->func_graph()),
1532 NewValueNode(false_block->func_graph())});
1533
1534 std::vector<AnfNodePtr> call_graph_nodes{switch_app};
1535 CNodePtr switch_app_call = block->func_graph()->NewCNodeInOrder(call_graph_nodes);
1536 return switch_app_call;
1537 }
1538
ParseListCompIter(const FunctionBlockPtr & block,const py::object & node,const py::object & generator_node)1539 FunctionBlockPtr Parser::ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
1540 const py::object &generator_node) {
1541 // Create a header block.
1542 FunctionBlockPtr top_block = GenerateBlock(std::make_shared<TraceListComp>(block->func_graph()->debug_info()));
1543 // Handle iter attribute.
1544 py::object iter_node = python_adapter::GetPyObjAttr(generator_node, "iter");
1545 AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node);
1546 AnfNodePtr op_iter = top_block->MakeResolveOperation(NAMED_PRIMITIVE_ITER);
1547 CNodePtr iter_apply = top_block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node});
1548
1549 // Create header graph.
1550 FunctionBlockPtr list_header_block =
1551 GenerateBlock(std::make_shared<TraceForHeader>(block->func_graph()->debug_info()));
1552 list_header_block->AddPrevBlock(top_block);
1553
1554 // Create hasNext apply.
1555 AnfNodePtr op_hasnext = top_block->MakeResolveOperation(NAMED_PRIMITIVE_HASNEXT);
1556 ParameterPtr iter_param = list_header_block->func_graph()->add_parameter();
1557 constexpr auto iter_param_name = "iter";
1558 iter_param->set_name(iter_param_name);
1559 iter_param->debug_info()->set_name(iter_param_name);
1560 CNodePtr cond_apply = list_header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param});
1561
1562 // Call the header graph with iter.
1563 ParameterPtr list_param = list_header_block->func_graph()->add_parameter();
1564 constexpr auto list_param_name = "list";
1565 list_param->set_name(list_param_name);
1566 list_param->debug_info()->set_name(list_param_name);
1567 auto empty_list = std::vector<ValuePtr>();
1568 AnfNodePtr empty_list_node = NewValueNode(std::make_shared<ValueList>(empty_list));
1569 top_block->Jump(list_header_block, {iter_apply, empty_list_node});
1570
1571 // Create body graph.
1572 FunctionBlockPtr list_body_block = GenerateBlock(std::make_shared<TraceForBody>(block->func_graph()->debug_info()));
1573 list_body_block->AddPrevBlock(list_header_block);
1574 AnfNodePtr op_next = top_block->MakeResolveOperation(NAMED_PRIMITIVE_NEXT);
1575 CNodePtr next_apply = list_body_block->func_graph()->NewCNodeInOrder({op_next, iter_param});
1576 AnfNodePtr op_getitem = top_block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
1577 CNodePtr item_apply =
1578 list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(0))});
1579 CNodePtr new_iter =
1580 list_body_block->func_graph()->NewCNodeInOrder({op_getitem, next_apply, NewValueNode(static_cast<int64_t>(1))});
1581
1582 // Save the `target` in a variable.
1583 py::object gen_target_node = python_adapter::GetPyObjAttr(generator_node, "target");
1584 WriteAssignVars(list_body_block, gen_target_node, item_apply);
1585
1586 auto ifs_new_list = ParseListCompIfs(list_body_block, list_param, node, generator_node);
1587 list_body_block->Jump(list_header_block, {new_iter, ifs_new_list});
1588
1589 // Create after graph.
1590 FunctionBlockPtr list_after_block = GenerateBlock(std::make_shared<TraceForAfter>(block->func_graph()->debug_info()));
1591 list_after_block->AddPrevBlock(list_header_block);
1592 // Return the list in after graph.
1593 list_after_block->func_graph()->set_output(list_param);
1594
1595 // Run the branches.
1596 list_header_block->ConditionalJump(cond_apply, list_body_block, list_after_block);
1597
1598 top_block->Mature();
1599 list_header_block->Mature();
1600 list_body_block->Mature();
1601 list_after_block->Mature();
1602 return top_block;
1603 }
1604
ParseListCompIfs(const FunctionBlockPtr & list_body_block,const ParameterPtr & list_param,const py::object & node,const py::object & generator_node)1605 AnfNodePtr Parser::ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
1606 const py::object &node, const py::object &generator_node) {
1607 // Handle ifs attribute.
1608 py::list ifs_node = python_adapter::GetPyObjAttr(generator_node, "ifs");
1609 AnfNodePtr ifs_bool_node;
1610 if (ifs_node.empty()) {
1611 ifs_bool_node = NewValueNode(true);
1612 } else {
1613 ifs_bool_node = ProcessBoolOpValueList(list_body_block, ifs_node, AST_SUB_TYPE_AND);
1614 }
1615
1616 // Create if-true graph.
1617 FunctionBlockPtr if_true_block =
1618 GenerateBlock(std::make_shared<TraceIfStmtTrueBranch>(list_body_block->func_graph()->debug_info()));
1619 if_true_block->AddPrevBlock(list_body_block);
1620 // Handle elt attribute in body block.
1621 py::object elt_obj = python_adapter::GetPyObjAttr(node, "elt");
1622 AnfNodePtr elt_node = ParseExprNode(list_body_block, elt_obj);
1623 // Append the element.
1624 auto list_append_op = prim::kPrimListAppend;
1625 auto new_list = list_body_block->func_graph()->NewCNodeInOrder({NewValueNode(list_append_op), list_param, elt_node});
1626 // Return new list in true branch graph.
1627 if_true_block->func_graph()->set_output(new_list);
1628
1629 // Create if-false graph.
1630 FunctionBlockPtr if_false_block =
1631 GenerateBlock(std::make_shared<TraceIfStmtFalseBranch>(list_body_block->func_graph()->debug_info()));
1632 if_false_block->AddPrevBlock(list_body_block);
1633 // Return original list in false branch graph.
1634 if_false_block->func_graph()->set_output(list_param);
1635
1636 // We don't want to create a header graph, where to get and wrap the result of Switch().
1637 // So just call ConditionalJump() to set Switch() as output, and reset it later, as tricky.
1638 list_body_block->ConditionalJump(ifs_bool_node, if_true_block, if_false_block);
1639 // Output is Switch() result, i.e. updated list.
1640 auto switch_apply_node = list_body_block->func_graph()->output();
1641 auto ifs_new_list = switch_apply_node;
1642 // Since we call ConditionalJump() above, to reset the Return as null before call Jump().
1643 list_body_block->func_graph()->set_return(nullptr);
1644 if_true_block->Mature();
1645 if_false_block->Mature();
1646 return ifs_new_list;
1647 }
1648
1649 // A ListComp contains: `elt` and `generators`.
1650 // `generators` contains: `target`, `iter` and `ifs`.
1651 // For example:
1652 // [x * x for x in range(0, 10) if x % 2 == 0]
1653 // It is compiled to be following statement:
1654 // list = []
1655 // for x in range(0, 10):
1656 // if x % 2 == 0:
1657 // list.append(x * x)
1658 // return list
ParseListComp(const FunctionBlockPtr & block,const py::object & node)1659 AnfNodePtr Parser::ParseListComp(const FunctionBlockPtr &block, const py::object &node) {
1660 MS_LOG(DEBUG) << "Process ast ListComp";
1661 MS_EXCEPTION_IF_NULL(block);
1662
1663 // Handle generators attribute.
1664 py::list generators_node = python_adapter::GetPyObjAttr(node, "generators");
1665 if (generators_node.size() != 1) {
1666 MS_EXCEPTION(TypeError) << "The 'generators' supports 1 'comprehension' in ListComp/GeneratorExp, but got "
1667 << generators_node.size() << " comprehensions.";
1668 }
1669 py::object generator_node = generators_node[0];
1670 auto generator_node_type = ast_->GetNodeType(generator_node);
1671 auto generator_node_name = generator_node_type->node_name();
1672 constexpr auto comprehension_name = "comprehension";
1673 if (generator_node_name != comprehension_name) {
1674 MS_LOG(EXCEPTION) << "Generator node name should be " << comprehension_name << ", but got " << generator_node_name;
1675 }
1676
1677 // Parse ListComp's `iter` and add `elt` in it.
1678 auto top_block = ParseListCompIter(block, node, generator_node);
1679
1680 // Call the top graph and return the list.
1681 auto call_function_node = NewValueNode(top_block->func_graph());
1682 std::vector<AnfNodePtr> func_call_nodes;
1683 func_call_nodes.push_back(call_function_node);
1684 AnfNodePtr output = block->func_graph()->NewCNodeInOrder(func_call_nodes);
1685 return output;
1686 }
1687
HandleAssignName(const FunctionBlockPtr & block,const py::object & target_object,const AnfNodePtr & assigned_node)1688 void Parser::HandleAssignName(const FunctionBlockPtr &block, const py::object &target_object,
1689 const AnfNodePtr &assigned_node) {
1690 MS_EXCEPTION_IF_NULL(block);
1691 MS_EXCEPTION_IF_NULL(assigned_node);
1692 py::str name = python_adapter::GetPyObjAttr(target_object, "id");
1693 std::string name_id = name;
1694 assigned_node->debug_info()->set_name(name_id);
1695 // Set the debug name of the constant graph
1696 if (IsValueNode<FuncGraph>(assigned_node)) {
1697 // The value should be graph
1698 auto fg = GetValueNode<FuncGraphPtr>(assigned_node);
1699 if (fg->debug_info()->name().empty()) {
1700 fg->debug_info()->set_name(name_id);
1701 }
1702 }
1703 MS_LOG(DEBUG) << "Assign name: `" << name_id << "` to node: " << assigned_node->DebugString();
1704 block->AddLocalPyParam(name_id, assigned_node);
1705 block->WriteVariable(name_id, assigned_node);
1706 }
1707
HandleAssignTuple(const FunctionBlockPtr & block,const py::object & target_object,const AnfNodePtr & assigned_node)1708 void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object &target_object,
1709 const AnfNodePtr &assigned_node) {
1710 MS_EXCEPTION_IF_NULL(block);
1711 AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM);
1712 py::list items = python_adapter::GetPyObjAttr(target_object, "elts");
1713 for (size_t i = 0; i < items.size(); i++) {
1714 // Use the Primitive replace the operation resolve node (getitem),
1715 // because the getitem will eventually be converted to Primitive node
1716 CNodePtr item_apply =
1717 block->func_graph()->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast<int64_t>(i))});
1718
1719 py::object elt = items[i];
1720 WriteAssignVars(block, elt, item_apply);
1721 }
1722 }
1723
HandleAssignClassMember(const FunctionBlockPtr & block,const py::object & target_object,const AnfNodePtr & assigned_node)1724 void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target_object,
1725 const AnfNodePtr &assigned_node) {
1726 // Now only support the self.xx = xxxxx, can't support x.y = xxxx
1727 AnfNodePtr target_node = ParseExprNode(block, target_object);
1728 MS_EXCEPTION_IF_NULL(target_node);
1729
1730 auto attr_name = target_object.attr("attr").cast<std::string>();
1731 std::string var_name = "self." + attr_name;
1732
1733 // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type
1734 if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
1735 MS_EXCEPTION(TypeError)
1736 << "'" << var_name << "' should be initialized as a 'Parameter' in the '__init__' function before assigning.\n\n"
1737 << trace::GetDebugInfo(target_node->debug_info());
1738 }
1739 auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
1740 auto obj_type = obj.attr("__class__").attr("__name__");
1741 if (!py::hasattr(obj, "__parameter__")) {
1742 MS_EXCEPTION(TypeError) << "'" << var_name
1743 << "' should be initialized as a 'Parameter' type in the '__init__' function, but got '"
1744 << py::str(obj).cast<std::string>() << "' with type '"
1745 << py::str(obj_type).cast<std::string>() << ".\n\n"
1746 << trace::GetDebugInfo(target_node->debug_info());
1747 }
1748
1749 MS_EXCEPTION_IF_NULL(block);
1750 MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString();
1751 block->SetStateAssign(target_node, assigned_node);
1752 }
1753
HandleAssignSubscript(const FunctionBlockPtr & block,const py::object & target_object,const AnfNodePtr & assigned_node)1754 void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target_object,
1755 const AnfNodePtr &assigned_node) {
1756 MS_EXCEPTION_IF_NULL(block);
1757 AnfNodePtr op_setitem = block->MakeResolveOperation(NAMED_PRIMITIVE_SETITEM);
1758 py::object value_obj = python_adapter::GetPyObjAttr(target_object, "value");
1759 py::object slice_obj = python_adapter::GetPyObjAttr(target_object, "slice");
1760 AnfNodePtr value_node = ParseExprNode(block, value_obj);
1761 AnfNodePtr slice_node = ParseExprNode(block, slice_obj);
1762 CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node});
1763 // Getitem apply should return the sequence data structure itself
1764 std::string var_name;
1765 if (ast_->IsClassMember(value_obj)) {
1766 auto attr_name = value_obj.attr("attr").cast<std::string>();
1767 var_name = "self." + attr_name;
1768 if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) {
1769 MS_EXCEPTION(TypeError)
1770 << "'" << var_name
1771 << "' should be initialized as a 'Parameter' in the '__init__' function before assigning.\n\n"
1772 << trace::GetDebugInfo(value_node->debug_info());
1773 }
1774 auto obj = ast()->obj().attr(common::SafeCStr(attr_name));
1775 auto obj_type = obj.attr("__class__").attr("__name__");
1776 if (!py::hasattr(obj, "__parameter__")) {
1777 MS_EXCEPTION(TypeError) << "'" << var_name
1778 << "' should be initialized as a 'Parameter' in the '__init__' function, but got '"
1779 << py::str(obj).cast<std::string>() << "' with type '"
1780 << py::str(obj_type).cast<std::string>() << "'.\n\n"
1781 << trace::GetDebugInfo(value_node->debug_info());
1782 }
1783 block->WriteVariable(var_name, setitem_app);
1784 return;
1785 }
1786 if (AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, value_obj))) ==
1787 AST_SUB_TYPE_SUBSCRIPT) {
1788 HandleAssignSubscript(block, value_obj, setitem_app);
1789 return;
1790 }
1791 if (!py::hasattr(value_obj, "id")) {
1792 MS_EXCEPTION(TypeError) << "Attribute id not found in " << py::str(value_obj).cast<std::string>() << "\n\n"
1793 << trace::GetDebugInfo(value_node->debug_info());
1794 }
1795 var_name = value_obj.attr("id").cast<std::string>();
1796 block->WriteVariable(var_name, setitem_app);
1797 }
1798
WriteAssignVars(const FunctionBlockPtr & block,const py::object & target_object,const AnfNodePtr & value_node)1799 void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object,
1800 const AnfNodePtr &value_node) {
1801 MS_EXCEPTION_IF_NULL(value_node);
1802 MS_LOG(DEBUG) << "Process WriteAssignVars";
1803 auto ast_type = AstSubType(py::cast<int32_t>(ast_->CallParseModFunction(PYTHON_PARSE_GET_AST_TYPE, target_object)));
1804 if (ast_type == AST_SUB_TYPE_NAME) {
1805 HandleAssignName(block, target_object, value_node);
1806 } else if (ast_type == AST_SUB_TYPE_TUPLE) {
1807 HandleAssignTuple(block, target_object, value_node);
1808 } else if (ast_type == AST_SUB_TYPE_SUBSCRIPT) {
1809 HandleAssignSubscript(block, target_object, value_node);
1810 } else if (ast_->IsClassMember(target_object)) {
1811 HandleAssignClassMember(block, target_object, value_node);
1812 } else if (ast_type == AST_SUB_TYPE_ATTRIBUTE) {
1813 TraceGuard(GetLocation(target_object));
1814 MS_EXCEPTION(TypeError) << "Only support assign to attribute of self, but got attribute of "
1815 << py::str(target_object.attr("value").attr("id")) << ".\n"
1816 << "More details please refer to syntax support at https://www.mindspore.cn";
1817 } else {
1818 TraceGuard(GetLocation(target_object));
1819 MS_EXCEPTION(TypeError) << "Only supported augassign to attribute of self, variable and index value, but got "
1820 << target_object.get_type()
1821 << ".\nMore details please refer to syntax support at https://www.mindspore.cn";
1822 }
1823 }
1824
HandleInterpret(const FunctionBlockPtr & block,const AnfNodePtr & value_node,const py::object & value_object)1825 AnfNodePtr Parser::HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
1826 const py::object &value_object) {
1827 // The fallback feature is enabled in default.
1828 // Not support change the flag during the process is alive.
1829 static const auto use_fallback = (support_fallback() == "1");
1830 if (!use_fallback) {
1831 return value_node;
1832 }
1833
1834 AnfNodePtr interpreted_node = value_node;
1835 if (value_node->interpret()) {
1836 const auto script_text = py::cast<std::string>(ast()->GetAstNodeText(value_object));
1837 py::dict global_dict = block->global_py_params();
1838 constexpr int recursive_level = 3;
1839 MS_LOG(INFO) << "[" << block->func_graph()->ToString() << "] script_text: " << script_text
1840 << ", value_node: " << value_node->DebugString(recursive_level)
1841 << ", global_dict: " << py::str(global_dict);
1842 // Prepare global parameters.
1843 ValuePtr globals_converted_value = nullptr;
1844 if (!ConvertData(global_dict, &globals_converted_value)) {
1845 MS_LOG(EXCEPTION) << "Convert data failed";
1846 }
1847 auto global_dict_node = NewValueNode(globals_converted_value);
1848 // Prepare local parameters.
1849 auto [keys, values] = block->local_py_params();
1850 auto local_dict_node = ParseDictByKeysAndValues(block, keys, values);
1851 // Update the valued node if it need interpreting.
1852 interpreted_node = block->MakeInterpret(script_text, global_dict_node, local_dict_node, value_node);
1853
1854 // Print a hint for user.
1855 MS_LOG(ERROR) << "Found unsupported syntax in Graph mode, those codes would be fell back to Python interpreter:"
1856 << "\n\n"
1857 << trace::GetDebugInfo(value_node->debug_info());
1858 }
1859 return interpreted_node;
1860 }
1861
1862 // Process a assign statement, such as a = b, a, b = tuple(xx, xx)
ParseAssign(const FunctionBlockPtr & block,const py::object & node)1863 FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) {
1864 MS_LOG(DEBUG) << "Process ast assign";
1865 py::object value_object = python_adapter::GetPyObjAttr(node, "value");
1866 AnfNodePtr value_node = ParseExprNode(block, value_object);
1867 value_node = HandleInterpret(block, value_node, value_object);
1868
1869 py::object targets_object = python_adapter::GetPyObjAttr(node, "targets");
1870 py::int_ pcount = python_adapter::CallPyObjMethod(targets_object, "__len__");
1871 size_t count = LongToSize(pcount);
1872 MS_LOG(DEBUG) << "The nodes count is " << count;
1873 for (size_t i = 0; i < count; i++) {
1874 auto target_node = py::cast<py::list>(targets_object)[i];
1875 WriteAssignVars(block, target_node, value_node);
1876 }
1877
1878 return block;
1879 }
1880
ParseBreak(const FunctionBlockPtr & block,const py::object & node)1881 FunctionBlockPtr Parser::ParseBreak(const FunctionBlockPtr &block, const py::object &node) {
1882 if (loops_.empty()) {
1883 // Report error if loop context not set for the 'break' statement.
1884 MS_LOG(EXCEPTION) << "Unexpected 'break'.";
1885 }
1886 // Get current loop.
1887 Loop &loop = loops_.top();
1888 if (loop.end == nullptr) {
1889 // Create end_block if it is not existed.
1890 TraceGuard trace_guard(std::make_shared<TraceLoopEnd>(block->func_graph()->debug_info()));
1891 loop.end = MakeFunctionBlock(*this);
1892 }
1893 // Jump to the end_block.
1894 block->Jump(loop.end, {});
1895 return block;
1896 }
1897
ParseContinue(const FunctionBlockPtr & block,const py::object & node)1898 FunctionBlockPtr Parser::ParseContinue(const FunctionBlockPtr &block, const py::object &node) {
1899 if (loops_.empty()) {
1900 // Report error if loop context not set for the 'continue' statement.
1901 MS_LOG(EXCEPTION) << "Unexpected 'continue.";
1902 }
1903 // Jump to the header of the loop with iterator called.
1904 Loop &loop = loops_.top();
1905 std::vector<AnfNodePtr> args;
1906 if (loop.iterator != nullptr) {
1907 args.emplace_back(loop.iterator);
1908 }
1909 block->Jump(loop.header, args);
1910 return block;
1911 }
1912
ParsePass(const FunctionBlockPtr & block,const py::object & node)1913 FunctionBlockPtr Parser::ParsePass(const FunctionBlockPtr &block, const py::object &node) {
1914 // We just bypass 'pass' statement.
1915 return block;
1916 }
1917
FindPhis(const std::unordered_map<ParameterPtr,AnfNodePtr> & removable_phis,const AnfNodePtr & node)1918 AnfNodePtr FindPhis(const std::unordered_map<ParameterPtr, AnfNodePtr> &removable_phis, const AnfNodePtr &node) {
1919 MS_EXCEPTION_IF_NULL(node);
1920 const auto &inp = node->cast<ParameterPtr>();
1921 const auto &iter = removable_phis.find(inp);
1922 if (iter == removable_phis.end()) {
1923 return node;
1924 }
1925 return FindPhis(removable_phis, iter->second);
1926 }
1927
RemoveUnnecessaryPhis()1928 void Parser::RemoveUnnecessaryPhis() {
1929 // Merge all removable phis to one map;
1930 std::unordered_map<ParameterPtr, AnfNodePtr> removable_phis;
1931 std::vector<ParameterPtr> phis;
1932 for (FunctionBlockPtr &block : func_block_list_) {
1933 MS_EXCEPTION_IF_NULL(block);
1934 removable_phis.insert(block->removable_phis().begin(), block->removable_phis().end());
1935 std::transform(block->removable_phis().begin(), block->removable_phis().end(), std::back_inserter(phis),
1936 [](const std::pair<ParameterPtr, AnfNodePtr> &pair) { return pair.first; });
1937 }
1938 if (removable_phis.empty()) {
1939 return;
1940 }
1941 auto mng = Manage(func_graph_, false);
1942 // Replace the nodes
1943 // Remove from inside to outside
1944 for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) {
1945 auto phi = phis[LongToSize(idx)];
1946 auto new_node = FindPhis(removable_phis, phi);
1947 mng->Replace(phi, new_node);
1948 }
1949 // Remove the parameter
1950 for (FunctionBlockPtr &block : func_block_list_) {
1951 MS_EXCEPTION_IF_NULL(block);
1952 auto &local_removable_phis = block->removable_phis();
1953 if (local_removable_phis.empty()) {
1954 continue;
1955 }
1956 auto func_graph = block->func_graph();
1957 auto ¶meters = func_graph->parameters();
1958 std::vector<AnfNodePtr> new_parameters(parameters.size());
1959 auto it = std::copy_if(
1960 parameters.begin(), parameters.end(), new_parameters.begin(), [&local_removable_phis](const AnfNodePtr ¶m) {
1961 MS_EXCEPTION_IF_NULL(param);
1962 return local_removable_phis.find(param->cast<ParameterPtr>()) == local_removable_phis.end();
1963 });
1964
1965 // Shrink container to new size
1966 new_parameters.resize(static_cast<size_t>(std::distance(new_parameters.begin(), it)));
1967 func_graph->set_parameters(new_parameters);
1968 }
1969 }
1970
1971 // ParseFunctionAst class code
InitParseAstInfo(const std::string & python_mod_get_parse_method)1972 bool ParseFunctionAst::InitParseAstInfo(const std::string &python_mod_get_parse_method) {
1973 // Init the type
1974 target_type_ = PARSE_TARGET_UNKNOW;
1975
1976 // Call python parse, get the parser fn
1977 module_ = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
1978 py::object parse_method = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_PARSE_METHOD);
1979
1980 // Get the obj type
1981 auto type = data_converter::GetObjType(obj_);
1982 if (type == RESOLVE_TYPE_FUNCTION) {
1983 target_type_ = PARSE_TARGET_FUNCTION;
1984 function_ = obj_;
1985 } else if (type == RESOLVE_TYPE_METHOD) {
1986 // Process the method ,need get the method's self obj
1987 target_type_ = PARSE_TARGET_METHOD;
1988 py::object method_object = python_adapter::GetPyObjAttr(obj_, PYTHON_GET_METHOD_SELF_CLASS);
1989 if (py::isinstance<py::none>(method_object)) {
1990 MS_LOG(ERROR) << "Get method's self object instance failed.";
1991 return false;
1992 }
1993 target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
1994 function_ = obj_;
1995 obj_ = method_object;
1996 } else if (type == RESOLVE_TYPE_CLASS_INSTANCE) {
1997 // obj is class instance, get the method to parse.
1998 function_ = python_adapter::CallPyModFn(module_, python_mod_get_parse_method, obj_, parse_method);
1999 if (py::isinstance<py::none>(function_)) {
2000 MS_LOG(ERROR) << "Get obj method function failed.";
2001 return false;
2002 }
2003 target_type_ = PARSE_TARGET_OBJECT_INSTANCE;
2004 // Check the fn is method
2005 auto obj_type = data_converter::GetObjType(function_);
2006 if (obj_type != RESOLVE_TYPE_METHOD) {
2007 MS_LOG(WARNING) << "Parse method function is invalid.";
2008 return false;
2009 }
2010 } else {
2011 MS_LOG(WARNING) << "Parse obj is invalid, only can parse function and obj, type = " << type;
2012 return false;
2013 }
2014
2015 // Call python parse get ast tree
2016 parser_ = python_adapter::CallPyModFn(module_, PYTHON_MOD_PARSE_OBJECT_FUNCTION, function_, parse_method);
2017 py::tuple ast_info = python_adapter::CallPyObjMethod(parser_, "parse");
2018 const size_t ast_info_size = 2;
2019 if (ast_info.size() != ast_info_size) {
2020 MS_EXCEPTION(NameError) << "ast info size is not equal to 2.";
2021 }
2022 ast_tokens_ = ast_info[0];
2023 ast_tree_ = ast_info[1];
2024
2025 // Get fn name and module
2026 function_module_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_module"));
2027 function_name_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "function_name"));
2028 function_filename_ = py::cast<std::string>(python_adapter::GetPyObjAttr(parser_, "filename"));
2029 function_line_offset_ = py::cast<int64_t>(python_adapter::GetPyObjAttr(parser_, "line_offset"));
2030
2031 return true;
2032 }
2033
2034 // Get ast tree node : is the tree bode list[0]
GetAstNode()2035 py::object ParseFunctionAst::GetAstNode() {
2036 py::list tree_body = python_adapter::GetPyObjAttr(ast_tree_, "body");
2037 py::object ast_node = tree_body[0];
2038 return ast_node;
2039 }
2040
2041 // Get ast tokens node text.
GetAstNodeText(const py::object & node_obj)2042 py::str ParseFunctionAst::GetAstNodeText(const py::object &node_obj) {
2043 return python_adapter::CallPyObjMethod(ast_tokens_, "get_text", node_obj);
2044 }
2045
GetArgs(const py::object & func_node)2046 py::list ParseFunctionAst::GetArgs(const py::object &func_node) {
2047 py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS, func_node);
2048 return ret;
2049 }
2050
GetArgsDefaultValues(const py::object & func_node)2051 py::list ParseFunctionAst::GetArgsDefaultValues(const py::object &func_node) {
2052 py::list ret = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES, func_node);
2053 return ret;
2054 }
2055
GetNodeType(const py::object & node)2056 AstNodeTypePtr ParseFunctionAst::GetNodeType(const py::object &node) {
2057 py::list list_value = python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_NODE_TYPE, node);
2058 const size_t list_value_size = 2;
2059 if (list_value.size() < list_value_size) {
2060 MS_LOG(EXCEPTION) << "The node of python method must has 2 values.";
2061 }
2062 auto node_name = py::cast<std::string>(list_value[0]);
2063 auto type = AstMainType(py::cast<int32_t>(list_value[1]));
2064 return std::make_shared<AstNodeType>(node, node_name, type);
2065 }
2066
GetOpType(const py::object & node)2067 AstSubType ParseFunctionAst::GetOpType(const py::object &node) {
2068 auto op_type = AstSubType(python_adapter::CallPyModFn(module_, PYTHON_PARSE_GET_AST_TYPE, node).cast<int32_t>());
2069 return op_type;
2070 }
2071
IsClassMember(const py::object & node)2072 bool ParseFunctionAst::IsClassMember(const py::object &node) {
2073 py::object ret = CallParseModFunction(PYTHON_MOD_PARSE_CHECK_IS_CLASS_MEMBER, node);
2074 if (!py::isinstance<py::bool_>(ret)) {
2075 MS_LOG(ERROR) << "The result of mod function parse, should be bool type.";
2076 return false;
2077 }
2078 return ret.cast<bool>();
2079 }
2080
UpdateFuncGraphFlags(const py::object & obj,const FuncGraphPtr & func_graph)2081 bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph) {
2082 if (func_graph == nullptr) {
2083 MS_LOG(ERROR) << "FuncGraph is null";
2084 return false;
2085 }
2086
2087 if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) {
2088 MS_LOG(DEBUG) << "No flags";
2089 return true;
2090 }
2091 py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG);
2092 for (auto &item : flags) {
2093 if (!py::isinstance<py::str>(item.first)) {
2094 MS_LOG(ERROR) << "Type error in flags dict convert";
2095 return false;
2096 }
2097 auto name = py::cast<std::string>(item.first);
2098 if (py::isinstance<py::bool_>(item.second)) {
2099 auto value = py::cast<bool>(item.second);
2100 MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
2101 func_graph->set_flag(name, value);
2102 } else if (py::isinstance<py::str>(item.second)) {
2103 auto value = py::cast<std::string>(item.second);
2104 MS_LOG(DEBUG) << "Flag name: " << name << ". Value: " << value;
2105 func_graph->set_attr(name, MakeValue(value));
2106 } else {
2107 MS_LOG(ERROR) << "Type error in flags/attrs dict convert";
2108 return false;
2109 }
2110 }
2111 return true;
2112 }
2113
2114 // Generate and copy a ValueNode, or a CNode with its child nodes
CopyNodesFromParamDefaultValue(const FuncGraphPtr & func_graph,const AnfNodePtr & param_node)2115 static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m_node) {
2116 MS_EXCEPTION_IF_NULL(param_node);
2117 if (param_node->isa<ValueNode>()) {
2118 return std::make_shared<ValueNode>(param_node->cast<ValueNodePtr>()->value());
2119 }
2120
2121 // Parameter default value is CNode.
2122 std::size_t index = 0;
2123 std::vector<AnfNodePtr> old_cnodes;
2124 old_cnodes.emplace_back(param_node);
2125 MS_EXCEPTION_IF_NULL(func_graph);
2126 auto res = func_graph->NewCNodeInOrder({});
2127 std::vector<CNodePtr> new_cnodes;
2128 new_cnodes.emplace_back(res);
2129 while (index < old_cnodes.size()) {
2130 auto current = old_cnodes[index];
2131 auto current_new_cnode = new_cnodes[index];
2132 index++;
2133 if (current->isa<CNode>()) {
2134 auto &inputs = current->cast<CNodePtr>()->inputs();
2135 for (auto it = inputs.begin(); it != inputs.end(); it++) {
2136 AnfNodePtr input = *it;
2137 if (input != nullptr && input->isa<CNode>()) {
2138 old_cnodes.emplace_back(input);
2139 auto new_cnode = func_graph->NewCNodeInOrder({});
2140 new_cnodes.emplace_back(new_cnode);
2141 current_new_cnode->add_input(new_cnode);
2142 } else if (input->isa<ValueNode>()) {
2143 current_new_cnode->add_input(std::make_shared<ValueNode>(input->cast<ValueNodePtr>()->value()));
2144 } else {
2145 MS_LOG(EXCEPTION) << "Wrong type item in default parameters: " << input->ToString();
2146 }
2147 }
2148 }
2149 }
2150 return res;
2151 }
2152
MakeTopGraph(const py::object & cell,const ValuePtr & cell_ptr)2153 FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
2154 auto current_graph = dyn_cast<FuncGraph>(cell_ptr);
2155 if (current_graph == nullptr) {
2156 MS_LOG(EXCEPTION) << "Current graph cast failed from " << cell_ptr->ToString();
2157 }
2158
2159 auto func_graph = std::make_shared<FuncGraph>();
2160 func_graph->debug_info()->set_name(current_graph->debug_info()->name() + "_wrapper");
2161 func_graph->debug_info()->set_location(current_graph->debug_info()->location());
2162
2163 // Copy all parameters information
2164 for (auto ¶ : current_graph->parameters()) {
2165 auto param = func_graph->add_parameter();
2166 auto orig_param = para->cast<ParameterPtr>();
2167 auto name = orig_param->name();
2168 param->set_name(name);
2169 param->debug_info()->set_name(name);
2170 param->debug_info()->set_location(param->debug_info()->location());
2171 }
2172 func_graph->set_has_vararg(current_graph->has_vararg());
2173 func_graph->set_has_kwarg(current_graph->has_kwarg());
2174 func_graph->set_kwonlyargs_count(current_graph->kwonlyargs_count());
2175 // Copy all default values
2176 for (auto &d : current_graph->parameter_default_value()) {
2177 func_graph->set_param_default_value(d.first, CopyNodesFromParamDefaultValue(func_graph, d.second));
2178 }
2179
2180 // cell_obj
2181 MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
2182 parse::UpdateFuncGraphFlags(cell, func_graph);
2183 // Top graph's construct flag
2184 if (py::hasattr(cell, "construct")) {
2185 parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
2186 }
2187
2188 auto unpacking = func_graph->has_vararg() || func_graph->has_kwarg();
2189 if (!unpacking) {
2190 std::vector<AnfNodePtr> inputs;
2191 inputs.emplace_back(NewValueNode(cell_ptr));
2192 auto ¶ms = func_graph->parameters();
2193 (void)std::transform(params.begin(), params.end(), std::back_inserter(inputs),
2194 [](AnfNodePtr node) -> AnfNodePtr { return node; });
2195 auto call_node = func_graph->NewCNodeInOrder(inputs);
2196
2197 TraceGuard guard(current_graph->get_return()->debug_info()->location());
2198 func_graph->set_output(call_node);
2199 } else {
2200 // ret = cell_obj(*arg, *kwargs)
2201 auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters());
2202
2203 TraceGuard guard(current_graph->get_return()->debug_info()->location());
2204 // Set output as ret
2205 func_graph->set_output(call_fn);
2206 }
2207 return func_graph;
2208 }
2209 } // namespace parse
2210 } // namespace mindspore
2211