• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
21 
22 #include <limits>
23 #include <vector>
24 #include <string>
25 #include <map>
26 #include <set>
27 #include <stack>
28 #include <memory>
29 #include "utils/misc.h"
30 #include "ir/anf.h"
31 #include "pipeline/jit/parse/parse_base.h"
32 #include "pipeline/jit/parse/python_adapter.h"
33 #include "pipeline/jit/parse/function_block.h"
34 
35 namespace mindspore {
36 namespace parse {
37 // Parse status define
38 enum ParseStatusCode : int64_t {
39   PARSE_SUCCESS = 0,
40   PARSE_FUNCTION_IS_NULL,            // Python function is null
41   PARSE_PARAMETER_INVALID,           // Parameter is invalid
42   PARSE_NO_RETURN,                   // Function no return node
43   PARSE_NODE_TYPE_NO_MATCH,          // Ast node type is error
44   PARSE_NODE_TYPE_UNKNOWN,           // Node type is unknown
45   PARSE_NODE_METHOD_UNSUPPORTED,     // No method to parse the node
46   PARSE_DONT_RESOLVE_SYMBOL,         // Can't resolve the string
47   PARSE_NOT_SUPPORTED_COMPARE_EXPR,  // The comparison is not supported
48   PARSE_FAILURE = 0xFF
49 };
50 
51 // Max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
52 // will be sunk(i.e. not unrolled)
53 // NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were
54 //  not implemented, so here set MAX_FOR_LOOP_COUNT to int64_t max limit to override default value `600`. This will make
55 //  the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised
56 //  when function call depth exceeds the limit `context.get_context('max_call_depth')`.
57 const int64_t MAX_FOR_LOOP_COUNT = std::numeric_limits<int64_t>::max();
58 
59 class AstNodeType;
60 class ParseFunctionAst;
61 
62 // Save loop info for 'continue' and 'break' statements.
63 struct Loop {
64   // Loop header block.
65   FunctionBlockPtr header;
66   // Loop iterator node, used in 'for loop'.
67   AnfNodePtr iterator;
68   // Loop end block.
69   FunctionBlockPtr end;
70 
LoopLoop71   Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end)
72       : header(header), iterator(iterator), end(end) {}
73   ~Loop() = default;
74 };
75 
76 // Loop context for loop stack management.
77 class LoopContext {
78  public:
LoopContext(std::stack<Loop> * loops,const FunctionBlockPtr & header,const AnfNodePtr & iterator)79   LoopContext(std::stack<Loop> *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) {
80     loops_->emplace(header, iterator, nullptr);
81   }
~LoopContext()82   ~LoopContext() { loops_->pop(); }
EndBlock()83   const FunctionBlockPtr &EndBlock() const { return loops_->top().end; }
84 
85  private:
86   std::stack<Loop> *loops_;
87 };
88 
89 // Parser to parse python function
90 class Parser {
91  public:
92   explicit Parser(const std::shared_ptr<ParseFunctionAst> &ast);
93 
~Parser()94   ~Parser() {}
95   FuncGraphPtr ParseFuncGraph();
func_graph()96   FuncGraphPtr func_graph() const { return func_graph_; }
errcode()97   ParseStatusCode errcode() const { return errcode_; }
ast()98   std::shared_ptr<ParseFunctionAst> ast() const { return ast_; }
support_fallback()99   const std::string &support_fallback() const { return support_fallback_; }
100   // Get location info from the ast node
101   LocationPtr GetLocation(const py::object &node) const;
102   static void InitParserEnvironment(const py::object &obj);
103   static void CleanParserResource();
GetTopFuncGraph()104   static FuncGraphPtr GetTopFuncGraph() { return top_func_graph_.lock(); }
105   static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph);
106 
107  private:
108   // Process the stmt node method list
109   FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node);
110   // Parse expression
111   FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node);
112   // Process a if statement
113   FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node);
114   // Process a while statement
115   FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
116   // Process a for statement
117   FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
118   FunctionBlockPtr ParseForIter(const FunctionBlockPtr &block, const py::object &node);
119   FunctionBlockPtr ParseForLoop(const FunctionBlockPtr &block, const py::object &node);
120   // Process a function def statement
121   FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node);
122   // Process a augment assign
123   FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node);
124   // Process a global declaration
125   FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node);
126   // Process assign statement
127   FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node);
128   // Process break statement
129   FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node);
130   // Process continue statement
131   FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node);
132   // Process pass statement
133   FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node);
134 
135   // Process the expr and slice node method list
136   AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node);
137   // Process a variable name
138   AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
139   // Process NoneType
140   AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node);
141   // Process Ellipsis
142   AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node);
143   // Process a integer or float number
144   AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node);
145   // Process a string variable
146   AnfNodePtr ParseStr(const FunctionBlockPtr &block, const py::object &node);
147   // Process a Constant
148   AnfNodePtr ParseConstant(const FunctionBlockPtr &block, const py::object &node);
149   // Process a name
150   AnfNodePtr ParseNameConstant(const FunctionBlockPtr &block, const py::object &node);
151   // Process a function call
152   AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node);
153   // Process function 'super'
154   AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args);
155   // Process the if expression
156   AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node);
157   // Process class type define
158   AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node);
159   // Process a compare expression
160   AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node);
161   // Process a bool operation
162   AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node);
163   // Process a lambda operation
164   AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node);
165   // Process a tuple
166   AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node);
167   // Process a tuple
168   AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node);
169   // Process a tuple
170   AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node);
171   // Process a slice
172   AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node);
173   // Process a extslice
174   AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node);
175   // Process a tuple
176   AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node);
177   // Process a unaryop
178   AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node);
179   // Process a dict ast node expression
180   AnfNodePtr ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
181                                       const std::vector<AnfNodePtr> &value_nodes);
182   AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node);
183   // Process ListComp expression
184   AnfNodePtr ParseListComp(const FunctionBlockPtr &block, const py::object &node);
185   FunctionBlockPtr ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
186                                      const py::object &generator_node);
187   AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
188                               const py::object &node, const py::object &generator_node);
189 
190   // Check if the node need interpreting.
191   AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
192                              const py::object &value_object);
193 
194   // Generate argument nodes for ast  function node
195   void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &function_node);
196   // Generate argument default value for ast  function node
197   void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &function_node);
198   // Parse ast function node
199   FunctionBlockPtr ParseFunction(const py::object &function_node, const FunctionBlockPtr &block = nullptr);
200   // Parse ast statements
201   FunctionBlockPtr ParseStatements(FunctionBlockPtr block, const py::object &stmt_node);
202   // Parse one ast statement node
203   FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node);
204   // Parse an ast expression node
205   AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node);
206 
207   void MakeConditionBlocks(const FunctionBlockPtr &block, const FunctionBlockPtr &trueBlock,
208                            const FunctionBlockPtr &falseBlock);
209   void RemoveUnnecessaryPhis();
210   // Write a new var
211   void WriteAssignVars(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &value_node);
212 
213   // Assign value to single variable name
214   void HandleAssignName(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
215 
216   // Assign value to tuple
217   void HandleAssignTuple(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
218 
219   // Assign value to class member
220   void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
221 
222   // Assign value to subscript
223   void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, const AnfNodePtr &assigned_node);
224 
225   // Process a bool operation value list
226   AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
227 
228   CNodePtr GenerateIteratorInFor(const FunctionBlockPtr &block, const pybind11::object &node,
229                                  const AnfNodePtr &op_iter);
230 
231   CNodePtr GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block,
232                              const AnfNodePtr &op_hasnext);
233 
234   FunctionBlockPtr GenerateBlock(const TraceInfoPtr &trace_info);
235 
236   bool ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node,
237                            std::vector<AnfNodePtr> *packed_arguments);
238 
239   bool ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *packed_arguments,
240                        std::vector<AnfNodePtr> *group_arguments);
241 
242   AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
243                                     const std::vector<AnfNodePtr> &packed_arguments,
244                                     const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const;
245   ScopePtr GetScopeForParseFunction();
246   void BuildMethodMap();
MakeFunctionBlock(const Parser & parse)247   FunctionBlockPtr MakeFunctionBlock(const Parser &parse) {
248     FunctionBlockPtr block = std::make_shared<FunctionBlock>(parse);
249     // In order to keep effect order in the sub-graphs which generated by control flow.
250     // We copy the flags from the top graph to the sub-graphs.
251     if (func_graph_ && !func_graph_->attrs().empty()) {
252       for (const auto &attr : func_graph_->attrs()) {
253         // The flag FUNC_GRAPH_OUTPUT_NO_RECOMPUTE should be only set in the top graph.
254         if (attr.first != FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) {
255           block->func_graph()->set_attr(attr.first, attr.second);
256         }
257       }
258     }
259     func_block_list_.push_back(block);
260     return block;
261   }
262   // Return a make tuple for input elements list
263   AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
264   int64_t GetForTransToWhileLoop();
265 
266   // The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
267   static FuncGraphWeakPtr top_func_graph_;
268   // Python function id, used to indicate whether two CNodes come from the same Python function
269   const std::shared_ptr<ParseFunctionAst> &ast_;
270   FuncGraphPtr func_graph_;
271   // Error code setwhen parsing ast tree
272   ParseStatusCode errcode_;
273 
274   // Hold all reference for FunctionBlock in this round of parsing,
275   // so in FunctionBlock class we can use FunctionBlock* in member
276   // pre_blocks_ and jumps_ to break reference cycle.
277   std::vector<FunctionBlockPtr> func_block_list_;
278   using pStmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
279   using pExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
280   // Define the function map to parse ast Statement
281   std::map<std::string, pStmtFunc> stmt_method_map_;
282   // Define the function map to parse ast expression
283   std::map<std::string, pExprFunc> expr_method_map_;
284   // Save current loops to support 'continue', 'break' statement.
285   std::stack<Loop> loops_;
286   string max_for_loop_count_str_;
287   string support_fallback_;
288 };
289 
290 // AST node type define code to ast
291 class AstNodeType {
292  public:
AstNodeType(const py::object & node,const std::string & name,AstMainType type)293   AstNodeType(const py::object &node, const std::string &name, AstMainType type)
294       : node_(node), node_name_(name), main_type_(type) {}
295 
~AstNodeType()296   ~AstNodeType() {}
297 
node_name()298   std::string node_name() const { return node_name_; }
299 
node()300   py::object node() const { return node_; }
301 
main_type()302   AstMainType main_type() const { return main_type_; }
303 
304  private:
305   const py::object &node_;
306   const std::string node_name_;
307   AstMainType main_type_;
308 };
309 
310 using AstNodeTypePtr = std::shared_ptr<AstNodeType>;
311 
312 // A helper class to parse python function
313 class ParseFunctionAst {
314  public:
ParseFunctionAst(const py::object & obj)315   explicit ParseFunctionAst(const py::object &obj)
316       : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {}
317 
318   ~ParseFunctionAst() = default;
319 
320   bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
321 
322   py::object GetAstNode();
323 
324   py::str GetAstNodeText(const py::object &node);
325 
326   py::list GetArgs(const py::object &func_node);
327 
328   py::list GetArgsDefaultValues(const py::object &func_node);
329 
330   AstNodeTypePtr GetNodeType(const py::object &node);
331 
332   AstSubType GetOpType(const py::object &node);
333 
334   template <class... T>
CallParserObjMethod(const std::string & method,const T &...args)335   py::object CallParserObjMethod(const std::string &method, const T &... args) {
336     return python_adapter::CallPyObjMethod(parser_, method, args...);
337   }
338 
339   template <class... T>
CallParseModFunction(const std::string & function,const T &...args)340   py::object CallParseModFunction(const std::string &function, const T &... args) {
341     return python_adapter::CallPyModFn(module_, function, args...);
342   }
343 
function_name()344   const std::string &function_name() const { return function_name_; }
345 
function_module()346   const std::string &function_module() const { return function_module_; }
347 
function_filename()348   const std::string &function_filename() const { return function_filename_; }
349 
function_line_offset()350   int64_t function_line_offset() const { return function_line_offset_; }
351 
function()352   py::function function() { return function_; }
353 
target_type()354   ParseTargetTypeDef target_type() const { return target_type_; }
355 
obj()356   py::object obj() { return obj_; }
357 
parser()358   py::object parser() { return parser_; }
359 
module()360   py::object module() { return module_; }
361 
ast_tree()362   py::object ast_tree() { return ast_tree_; }
363 
364   bool IsClassMember(const py::object &node);
365 
366  private:
367   // Save obj,eg: class instance or function
368   py::object obj_;
369 
370   // Function or class method.
371   py::function function_;
372 
373   py::object ast_tokens_;
374   py::object ast_tree_;
375   py::object parser_;
376   py::module module_;
377 
378   // Is function or method
379   ParseTargetTypeDef target_type_;
380 
381   std::string function_name_;
382   std::string function_module_;
383   std::string function_filename_;
384   int64_t function_line_offset_;
385 };
386 
387 // Update the graph flags
388 bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph);
389 
390 AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
391 TypePtr GetMixedPrecisionTargetType(const FuncGraphPtr &func_graph);
392 
393 }  // namespace parse
394 }  // namespace mindspore
395 
396 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
397