1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_ 20 #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_ 21 22 #include <limits> 23 #include <vector> 24 #include <string> 25 #include <map> 26 #include <set> 27 #include <stack> 28 #include <memory> 29 #include "utils/misc.h" 30 #include "ir/anf.h" 31 #include "pipeline/jit/parse/parse_base.h" 32 #include "pipeline/jit/parse/python_adapter.h" 33 #include "pipeline/jit/parse/function_block.h" 34 35 namespace mindspore { 36 namespace parse { 37 // Parse status define 38 enum ParseStatusCode : int64_t { 39 PARSE_SUCCESS = 0, 40 PARSE_FUNCTION_IS_NULL, // Python function is null 41 PARSE_PARAMETER_INVALID, // Parameter is invalid 42 PARSE_NO_RETURN, // Function no return node 43 PARSE_NODE_TYPE_NO_MATCH, // Ast node type is error 44 PARSE_NODE_TYPE_UNKNOWN, // Node type is unknown 45 PARSE_NODE_METHOD_UNSUPPORTED, // No method to parse the node 46 PARSE_DONT_RESOLVE_SYMBOL, // Can't resolve the string 47 PARSE_NOT_SUPPORTED_COMPARE_EXPR, // The comparison is not supported 48 PARSE_FAILURE = 0xFF 49 }; 50 51 // Max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it 52 // will be sunk(i.e. not unrolled) 53 // NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were 54 // not implemented, so here set MAX_FOR_LOOP_COUNT to int64_t max limit to override default value `600`. This will make 55 // the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised 56 // when function call depth exceeds the limit `context.get_context('max_call_depth')`. 57 const int64_t MAX_FOR_LOOP_COUNT = std::numeric_limits<int64_t>::max(); 58 59 class AstNodeType; 60 class ParseFunctionAst; 61 62 // Save loop info for 'continue' and 'break' statements. 63 struct Loop { 64 // Loop header block. 65 FunctionBlockPtr header; 66 // Loop iterator node, used in 'for loop'. 67 AnfNodePtr iterator; 68 // Loop end block. 69 FunctionBlockPtr end; 70 LoopLoop71 Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end) 72 : header(header), iterator(iterator), end(end) {} 73 ~Loop() = default; 74 }; 75 76 // Loop context for loop stack management. 77 class LoopContext { 78 public: LoopContext(std::stack<Loop> * loops,const FunctionBlockPtr & header,const AnfNodePtr & iterator)79 LoopContext(std::stack<Loop> *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) { 80 loops_->emplace(header, iterator, nullptr); 81 } ~LoopContext()82 ~LoopContext() { loops_->pop(); } EndBlock()83 const FunctionBlockPtr &EndBlock() const { return loops_->top().end; } 84 85 private: 86 std::stack<Loop> *loops_; 87 }; 88 89 // Parser to parse python function 90 class Parser { 91 public: 92 explicit Parser(const std::shared_ptr<ParseFunctionAst> &ast); 93 ~Parser()94 ~Parser() {} 95 FuncGraphPtr ParseFuncGraph(); func_graph()96 FuncGraphPtr func_graph() const { return func_graph_; } errcode()97 ParseStatusCode errcode() const { return errcode_; } ast()98 std::shared_ptr<ParseFunctionAst> ast() const { return ast_; } support_fallback()99 const std::string &support_fallback() const { return support_fallback_; } 100 // Get location info from the ast node 101 LocationPtr GetLocation(const py::object &node) const; 102 static void InitParserEnvironment(const py::object &obj); 103 static void CleanParserResource(); GetTopFuncGraph()104 static FuncGraphPtr GetTopFuncGraph() { return top_func_graph_.lock(); } 105 static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph); 106 107 private: 108 // Process the stmt node method list 109 FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node); 110 // Parse expression 111 FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node); 112 // Process a if statement 113 FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node); 114 // Process a while statement 115 FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node); 116 // Process a for statement 117 FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node); 118 FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node); 119 FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node); 120 // Process a function def statement 121 FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node); 122 // Process a augment assign 123 FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node); 124 // Process a global declaration 125 FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node); 126 // Process assign statement 127 FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node); 128 // Process break statement 129 FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node); 130 // Process continue statement 131 FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node); 132 // Process pass statement 133 FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node); 134 135 // Process the expr and slice node method list 136 AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node); 137 // Process a variable name 138 AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); 139 // Process NoneType 140 AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node); 141 // Process Ellipsis 142 AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node); 143 // Process a integer or float number 144 AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); 145 // Process a string variable 146 AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node); 147 // Process a Constant 148 AnfNodePtr ParseConstant(const FunctionBlockPtr &block, const py::object &node); 149 // Process a name 150 AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node); 151 // Process a function call 152 AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node); 153 // Process function 'super' 154 AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args); 155 // Process the if expression 156 AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node); 157 // Process class type define 158 AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node); 159 // Process a compare expression 160 AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node); 161 // Process a bool operation 162 AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node); 163 // Process a lambda operation 164 AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node); 165 // Process a tuple 166 AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node); 167 // Process a tuple 168 AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node); 169 // Process a tuple 170 AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node); 171 // Process a slice 172 AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node); 173 // Process a extslice 174 AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node); 175 // Process a tuple 176 AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node); 177 // Process a unaryop 178 AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node); 179 // Process a dict ast node expression 180 AnfNodePtr ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes, 181 const std::vector<AnfNodePtr> &value_nodes); 182 AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node); 183 // Process ListComp expression 184 AnfNodePtr ParseListComp(const FunctionBlockPtr &block, const py::object &node); 185 FunctionBlockPtr ParseListCompIter(const FunctionBlockPtr &block, const py::object &node, 186 const py::object &generator_node); 187 AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param, 188 const py::object &node, const py::object &generator_node); 189 190 // Check if the node need interpreting. 191 AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node, 192 const py::object &value_object); 193 194 // Generate argument nodes for ast function node 195 void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node); 196 // Generate argument default value for ast function node 197 void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node); 198 // Parse ast function node 199 FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr); 200 // Parse ast statements 201 FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node); 202 // Parse one ast statement node 203 FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node); 204 // Parse an ast expression node 205 AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node); 206 207 void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock, 208 const FunctionBlockPtr &falseBlock); 209 void RemoveUnnecessaryPhis(); 210 // Write a new var 211 void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node); 212 213 // Assign value to single variable name 214 void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); 215 216 // Assign value to tuple 217 void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); 218 219 // Assign value to class member 220 void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); 221 222 // Assign value to subscript 223 void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node); 224 225 // Process a bool operation value list 226 AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode); 227 228 CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node, 229 const AnfNodePtr &op_iter); 230 231 CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, 232 const AnfNodePtr &op_hasnext); 233 234 FunctionBlockPtr GenerateBlock(const TraceInfoPtr &trace_info); 235 236 bool ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, 237 std::vector<AnfNodePtr> *packed_arguments); 238 239 bool ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *packed_arguments, 240 std::vector<AnfNodePtr> *group_arguments); 241 242 AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, 243 const std::vector<AnfNodePtr> &packed_arguments, 244 const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const; 245 ScopePtr GetScopeForParseFunction(); 246 void BuildMethodMap(); MakeFunctionBlock(const Parser & parse)247 FunctionBlockPtr MakeFunctionBlock(const Parser &parse) { 248 FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse); 249 // In order to keep effect order in the sub-graphs which generated by control flow. 250 // We copy the flags from the top graph to the sub-graphs. 251 if (func_graph_ && !func_graph_->attrs().empty()) { 252 for (const auto &attr : func_graph_->attrs()) { 253 // The flag FUNC_GRAPH_OUTPUT_NO_RECOMPUTE should be only set in the top graph. 254 if (attr.first != FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) { 255 block->func_graph()->set_attr(attr.first, attr.second); 256 } 257 } 258 } 259 func_block_list_.push_back(block); 260 return block; 261 } 262 // Return a make tuple for input elements list 263 AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes); 264 int64_t GetForTransToWhileLoop(); 265 266 // The shared_ptr will be hold by GraphManager, so just hold a weak ref here. 267 static FuncGraphWeakPtr top_func_graph_; 268 // Python function id, used to indicate whether two CNodes come from the same Python function 269 const std::shared_ptr<ParseFunctionAst> &ast_; 270 FuncGraphPtr func_graph_; 271 // Error code setwhen parsing ast tree 272 ParseStatusCode errcode_; 273 274 // Hold all reference for FunctionBlock in this round of parsing, 275 // so in FunctionBlock class we can use FunctionBlock* in member 276 // pre_blocks_ and jumps_ to break reference cycle. 277 std::vector<FunctionBlockPtr> func_block_list_; 278 using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); 279 using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node); 280 // Define the function map to parse ast Statement 281 std::map<std::string, pStmtFunc> stmt_method_map_; 282 // Define the function map to parse ast expression 283 std::map<std::string, pExprFunc> expr_method_map_; 284 // Save current loops to support 'continue', 'break' statement. 285 std::stack<Loop> loops_; 286 string max_for_loop_count_str_; 287 string support_fallback_; 288 }; 289 290 // AST node type define code to ast 291 class AstNodeType { 292 public: AstNodeType(const py::object & node,const std::string & name,AstMainType type)293 AstNodeType(const py::object &node, const std::string &name, AstMainType type) 294 : node_(node), node_name_(name), main_type_(type) {} 295 ~AstNodeType()296 ~AstNodeType() {} 297 node_name()298 std::string node_name() const { return node_name_; } 299 node()300 py::object node() const { return node_; } 301 main_type()302 AstMainType main_type() const { return main_type_; } 303 304 private: 305 const py::object &node_; 306 const std::string node_name_; 307 AstMainType main_type_; 308 }; 309 310 using AstNodeTypePtr = std::shared_ptr<AstNodeType>; 311 312 // A helper class to parse python function 313 class ParseFunctionAst { 314 public: ParseFunctionAst(const py::object & obj)315 explicit ParseFunctionAst(const py::object &obj) 316 : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {} 317 318 ~ParseFunctionAst() = default; 319 320 bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); 321 322 py::object GetAstNode(); 323 324 py::str GetAstNodeText(const py::object &node); 325 326 py::list GetArgs(const py::object &func_node); 327 328 py::list GetArgsDefaultValues(const py::object &func_node); 329 330 AstNodeTypePtr GetNodeType(const py::object &node); 331 332 AstSubType GetOpType(const py::object &node); 333 334 template <class... T> CallParserObjMethod(const std::string & method,const T &...args)335 py::object CallParserObjMethod(const std::string &method, const T &... args) { 336 return python_adapter::CallPyObjMethod(parser_, method, args...); 337 } 338 339 template <class... T> CallParseModFunction(const std::string & function,const T &...args)340 py::object CallParseModFunction(const std::string &function, const T &... args) { 341 return python_adapter::CallPyModFn(module_, function, args...); 342 } 343 function_name()344 const std::string &function_name() const { return function_name_; } 345 function_module()346 const std::string &function_module() const { return function_module_; } 347 function_filename()348 const std::string &function_filename() const { return function_filename_; } 349 function_line_offset()350 int64_t function_line_offset() const { return function_line_offset_; } 351 function()352 py::function function() { return function_; } 353 target_type()354 ParseTargetTypeDef target_type() const { return target_type_; } 355 obj()356 py::object obj() { return obj_; } 357 parser()358 py::object parser() { return parser_; } 359 module()360 py::object module() { return module_; } 361 ast_tree()362 py::object ast_tree() { return ast_tree_; } 363 364 bool IsClassMember(const py::object &node); 365 366 private: 367 // Save obj,eg: class instance or function 368 py::object obj_; 369 370 // Function or class method. 371 py::function function_; 372 373 py::object ast_tokens_; 374 py::object ast_tree_; 375 py::object parser_; 376 py::module module_; 377 378 // Is function or method 379 ParseTargetTypeDef target_type_; 380 381 std::string function_name_; 382 std::string function_module_; 383 std::string function_filename_; 384 int64_t function_line_offset_; 385 }; 386 387 // Update the graph flags 388 bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph); 389 390 AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); 391 TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph); 392 393 } // namespace parse 394 } // namespace mindspore 395 396 #endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_ 397