1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2023 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_ 20 #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_ 21 22 #include <limits> 23 #include <utility> 24 #include <tuple> 25 #include <vector> 26 #include <string> 27 #include <map> 28 #include <set> 29 #include <stack> 30 #include <memory> 31 #include "utils/misc.h" 32 #include "ir/anf.h" 33 #include "pipeline/jit/ps/parse/parse_base.h" 34 #include "include/common/utils/python_adapter.h" 35 #include "pipeline/jit/ps/parse/function_block.h" 36 37 namespace mindspore { 38 namespace parse { 39 // Parse status define. 40 enum ParseStatusCode : int64_t { 41 PARSE_SUCCESS = 0, 42 PARSE_FUNCTION_IS_NULL, // Python function is null 43 PARSE_PARAMETER_INVALID, // Parameter is invalid 44 PARSE_NO_RETURN, // Function no return node 45 PARSE_NODE_TYPE_NO_MATCH, // Ast node type is error 46 PARSE_NODE_TYPE_UNKNOWN, // Node type is unknown 47 PARSE_NODE_METHOD_UNSUPPORTED, // No method to parse the node 48 PARSE_DONT_RESOLVE_SYMBOL, // Can't resolve the string 49 PARSE_NOT_SUPPORTED_COMPARE_EXPR, // The comparison is not supported 50 PARSE_FAILURE = 0xFF 51 }; 52 53 constexpr char kStandardMethodModelName[] = "mindspore._extends.parse.standard_method"; 54 55 // Max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it 56 // will be sunk(i.e. not unrolled) 57 // NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were 58 // not implemented, so here set MAX_FOR_LOOP_COUNT to int64_t max limit to override default value `600`. This will make 59 // the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised 60 // when function call depth exceeds the limit `context.get_context('max_call_depth')`. 61 const int64_t MAX_FOR_LOOP_COUNT = std::numeric_limits<int64_t>::max(); 62 63 class AstNodeType; 64 class ParseFunctionAst; 65 66 // Save loop info for 'continue' and 'break' statements. 67 struct Loop { 68 // Loop header block. 69 FunctionBlockPtr header; 70 // Loop iterator node, used in 'for loop'. 71 AnfNodePtr iterator; 72 // Loop end block. 73 FunctionBlockPtr end; 74 LoopLoop75 Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end) 76 : header(header), iterator(iterator), end(end) {} 77 ~Loop() = default; 78 }; 79 80 // Loop context for loop stack management. 81 class LoopContext { 82 public: LoopContext(std::stack<Loop> * loops,const FunctionBlockPtr & header,const AnfNodePtr & iterator)83 LoopContext(std::stack<Loop> *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) { 84 loops_->emplace(header, iterator, nullptr); 85 } ~LoopContext()86 ~LoopContext() { 87 try { 88 MS_EXCEPTION_IF_NULL(loops_); 89 loops_->pop(); 90 } catch (const std::exception &e) { 91 MS_LOG(ERROR) << "Exception when pop. Error info " << e.what(); 92 } catch (...) { 93 MS_LOG(ERROR) << "Throw exception when pop."; 94 } 95 loops_ = nullptr; 96 } 97 EndBlock()98 const FunctionBlockPtr &EndBlock() const { return loops_->top().end; } 99 100 private: 101 std::stack<Loop> *loops_; 102 }; 103 104 struct ArgsContext { 105 bool need_unpack{false}; 106 bool has_interpret_without_internal{false}; 107 bool has_interpret_internal{false}; 108 109 std::vector<AnfNodePtr> packed_arguments; 110 std::vector<AnfNodePtr> group_arguments; ArgsContextArgsContext111 ArgsContext() {} 112 ~ArgsContext() = default; 113 }; 114 115 // Parser to parse python function. 116 class Parser { 117 public: 118 explicit Parser(const std::shared_ptr<ParseFunctionAst> &ast, ValuePtrList args_value_list); 119 ~Parser()120 ~Parser() {} 121 FuncGraphPtr ParseFuncGraph(); func_graph()122 FuncGraphPtr func_graph() const { return func_graph_; } errcode()123 ParseStatusCode errcode() const { return errcode_; } ast()124 std::shared_ptr<ParseFunctionAst> ast() const { return ast_; } 125 // Get location info from the ast node. 126 LocationPtr GetLocation(const py::object &node) const; 127 static void InitParserEnvironment(const py::object &obj); 128 static void CleanParserResource(); GetTopFuncGraph()129 static FuncGraphPtr GetTopFuncGraph() { return top_func_graph_.lock(); } 130 static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph); EnableDeferResolve(bool enabled)131 static void EnableDeferResolve(bool enabled) { defer_resolve_ = enabled; } defer_resolve()132 static bool defer_resolve() { return defer_resolve_; } 133 134 private: 135 // Process the stmt node method list. 136 FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node); 137 // Parse expression. 138 FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node); 139 // Process a if statement. 140 FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node); 141 // Process a while statement. 142 FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node); 143 // Process a for statement. 144 FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node); 145 FunctionBlockPtr ParseForUnroll(const FunctionBlockPtr &block, const py::object &node); 146 FunctionBlockPtr ParseForRepeat(const FunctionBlockPtr &block, const py::object &node); 147 // Process a function def statement. 148 FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node); 149 // Process a augment assign. 150 FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node); 151 // Process a global declaration. 152 FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node); 153 // Process assign statement. 154 FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node); 155 // Process annassign statement. 156 FunctionBlockPtr ParseAnnAssign(const FunctionBlockPtr &block, const py::object &node); 157 // Process break statement. 158 FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node); 159 // Process continue statement. 160 FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node); 161 // Process pass statement. 162 FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node); 163 // Process raise statement. 164 FunctionBlockPtr ParseRaise(const FunctionBlockPtr &block, const py::object &node); 165 // Process assert statement. 166 FunctionBlockPtr ParseAssert(const FunctionBlockPtr &block, const py::object &node); 167 // Process with statement. 168 FunctionBlockPtr ParseWith(const FunctionBlockPtr &block, const py::object &node); 169 170 // Process withitem. 171 AnfNodePtr ParseWithitem(const FunctionBlockPtr &block, const py::object &node, const AnfNodePtr &context_expr_node); 172 // Process the expr and slice node method list. 173 AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node); 174 // Process a variable name. 175 AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); 176 // Process NoneType. 177 AnfNodePtr ParseNone(const FunctionBlockPtr &, const py::object &); 178 // Process Ellipsis. 179 AnfNodePtr ParseEllipsis(const FunctionBlockPtr &, const py::object &); 180 // Process an integer or float number. 181 AnfNodePtr ParseNum(const FunctionBlockPtr &, const py::object &node); 182 // Process a string variable. 183 AnfNodePtr ParseStr(const FunctionBlockPtr &, const py::object &node); 184 // Process a Constant. 185 AnfNodePtr ParseConstant(const FunctionBlockPtr &, const py::object &node); 186 // Process a name. 187 AnfNodePtr ParseNameConstant(const FunctionBlockPtr &, const py::object &node); 188 // Process a function call. 189 AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node); 190 // Process function 'super'. 191 AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args); 192 // Process the if expression. 193 AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node); 194 // get vector of getattr node from getattr map. 195 std::vector<AnfNodePtr> GetGetAttrVectotFromMap(const std::string &obj_name, const std::string &attr_name); 196 // get setattr node from setattr map. 197 AnfNodePtr GetSetAttrFromMap(const std::string &obj_name, const std::string &attr_name); 198 // make getattr node using interpret node as target 199 AnfNodePtr MakeGetAttrWithInterpret(const std::string &obj_name, const std::string &attr_name, 200 const py::object &getattr_obj, const FuncGraphPtr &cur_fg); 201 // Process class type define. 202 AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node); 203 // Process ms Tensor. 204 AnfNodePtr ParseMsTensor(const FunctionBlockPtr &block, const py::object &node, const py::object &value_body, 205 const AnfNodePtr &value_node); 206 // Process dtype._null. 207 AnfNodePtr ParseNull(const FunctionBlockPtr &block, const py::object &value_body) const; 208 // Process a compare expression. 209 AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node); 210 AnfNodePtr ParseSingleCompare(const FunctionBlockPtr &block, const py::object &left, const py::object &right, 211 const py::object &compare_op); 212 AnfNodePtr ConnectSingleCompare(const FunctionBlockPtr &block, const AnfNodePtr &left_node, 213 const AnfNodePtr &right_node); 214 // Process a bool operation. 215 AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node); 216 // Process a lambda operation. 217 AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node); 218 // Process a tuple. 219 AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node); 220 // Process a list. 221 AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node); 222 // Process a tuple or list. 223 AnfNodePtr ParseTupleOrList(const FunctionBlockPtr &block, const py::object &node, bool is_tuple); 224 // Process a tuple or list with starred expression. 225 AnfNodePtr ParseTupleOrListWithStarred(const FunctionBlockPtr &block, const py::object &node, bool is_tuple, 226 const std::vector<AnfNodePtr> &starred_flags); 227 // Process a subscript. 228 AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node); 229 // Process a slice. 230 AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node); 231 // Process a extslice. 232 AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node); 233 // Process a index. 234 AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node); 235 // Process a unaryop. 236 AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node); 237 // Process a dict ast node expression. 238 AnfNodePtr ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes, 239 const std::vector<AnfNodePtr> &value_nodes); 240 // Process a dict. 241 AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node); 242 243 std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> GetRealKeysValues(const FunctionBlockPtr &block, 244 const py::object &node); 245 std::pair<AnfNodePtr, AnfNodePtr> GetRealKeysValuesFromName(const FunctionBlockPtr &block, const py::object &node); 246 // Process DictComp expression. 247 AnfNodePtr ParseDictComp(const FunctionBlockPtr &block, const py::object &node); 248 FunctionBlockPtr ParseDictCompIter(const FunctionBlockPtr &block, const py::object &node, 249 const py::object &generator_node); 250 AnfNodePtr ParseDictCompIfs(const FunctionBlockPtr &dict_body_block, const ParameterPtr &dict_param, 251 const py::object &node, const py::object &generator_node); 252 // Process ListComp expression. 253 AnfNodePtr ParseListComp(const FunctionBlockPtr &block, const py::object &node); 254 FunctionBlockPtr ParseListCompIter(const FunctionBlockPtr &block, const py::object &node, 255 const py::object &generator_node); 256 AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param, 257 const py::object &node, const py::object &generator_node); 258 AnfNodePtr ParseJoinedStr(const FunctionBlockPtr &block, const py::object &node); 259 AnfNodePtr ParseFormattedValue(const FunctionBlockPtr &block, const py::object &node); 260 AnfNodePtr ParseStarred(const FunctionBlockPtr &block, const py::object &node); 261 std::vector<AnfNodePtr> HandleException(const FunctionBlockPtr &block, const py::list &args, const std::string &name); 262 std::vector<AnfNodePtr> ParseRaiseCall(const FunctionBlockPtr &block, const py::object &node); 263 void HandleStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *str_nodes); 264 265 bool GetBoolObjForAstCompare(const FunctionBlockPtr &block, const py::object &node, bool *bool_res) const; 266 py::object GetPyObjForAstAttr(const FunctionBlockPtr &block, const py::object &attr_ast_node, 267 bool *is_constant) const; 268 bool GetConstantConditionFromComment(const FunctionBlockPtr &block, const py::object &if_node, 269 bool *is_true_cond) const; 270 bool CheckConstantCondition(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond, 271 const py::object &if_node = py::none()) const; 272 273 FunctionBlockPtr MakeAssertErrorBlock(const FunctionBlockPtr &block, const py::object &node); 274 AnfNodePtr ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node) const; 275 276 // Transform tail call to parallel call. 277 void TransformParallelCall(); 278 void LiftRolledBodyGraphFV(); 279 void LiftIfBranchGraphFV(); 280 281 // Check if script_text is in global/local params. 282 bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict, 283 const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) const; 284 // Make interpret node. 285 AnfNodePtr MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node, const string &script_text); 286 // Check if the node need interpreting. 287 AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node, 288 const py::object &value_object); 289 290 bool CheckNeedConvertInterpret(const FunctionBlockPtr &block, const AnfNodePtr &node, 291 const string &script_text) const; 292 293 // Generate argument nodes for ast function node. 294 void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node); 295 // Generate argument default value for ast function node. 296 void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node); 297 // Parse ast function node. 298 FunctionBlockPtr ParseDefFunction(const py::object &node, const FunctionBlockPtr &block = nullptr); 299 // Parse lambda function node. 300 FunctionBlockPtr ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block = nullptr); 301 // Parse ast statements. 302 FunctionBlockPtr ParseStatements(const FunctionBlockPtr &block, const py::object &nodes); 303 // Parse one ast statement node. 304 FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node); 305 // Parse an ast expression node. 306 AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node); 307 308 void MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block, 309 const FunctionBlockPtr &false_block) const; 310 std::shared_ptr<std::map<ParameterPtr, AnfNodePtr>> CalRemovablePhis(); 311 void CreatePhiArgMaps(std::map<ParameterPtr, std::set<AnfNodePtr>> *phi_to_args, 312 std::map<AnfNodePtr, std::set<ParameterPtr>> *arg_to_phis); 313 static void PrintPhiArgMaps(const std::map<ParameterPtr, std::set<AnfNodePtr>> &phi_to_args, 314 const std::map<AnfNodePtr, std::set<ParameterPtr>> &arg_to_phis); 315 static void UpdatePhiArgMapsRepeatedly(std::map<ParameterPtr, std::set<AnfNodePtr>> *phi_to_args, 316 std::map<AnfNodePtr, std::set<ParameterPtr>> *arg_to_phis); 317 static std::shared_ptr<std::map<ParameterPtr, AnfNodePtr>> CollectRemovablePhiArgs( 318 const std::map<ParameterPtr, std::set<AnfNodePtr>> &phi_to_args); 319 void RemoveUnnecessaryPhis(const FuncGraphManagerPtr &manager); 320 void ConvertGetattrNodes(); 321 // Write a new var. 322 void WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &value_node); 323 324 // Create a setattr CNode. 325 void MakeSetAttrNode(const FunctionBlockPtr &block, const AnfNodePtr &target_node, const AnfNodePtr &value_node, 326 const std::string &target_id_str, const std::string &attr_str); 327 328 // Assign value to single variable name. 329 void HandleAssignName(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &assigned_node) const; 330 331 // Assign value to starred expression. 332 void HandleAssignStarred(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &assigned_node); 333 334 // Assign value to tuple. 335 void HandleAssignTupleOrList(const FunctionBlockPtr &block, const py::object &target, 336 const AnfNodePtr &assigned_node); 337 338 // Assign value to tuple with starred expression. 339 void HandleAssignTupleWithStarredExpression(const FunctionBlockPtr &block, const py::object &target, 340 const AnfNodePtr &assigned_node, const std::vector<int64_t> &positions); 341 342 // Assign value to class Parameter member. Return false if not a Parameter member. 343 bool HandleAssignClassParameterMember(const FunctionBlockPtr &block, const py::object &target, 344 const AnfNodePtr &value_node); 345 346 // Handle set attribute change for class member. 347 bool HandleSetAttrClassMemberForInplace(const FunctionBlockPtr &block, const AnfNodePtr &node); 348 349 // Assign value to class member. 350 void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &value_node); 351 352 // Assign value to subscript. 353 void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &assigned_node); 354 355 // Process a bool operation value list. 356 AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode); 357 358 void ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, ArgsContext *args_context); 359 360 void ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, ArgsContext *args_context); 361 AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node, 362 const ArgsContext &args_context) const; 363 ScopePtr GetScopeForParseFunction(); 364 // Check the value is subscript is reference type. 365 bool IsSubscriptReferenceType(const py::object &obj); 366 void BuildMethodMap(); 367 // Must add a TraceGuard before call it. 368 FunctionBlockPtr MakeFunctionBlock(); 369 FunctionBlockPtr MakeFunctionBlock(const TraceInfoPtr &trace_info); 370 // Return a make tuple for input elements list. 371 AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes); 372 // Check if the node is pop operation. 373 bool IsPopOperation(const AnfNodePtr &node) const; 374 // Check if branch block contains break/continue/return statement, and propagate that flag back to block. 375 void CheckControlFlowAlterationInIf(std::pair<FunctionBlockPtr, FunctionBlockPtr> *branch_graphs_pair, 376 const FunctionBlockPtr &branch_block, const FunctionBlockPtr &branch_end, 377 const FunctionBlockPtr &after_block, const FunctionBlockPtr &block) const; 378 // Check if body block contains return statement, and propagate that flag back to block. 379 void CheckReturnInLoop(const FunctionBlockPtr &block, const FunctionBlockPtr &body_block) const; 380 381 // Check whether the functions referred by this function and itself are missing 'return' statement. 382 void CheckFuncReturn(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fn); 383 384 // If the node is Parameter member of class. 385 bool IsClassParameterMember(const py::object &target_obj, const AnfNodePtr &target_node) const; 386 387 ValuePtr GetParameterValue(const AnfNodePtr ¶meter) const; 388 bool CheckAttributeConstantCond(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond) const; 389 bool CheckNameConstantCond(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond) const; 390 bool CheckUnaryOpConstantCond(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond) const; 391 bool CheckCompareConstantCond(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond) const; 392 bool CheckBoolOpConstantCond(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond) const; 393 bool CompareIs(const FunctionBlockPtr &, const py::object &left_obj, const py::object &comparator_obj, 394 bool *bool_res) const; 395 bool CompareIsNot(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj, 396 bool *bool_res) const; 397 bool CompareEqual(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj, 398 bool *bool_res) const; 399 bool CompareNotEqual(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj, 400 bool *bool_res) const; 401 bool CompareGreater(const FunctionBlockPtr &, const py::object &left_obj, const py::object &comparator_obj, 402 bool *bool_res) const; 403 bool CompareGreaterEqual(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj, 404 bool *bool_res) const; 405 bool CompareLess(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj, 406 bool *bool_res) const; 407 bool CompareLessEqual(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj, 408 bool *bool_res) const; 409 py::object GetValuePythonObject(const py::object &value_node); 410 CNodePtr MakeSetitemNode(const FunctionBlockPtr &block, const py::object &value_obj, const py::object &slice_obj, 411 const AnfNodePtr &assigned_node, const AnfNodePtr &value_node); 412 413 void ProcessPopOperation(const FunctionBlockPtr &block, const AnfNodePtr &value_node, 414 const py::object &target_object); 415 416 void ProcessPopOperationInAugAssign(const FunctionBlockPtr &block, const AnfNodePtr &value_node, 417 const AnfNodePtr &target_node, const AnfNodePtr &op_node, 418 const py::object &target_object); 419 420 // The shared_ptr will be hold by GraphManager, so just hold a weak ref here. 421 static FuncGraphWeakPtr top_func_graph_; 422 // Set if defer resolve during parsing. 423 inline static bool defer_resolve_{false}; 424 // Python function id, used to indicate whether two CNodes come from the same Python function. 425 const std::shared_ptr<ParseFunctionAst> &ast_; 426 FuncGraphPtr func_graph_; 427 // Error code setwhen parsing ast tree. 428 ParseStatusCode errcode_; 429 py::object list_pop_target_obj_; 430 431 // Hold all reference for FunctionBlock in this round of parsing, 432 // so in FunctionBlock class we can use FunctionBlock* in member 433 // pre_blocks_ and jumps_ to break reference cycle. 434 std::vector<FunctionBlockPtr> func_block_list_; 435 using StmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); 436 using ExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); 437 using CompareFunc = bool (Parser::*)(const FunctionBlockPtr &block, const py::object &left_obj, 438 const py::object &comparator_obj, bool *bool_res) const; 439 using ConditionFunc = bool (Parser::*)(const FunctionBlockPtr &block, const py::object &test_node, 440 bool *is_true_cond) const; 441 // Define the function map to parse ast Statement. 442 std::map<std::string, StmtFunc> stmt_method_map_; 443 // Define the function map to parse ast expression. 444 std::map<std::string, ExprFunc> expr_method_map_; 445 // Define the function map to parse compare expression. 446 std::map<std::string, CompareFunc> compare_method_map_; 447 // Define the function map to parse constant condition expression. 448 std::map<std::string, ConditionFunc> condition_method_map_; 449 // Save current loops to support 'continue', 'break' statement. 450 std::stack<Loop> loops_; 451 452 // The func graphs to transform tail call ir to independent call ir. 453 // Contains: {former_graph, middle_graph}, latter_graph is no need. 454 std::vector<std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>>> parallel_call_graphs_; 455 // The true branch and false branch call info. of if statement. 456 std::vector<std::tuple<CNodePtr, FunctionBlockPtr, FunctionBlockPtr>> if_branch_calls_; 457 // The rolled_body callers info. for later lifting operation. 458 std::vector<std::pair<CNodePtr, FunctionBlockPtr>> rolled_body_calls_; 459 460 // Record all setattr nodes and their targets and values. 461 std::map<std::string, std::map<std::string, AnfNodePtr>> setattr_nodes_map_; 462 // Record all getattr node for specific object and attribute name. 463 std::map<std::string, std::map<std::string, std::vector<AnfNodePtr>>> getattr_nodes_map_; 464 // Store the values for input args of top graph. 465 ValuePtrList args_value_list_; 466 }; 467 468 // AST node type define code to ast. 469 class AstNodeType { 470 public: AstNodeType(const py::object & node,const std::string & name,AstMainType type)471 AstNodeType(const py::object &node, const std::string &name, AstMainType type) 472 : node_(node), node_name_(name), main_type_(type) {} 473 ~AstNodeType()474 ~AstNodeType() {} 475 node_name()476 std::string node_name() const { return node_name_; } 477 node()478 py::object node() const { return node_; } 479 main_type()480 AstMainType main_type() const { return main_type_; } 481 482 private: 483 const py::object &node_; 484 const std::string node_name_; 485 AstMainType main_type_; 486 }; 487 488 using AstNodeTypePtr = std::shared_ptr<AstNodeType>; 489 490 // A helper class to parse python function. 491 class ParseFunctionAst { 492 public: ParseFunctionAst(const py::object & obj)493 explicit ParseFunctionAst(const py::object &obj) 494 : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {} 495 496 ~ParseFunctionAst() = default; 497 498 bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); 499 500 py::object GetAstNode(); 501 502 py::str GetAstNodeText(const py::object &node); 503 504 py::list GetArgs(const py::object &func_node); 505 506 py::list GetArgsDefaultValues(const py::object &func_node); 507 508 AstNodeTypePtr GetNodeType(const py::object &node); 509 510 AstSubType GetOpType(const py::object &node); 511 512 template <class... T> CallParserObjMethod(const std::string & method,const T &...args)513 py::object CallParserObjMethod(const std::string &method, const T &... args) { 514 return python_adapter::CallPyObjMethod(parser_, method, args...); 515 } 516 517 template <class... T> CallParseModFunction(const std::string & function,const T &...args)518 py::object CallParseModFunction(const std::string &function, const T &... args) { 519 return python_adapter::CallPyModFn(module_, function, args...); 520 } 521 function_name()522 const std::string &function_name() const { return function_name_; } 523 function_module()524 const std::string &function_module() const { return function_module_; } 525 function_filename()526 const std::string &function_filename() const { return function_filename_; } 527 function_line_offset()528 int64_t function_line_offset() const { return function_line_offset_; } 529 function()530 py::function function() { return function_; } 531 target_type()532 ParseTargetType target_type() const { return target_type_; } 533 obj()534 py::object obj() { return obj_; } 535 parser()536 py::object parser() { return parser_; } 537 module()538 py::object module() { return module_; } 539 ast_tree()540 py::object ast_tree() { return ast_tree_; } 541 542 bool IsClassMemberOfSelf(const py::object &node); 543 bool IsClassMemberRecursive(const py::object &node); 544 545 private: 546 // Save obj, eg: class instance or function. 547 py::object obj_; 548 549 // Function or class method. 550 py::function function_; 551 552 py::object ast_tokens_; 553 py::object ast_tree_; 554 py::object parser_; 555 py::module module_; 556 557 // Is function or method. 558 ParseTargetType target_type_; 559 560 std::string function_name_; 561 std::string function_module_; 562 std::string function_filename_; 563 int64_t function_line_offset_; 564 }; 565 566 // Update the graph flags. 567 bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph, bool is_construct_function = false); 568 569 // Update recomputed scope for the graph. 570 void UpdateRecomputeScope(const FuncGraphPtr &func_graph); 571 572 AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); 573 } // namespace parse 574 } // namespace mindspore 575 576 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_ 577