• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2023 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
20 #define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
21 
22 #include <limits>
23 #include <utility>
24 #include <tuple>
25 #include <vector>
26 #include <string>
27 #include <map>
28 #include <set>
29 #include <stack>
30 #include <memory>
31 #include "utils/misc.h"
32 #include "ir/anf.h"
33 #include "pipeline/jit/ps/parse/parse_base.h"
34 #include "include/common/utils/python_adapter.h"
35 #include "pipeline/jit/ps/parse/function_block.h"
36 
37 namespace mindspore {
38 namespace parse {
39 // Parse status define.
40 enum ParseStatusCode : int64_t {
41   PARSE_SUCCESS = 0,
42   PARSE_FUNCTION_IS_NULL,            // Python function is null
43   PARSE_PARAMETER_INVALID,           // Parameter is invalid
44   PARSE_NO_RETURN,                   // Function no return node
45   PARSE_NODE_TYPE_NO_MATCH,          // Ast node type is error
46   PARSE_NODE_TYPE_UNKNOWN,           // Node type is unknown
47   PARSE_NODE_METHOD_UNSUPPORTED,     // No method to parse the node
48   PARSE_DONT_RESOLVE_SYMBOL,         // Can't resolve the string
49   PARSE_NOT_SUPPORTED_COMPARE_EXPR,  // The comparison is not supported
50   PARSE_FAILURE = 0xFF
51 };
52 
53 constexpr char kStandardMethodModelName[] = "mindspore._extends.parse.standard_method";
54 
55 // Max loop count of for statement, when loop count is less then this value, the for loop will be unrolled, otherwise it
56 // will be sunk(i.e. not unrolled)
57 // NOTE: Since when the for loop was unrolled, it depends backend operators `tuple_getitem` and `scalar_add` which were
58 //  not implemented, so here set MAX_FOR_LOOP_COUNT to int64_t max limit to override default value `600`. This will make
59 //  the for loop will always be unrolled, but don't worry about the memory were exhausted, an exception will be raised
60 //  when function call depth exceeds the limit `context.get_context('max_call_depth')`.
61 const int64_t MAX_FOR_LOOP_COUNT = std::numeric_limits<int64_t>::max();
62 
63 class AstNodeType;
64 class ParseFunctionAst;
65 
66 // Save loop info for 'continue' and 'break' statements.
67 struct Loop {
68   // Loop header block.
69   FunctionBlockPtr header;
70   // Loop iterator node, used in 'for loop'.
71   AnfNodePtr iterator;
72   // Loop end block.
73   FunctionBlockPtr end;
74 
LoopLoop75   Loop(const FunctionBlockPtr &header, const AnfNodePtr &iterator, const FunctionBlockPtr &end)
76       : header(header), iterator(iterator), end(end) {}
77   ~Loop() = default;
78 };
79 
80 // Loop context for loop stack management.
81 class LoopContext {
82  public:
LoopContext(std::stack<Loop> * loops,const FunctionBlockPtr & header,const AnfNodePtr & iterator)83   LoopContext(std::stack<Loop> *loops, const FunctionBlockPtr &header, const AnfNodePtr &iterator) : loops_(loops) {
84     loops_->emplace(header, iterator, nullptr);
85   }
~LoopContext()86   ~LoopContext() {
87     try {
88       MS_EXCEPTION_IF_NULL(loops_);
89       loops_->pop();
90     } catch (const std::exception &e) {
91       MS_LOG(ERROR) << "Exception when pop. Error info " << e.what();
92     } catch (...) {
93       MS_LOG(ERROR) << "Throw exception when pop.";
94     }
95     loops_ = nullptr;
96   }
97 
EndBlock()98   const FunctionBlockPtr &EndBlock() const { return loops_->top().end; }
99 
100  private:
101   std::stack<Loop> *loops_;
102 };
103 
104 struct ArgsContext {
105   bool need_unpack{false};
106   bool has_interpret_without_internal{false};
107   bool has_interpret_internal{false};
108 
109   std::vector<AnfNodePtr> packed_arguments;
110   std::vector<AnfNodePtr> group_arguments;
ArgsContextArgsContext111   ArgsContext() {}
112   ~ArgsContext() = default;
113 };
114 
115 // Parser to parse python function.
116 class Parser {
117  public:
118   explicit Parser(const std::shared_ptr<ParseFunctionAst> &ast, ValuePtrList args_value_list);
119 
~Parser()120   ~Parser() {}
121   FuncGraphPtr ParseFuncGraph();
func_graph()122   FuncGraphPtr func_graph() const { return func_graph_; }
errcode()123   ParseStatusCode errcode() const { return errcode_; }
ast()124   std::shared_ptr<ParseFunctionAst> ast() const { return ast_; }
125   // Get location info from the ast node.
126   LocationPtr GetLocation(const py::object &node) const;
127   static void InitParserEnvironment(const py::object &obj);
128   static void CleanParserResource();
GetTopFuncGraph()129   static FuncGraphPtr GetTopFuncGraph() { return top_func_graph_.lock(); }
130   static void UpdateTopFuncGraph(const FuncGraphPtr &func_graph);
EnableDeferResolve(bool enabled)131   static void EnableDeferResolve(bool enabled) { defer_resolve_ = enabled; }
defer_resolve()132   static bool defer_resolve() { return defer_resolve_; }
133 
134  private:
135   // Process the stmt node method list.
136   FunctionBlockPtr ParseReturn(const FunctionBlockPtr &block, const py::object &node);
137   // Parse expression.
138   FunctionBlockPtr ParseExpr(const FunctionBlockPtr &block, const py::object &node);
139   // Process a if statement.
140   FunctionBlockPtr ParseIf(const FunctionBlockPtr &block, const py::object &node);
141   // Process a while statement.
142   FunctionBlockPtr ParseWhile(const FunctionBlockPtr &block, const py::object &node);
143   // Process a for statement.
144   FunctionBlockPtr ParseFor(const FunctionBlockPtr &block, const py::object &node);
145   FunctionBlockPtr ParseForUnroll(const FunctionBlockPtr &block, const py::object &node);
146   FunctionBlockPtr ParseForRepeat(const FunctionBlockPtr &block, const py::object &node);
147   // Process a function def statement.
148   FunctionBlockPtr ParseFunctionDef(const FunctionBlockPtr &block, const py::object &node);
149   // Process a augment assign.
150   FunctionBlockPtr ParseAugAssign(const FunctionBlockPtr &block, const py::object &node);
151   // Process a global declaration.
152   FunctionBlockPtr ParseGlobal(const FunctionBlockPtr &block, const py::object &node);
153   // Process assign statement.
154   FunctionBlockPtr ParseAssign(const FunctionBlockPtr &block, const py::object &node);
155   // Process annassign statement.
156   FunctionBlockPtr ParseAnnAssign(const FunctionBlockPtr &block, const py::object &node);
157   // Process break statement.
158   FunctionBlockPtr ParseBreak(const FunctionBlockPtr &block, const py::object &node);
159   // Process continue statement.
160   FunctionBlockPtr ParseContinue(const FunctionBlockPtr &block, const py::object &node);
161   // Process pass statement.
162   FunctionBlockPtr ParsePass(const FunctionBlockPtr &block, const py::object &node);
163   // Process raise statement.
164   FunctionBlockPtr ParseRaise(const FunctionBlockPtr &block, const py::object &node);
165   // Process assert statement.
166   FunctionBlockPtr ParseAssert(const FunctionBlockPtr &block, const py::object &node);
167   // Process with statement.
168   FunctionBlockPtr ParseWith(const FunctionBlockPtr &block, const py::object &node);
169 
170   // Process withitem.
171   AnfNodePtr ParseWithitem(const FunctionBlockPtr &block, const py::object &node, const AnfNodePtr &context_expr_node);
172   // Process the expr and slice node method list.
173   AnfNodePtr ParseBinOp(const FunctionBlockPtr &block, const py::object &node);
174   // Process a variable name.
175   AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node);
176   // Process NoneType.
177   AnfNodePtr ParseNone(const FunctionBlockPtr &, const py::object &);
178   // Process Ellipsis.
179   AnfNodePtr ParseEllipsis(const FunctionBlockPtr &, const py::object &);
180   // Process an integer or float number.
181   AnfNodePtr ParseNum(const FunctionBlockPtr &, const py::object &node);
182   // Process a string variable.
183   AnfNodePtr ParseStr(const FunctionBlockPtr &, const py::object &node);
184   // Process a Constant.
185   AnfNodePtr ParseConstant(const FunctionBlockPtr &, const py::object &node);
186   // Process a name.
187   AnfNodePtr ParseNameConstant(const FunctionBlockPtr &, const py::object &node);
188   // Process a function call.
189   AnfNodePtr ParseCall(const FunctionBlockPtr &block, const py::object &node);
190   // Process function 'super'.
191   AnfNodePtr ParseSuper(const FunctionBlockPtr &block, const py::list &args);
192   // Process the if expression.
193   AnfNodePtr ParseIfExp(const FunctionBlockPtr &block, const py::object &node);
194   // get vector of getattr node from getattr map.
195   std::vector<AnfNodePtr> GetGetAttrVectotFromMap(const std::string &obj_name, const std::string &attr_name);
196   // get setattr node from setattr map.
197   AnfNodePtr GetSetAttrFromMap(const std::string &obj_name, const std::string &attr_name);
198   // make getattr node using interpret node as target
199   AnfNodePtr MakeGetAttrWithInterpret(const std::string &obj_name, const std::string &attr_name,
200                                       const py::object &getattr_obj, const FuncGraphPtr &cur_fg);
201   // Process class type define.
202   AnfNodePtr ParseAttribute(const FunctionBlockPtr &block, const py::object &node);
203   // Process ms Tensor.
204   AnfNodePtr ParseMsTensor(const FunctionBlockPtr &block, const py::object &node, const py::object &value_body,
205                            const AnfNodePtr &value_node);
206   // Process dtype._null.
207   AnfNodePtr ParseNull(const FunctionBlockPtr &block, const py::object &value_body) const;
208   // Process a compare expression.
209   AnfNodePtr ParseCompare(const FunctionBlockPtr &block, const py::object &node);
210   AnfNodePtr ParseSingleCompare(const FunctionBlockPtr &block, const py::object &left, const py::object &right,
211                                 const py::object &compare_op);
212   AnfNodePtr ConnectSingleCompare(const FunctionBlockPtr &block, const AnfNodePtr &left_node,
213                                   const AnfNodePtr &right_node);
214   // Process a bool operation.
215   AnfNodePtr ParseBoolOp(const FunctionBlockPtr &block, const py::object &node);
216   // Process a lambda operation.
217   AnfNodePtr ParseLambda(const FunctionBlockPtr &block, const py::object &node);
218   // Process a tuple.
219   AnfNodePtr ParseTuple(const FunctionBlockPtr &block, const py::object &node);
220   // Process a list.
221   AnfNodePtr ParseList(const FunctionBlockPtr &block, const py::object &node);
222   // Process a tuple or list.
223   AnfNodePtr ParseTupleOrList(const FunctionBlockPtr &block, const py::object &node, bool is_tuple);
224   // Process a tuple or list with starred expression.
225   AnfNodePtr ParseTupleOrListWithStarred(const FunctionBlockPtr &block, const py::object &node, bool is_tuple,
226                                          const std::vector<AnfNodePtr> &starred_flags);
227   // Process a subscript.
228   AnfNodePtr ParseSubscript(const FunctionBlockPtr &block, const py::object &node);
229   // Process a slice.
230   AnfNodePtr ParseSlice(const FunctionBlockPtr &block, const py::object &node);
231   // Process a extslice.
232   AnfNodePtr ParseExtSlice(const FunctionBlockPtr &block, const py::object &node);
233   // Process a index.
234   AnfNodePtr ParseIndex(const FunctionBlockPtr &block, const py::object &node);
235   // Process a unaryop.
236   AnfNodePtr ParseUnaryOp(const FunctionBlockPtr &block, const py::object &node);
237   // Process a dict ast node expression.
238   AnfNodePtr ParseDictByKeysAndValues(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &key_nodes,
239                                       const std::vector<AnfNodePtr> &value_nodes);
240   // Process a dict.
241   AnfNodePtr ParseDict(const FunctionBlockPtr &block, const py::object &node);
242 
243   std::pair<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> GetRealKeysValues(const FunctionBlockPtr &block,
244                                                                                 const py::object &node);
245   std::pair<AnfNodePtr, AnfNodePtr> GetRealKeysValuesFromName(const FunctionBlockPtr &block, const py::object &node);
246   // Process DictComp expression.
247   AnfNodePtr ParseDictComp(const FunctionBlockPtr &block, const py::object &node);
248   FunctionBlockPtr ParseDictCompIter(const FunctionBlockPtr &block, const py::object &node,
249                                      const py::object &generator_node);
250   AnfNodePtr ParseDictCompIfs(const FunctionBlockPtr &dict_body_block, const ParameterPtr &dict_param,
251                               const py::object &node, const py::object &generator_node);
252   // Process ListComp expression.
253   AnfNodePtr ParseListComp(const FunctionBlockPtr &block, const py::object &node);
254   FunctionBlockPtr ParseListCompIter(const FunctionBlockPtr &block, const py::object &node,
255                                      const py::object &generator_node);
256   AnfNodePtr ParseListCompIfs(const FunctionBlockPtr &list_body_block, const ParameterPtr &list_param,
257                               const py::object &node, const py::object &generator_node);
258   AnfNodePtr ParseJoinedStr(const FunctionBlockPtr &block, const py::object &node);
259   AnfNodePtr ParseFormattedValue(const FunctionBlockPtr &block, const py::object &node);
260   AnfNodePtr ParseStarred(const FunctionBlockPtr &block, const py::object &node);
261   std::vector<AnfNodePtr> HandleException(const FunctionBlockPtr &block, const py::list &args, const std::string &name);
262   std::vector<AnfNodePtr> ParseRaiseCall(const FunctionBlockPtr &block, const py::object &node);
263   void HandleStrInError(const FunctionBlockPtr &block, const py::list &args, std::vector<AnfNodePtr> *str_nodes);
264 
265   bool GetBoolObjForAstCompare(const FunctionBlockPtr &block, const py::object &node, bool *bool_res) const;
266   py::object GetPyObjForAstAttr(const FunctionBlockPtr &block, const py::object &attr_ast_node,
267                                 bool *is_constant) const;
268   bool GetConstantConditionFromComment(const FunctionBlockPtr &block, const py::object &if_node,
269                                        bool *is_true_cond) const;
270   bool CheckConstantCondition(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond,
271                               const py::object &if_node = py::none()) const;
272 
273   FunctionBlockPtr MakeAssertErrorBlock(const FunctionBlockPtr &block, const py::object &node);
274   AnfNodePtr ProcessAttributeWithClassMember(const FunctionBlockPtr &block, const py::object &node) const;
275 
276   // Transform tail call to parallel call.
277   void TransformParallelCall();
278   void LiftRolledBodyGraphFV();
279   void LiftIfBranchGraphFV();
280 
281   // Check if script_text is in global/local params.
282   bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict,
283                         const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) const;
284   // Make interpret node.
285   AnfNodePtr MakeInterpretNode(const FunctionBlockPtr &block, const AnfNodePtr &value_node, const string &script_text);
286   // Check if the node need interpreting.
287   AnfNodePtr HandleInterpret(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
288                              const py::object &value_object);
289 
290   bool CheckNeedConvertInterpret(const FunctionBlockPtr &block, const AnfNodePtr &node,
291                                  const string &script_text) const;
292 
293   // Generate argument nodes for ast function node.
294   void GenerateArgsNodeForFunction(const FunctionBlockPtr &block, const py::object &fn_node);
295   // Generate argument default value for ast function node.
296   void GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, const py::object &fn_node);
297   // Parse ast function node.
298   FunctionBlockPtr ParseDefFunction(const py::object &node, const FunctionBlockPtr &block = nullptr);
299   // Parse lambda function node.
300   FunctionBlockPtr ParseLambdaFunction(const py::object &node, const FunctionBlockPtr &block = nullptr);
301   // Parse ast statements.
302   FunctionBlockPtr ParseStatements(const FunctionBlockPtr &block, const py::object &nodes);
303   // Parse one ast statement node.
304   FunctionBlockPtr ParseStatement(const FunctionBlockPtr &block, const py::object &node);
305   // Parse an ast expression node.
306   AnfNodePtr ParseExprNode(const FunctionBlockPtr &block, const py::object &node);
307 
308   void MakeConditionBlocks(const FunctionBlockPtr &pre_block, const FunctionBlockPtr &true_block,
309                            const FunctionBlockPtr &false_block) const;
310   std::shared_ptr<std::map<ParameterPtr, AnfNodePtr>> CalRemovablePhis();
311   void CreatePhiArgMaps(std::map<ParameterPtr, std::set<AnfNodePtr>> *phi_to_args,
312                         std::map<AnfNodePtr, std::set<ParameterPtr>> *arg_to_phis);
313   static void PrintPhiArgMaps(const std::map<ParameterPtr, std::set<AnfNodePtr>> &phi_to_args,
314                               const std::map<AnfNodePtr, std::set<ParameterPtr>> &arg_to_phis);
315   static void UpdatePhiArgMapsRepeatedly(std::map<ParameterPtr, std::set<AnfNodePtr>> *phi_to_args,
316                                          std::map<AnfNodePtr, std::set<ParameterPtr>> *arg_to_phis);
317   static std::shared_ptr<std::map<ParameterPtr, AnfNodePtr>> CollectRemovablePhiArgs(
318     const std::map<ParameterPtr, std::set<AnfNodePtr>> &phi_to_args);
319   void RemoveUnnecessaryPhis(const FuncGraphManagerPtr &manager);
320   void ConvertGetattrNodes();
321   // Write a new var.
322   void WriteAssignVars(const FunctionBlockPtr &block, const py::object &target_object, const AnfNodePtr &value_node);
323 
324   // Create a setattr CNode.
325   void MakeSetAttrNode(const FunctionBlockPtr &block, const AnfNodePtr &target_node, const AnfNodePtr &value_node,
326                        const std::string &target_id_str, const std::string &attr_str);
327 
328   // Assign value to single variable name.
329   void HandleAssignName(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &assigned_node) const;
330 
331   // Assign value to starred expression.
332   void HandleAssignStarred(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &assigned_node);
333 
334   // Assign value to tuple.
335   void HandleAssignTupleOrList(const FunctionBlockPtr &block, const py::object &target,
336                                const AnfNodePtr &assigned_node);
337 
338   // Assign value to tuple with starred expression.
339   void HandleAssignTupleWithStarredExpression(const FunctionBlockPtr &block, const py::object &target,
340                                               const AnfNodePtr &assigned_node, const std::vector<int64_t> &positions);
341 
342   // Assign value to class Parameter member. Return false if not a Parameter member.
343   bool HandleAssignClassParameterMember(const FunctionBlockPtr &block, const py::object &target,
344                                         const AnfNodePtr &value_node);
345 
346   // Handle set attribute change for class member.
347   bool HandleSetAttrClassMemberForInplace(const FunctionBlockPtr &block, const AnfNodePtr &node);
348 
349   // Assign value to class member.
350   void HandleAssignClassMember(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &value_node);
351 
352   // Assign value to subscript.
353   void HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &target, const AnfNodePtr &assigned_node);
354 
355   // Process a bool operation value list.
356   AnfNodePtr ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode);
357 
358   void ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object &node, ArgsContext *args_context);
359 
360   void ParseArgsInCall(const FunctionBlockPtr &block, const py::list &args, ArgsContext *args_context);
361   AnfNodePtr GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_node,
362                                     const ArgsContext &args_context) const;
363   ScopePtr GetScopeForParseFunction();
364   // Check the value is subscript is reference type.
365   bool IsSubscriptReferenceType(const py::object &obj);
366   void BuildMethodMap();
367   // Must add a TraceGuard before call it.
368   FunctionBlockPtr MakeFunctionBlock();
369   FunctionBlockPtr MakeFunctionBlock(const TraceInfoPtr &trace_info);
370   // Return a make tuple for input elements list.
371   AnfNodePtr GenerateMakeTuple(const FunctionBlockPtr &block, const std::vector<AnfNodePtr> &element_nodes);
372   // Check if the node is pop operation.
373   bool IsPopOperation(const AnfNodePtr &node) const;
374   // Check if branch block contains break/continue/return statement, and propagate that flag back to block.
375   void CheckControlFlowAlterationInIf(std::pair<FunctionBlockPtr, FunctionBlockPtr> *branch_graphs_pair,
376                                       const FunctionBlockPtr &branch_block, const FunctionBlockPtr &branch_end,
377                                       const FunctionBlockPtr &after_block, const FunctionBlockPtr &block) const;
378   // Check if body block contains return statement, and propagate that flag back to block.
379   void CheckReturnInLoop(const FunctionBlockPtr &block, const FunctionBlockPtr &body_block) const;
380 
381   // Check whether the functions referred by this function and itself are missing 'return' statement.
382   void CheckFuncReturn(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fn);
383 
384   // If the node is Parameter member of class.
385   bool IsClassParameterMember(const py::object &target_obj, const AnfNodePtr &target_node) const;
386 
387   ValuePtr GetParameterValue(const AnfNodePtr &parameter) const;
388   bool CheckAttributeConstantCond(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond) const;
389   bool CheckNameConstantCond(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond) const;
390   bool CheckUnaryOpConstantCond(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond) const;
391   bool CheckCompareConstantCond(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond) const;
392   bool CheckBoolOpConstantCond(const FunctionBlockPtr &block, const py::object &test_node, bool *is_true_cond) const;
393   bool CompareIs(const FunctionBlockPtr &, const py::object &left_obj, const py::object &comparator_obj,
394                  bool *bool_res) const;
395   bool CompareIsNot(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
396                     bool *bool_res) const;
397   bool CompareEqual(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
398                     bool *bool_res) const;
399   bool CompareNotEqual(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
400                        bool *bool_res) const;
401   bool CompareGreater(const FunctionBlockPtr &, const py::object &left_obj, const py::object &comparator_obj,
402                       bool *bool_res) const;
403   bool CompareGreaterEqual(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
404                            bool *bool_res) const;
405   bool CompareLess(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
406                    bool *bool_res) const;
407   bool CompareLessEqual(const FunctionBlockPtr &block, const py::object &left_obj, const py::object &comparator_obj,
408                         bool *bool_res) const;
409   py::object GetValuePythonObject(const py::object &value_node);
410   CNodePtr MakeSetitemNode(const FunctionBlockPtr &block, const py::object &value_obj, const py::object &slice_obj,
411                            const AnfNodePtr &assigned_node, const AnfNodePtr &value_node);
412 
413   void ProcessPopOperation(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
414                            const py::object &target_object);
415 
416   void ProcessPopOperationInAugAssign(const FunctionBlockPtr &block, const AnfNodePtr &value_node,
417                                       const AnfNodePtr &target_node, const AnfNodePtr &op_node,
418                                       const py::object &target_object);
419 
420   // The shared_ptr will be hold by GraphManager, so just hold a weak ref here.
421   static FuncGraphWeakPtr top_func_graph_;
422   // Set if defer resolve during parsing.
423   inline static bool defer_resolve_{false};
424   // Python function id, used to indicate whether two CNodes come from the same Python function.
425   const std::shared_ptr<ParseFunctionAst> &ast_;
426   FuncGraphPtr func_graph_;
427   // Error code setwhen parsing ast tree.
428   ParseStatusCode errcode_;
429   py::object list_pop_target_obj_;
430 
431   // Hold all reference for FunctionBlock in this round of parsing,
432   // so in FunctionBlock class we can use FunctionBlock* in member
433   // pre_blocks_ and jumps_ to break reference cycle.
434   std::vector<FunctionBlockPtr> func_block_list_;
435   using StmtFunc = FunctionBlockPtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
436   using ExprFunc = AnfNodePtr (Parser::*)(const FunctionBlockPtr &block, const py::object &node);
437   using CompareFunc = bool (Parser::*)(const FunctionBlockPtr &block, const py::object &left_obj,
438                                        const py::object &comparator_obj, bool *bool_res) const;
439   using ConditionFunc = bool (Parser::*)(const FunctionBlockPtr &block, const py::object &test_node,
440                                          bool *is_true_cond) const;
441   // Define the function map to parse ast Statement.
442   std::map<std::string, StmtFunc> stmt_method_map_;
443   // Define the function map to parse ast expression.
444   std::map<std::string, ExprFunc> expr_method_map_;
445   // Define the function map to parse compare expression.
446   std::map<std::string, CompareFunc> compare_method_map_;
447   // Define the function map to parse constant condition expression.
448   std::map<std::string, ConditionFunc> condition_method_map_;
449   // Save current loops to support 'continue', 'break' statement.
450   std::stack<Loop> loops_;
451 
452   // The func graphs to transform tail call ir to independent call ir.
453   // Contains: {former_graph, middle_graph}, latter_graph is no need.
454   std::vector<std::vector<std::pair<FunctionBlockPtr, FunctionBlockPtr>>> parallel_call_graphs_;
455   // The true branch and false branch call info. of if statement.
456   std::vector<std::tuple<CNodePtr, FunctionBlockPtr, FunctionBlockPtr>> if_branch_calls_;
457   // The rolled_body callers info. for later lifting operation.
458   std::vector<std::pair<CNodePtr, FunctionBlockPtr>> rolled_body_calls_;
459 
460   // Record all setattr nodes and their targets and values.
461   std::map<std::string, std::map<std::string, AnfNodePtr>> setattr_nodes_map_;
462   // Record all getattr node for specific object and attribute name.
463   std::map<std::string, std::map<std::string, std::vector<AnfNodePtr>>> getattr_nodes_map_;
464   // Store the values for input args of top graph.
465   ValuePtrList args_value_list_;
466 };
467 
468 // AST node type define code to ast.
469 class AstNodeType {
470  public:
AstNodeType(const py::object & node,const std::string & name,AstMainType type)471   AstNodeType(const py::object &node, const std::string &name, AstMainType type)
472       : node_(node), node_name_(name), main_type_(type) {}
473 
~AstNodeType()474   ~AstNodeType() {}
475 
node_name()476   std::string node_name() const { return node_name_; }
477 
node()478   py::object node() const { return node_; }
479 
main_type()480   AstMainType main_type() const { return main_type_; }
481 
482  private:
483   const py::object &node_;
484   const std::string node_name_;
485   AstMainType main_type_;
486 };
487 
488 using AstNodeTypePtr = std::shared_ptr<AstNodeType>;
489 
490 // A helper class to parse python function.
491 class ParseFunctionAst {
492  public:
ParseFunctionAst(const py::object & obj)493   explicit ParseFunctionAst(const py::object &obj)
494       : obj_(obj), target_type_(PARSE_TARGET_UNKNOW), function_line_offset_(-1) {}
495 
496   ~ParseFunctionAst() = default;
497 
498   bool InitParseAstInfo(const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
499 
500   py::object GetAstNode();
501 
502   py::str GetAstNodeText(const py::object &node);
503 
504   py::list GetArgs(const py::object &func_node);
505 
506   py::list GetArgsDefaultValues(const py::object &func_node);
507 
508   AstNodeTypePtr GetNodeType(const py::object &node);
509 
510   AstSubType GetOpType(const py::object &node);
511 
512   template <class... T>
CallParserObjMethod(const std::string & method,const T &...args)513   py::object CallParserObjMethod(const std::string &method, const T &... args) {
514     return python_adapter::CallPyObjMethod(parser_, method, args...);
515   }
516 
517   template <class... T>
CallParseModFunction(const std::string & function,const T &...args)518   py::object CallParseModFunction(const std::string &function, const T &... args) {
519     return python_adapter::CallPyModFn(module_, function, args...);
520   }
521 
function_name()522   const std::string &function_name() const { return function_name_; }
523 
function_module()524   const std::string &function_module() const { return function_module_; }
525 
function_filename()526   const std::string &function_filename() const { return function_filename_; }
527 
function_line_offset()528   int64_t function_line_offset() const { return function_line_offset_; }
529 
function()530   py::function function() { return function_; }
531 
target_type()532   ParseTargetType target_type() const { return target_type_; }
533 
obj()534   py::object obj() { return obj_; }
535 
parser()536   py::object parser() { return parser_; }
537 
module()538   py::object module() { return module_; }
539 
ast_tree()540   py::object ast_tree() { return ast_tree_; }
541 
542   bool IsClassMemberOfSelf(const py::object &node);
543   bool IsClassMemberRecursive(const py::object &node);
544 
545  private:
546   // Save obj, eg: class instance or function.
547   py::object obj_;
548 
549   // Function or class method.
550   py::function function_;
551 
552   py::object ast_tokens_;
553   py::object ast_tree_;
554   py::object parser_;
555   py::module module_;
556 
557   // Is function or method.
558   ParseTargetType target_type_;
559 
560   std::string function_name_;
561   std::string function_module_;
562   std::string function_filename_;
563   int64_t function_line_offset_;
564 };
565 
566 // Update the graph flags.
567 bool UpdateFuncGraphFlags(const py::object &obj, const FuncGraphPtr &func_graph, bool is_construct_function = false);
568 
569 // Update recomputed scope for the graph.
570 void UpdateRecomputeScope(const FuncGraphPtr &func_graph);
571 
572 AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr &param);
573 }  // namespace parse
574 }  // namespace mindspore
575 
576 #endif  // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_PARSE_H_
577