• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""ControlFlow Node."""
16from typing import Union, List
17import ast
18from mindspore import log as logger
19from .node import Node
20from .node_manager import NodeManager
21from ..api.node_type import NodeType
22from ..ast_helpers import AstModifier
23
24
25class ControlFlow(Node, NodeManager):
26    """ControlFlow node is used for statements like loops and `if` ."""
27    def __init__(self, node_name: str, ast_node: Union[ast.If, ast.IfExp, ast.For, ast.While], is_orelse: bool,
28                 args: None, stree):
29        """
30        Constructor of ControlFlow.
31
32        Args:
33            node_name (str): A string represents name of node. Name of node will be unique when inserted into
34                SymbolTree. Name of node also used as field name in network class.
35            ast_node (ast.AST): An instance of ast.AST represents control flow statements, can be one of ast.If,
36                ast.IfExp, ast.For, ast.While.
37            is_orelse (bool): Whether current node presents the else branch of node.
38            args (list[ScopedValue]): A list of instance of ScopedValue.
39            stree (SymbolTree): Symbol tree used to get node_namer.
40        """
41        Node.__init__(self, NodeType.ControlFlow, ast_node, None, node_name, args, {}, node_name, None)
42        NodeManager.__init__(self)
43        NodeManager.set_manager_node_namer(self, stree.get_node_namer())
44        NodeManager.set_manager_name(self, node_name)
45        self.is_orelse = is_orelse
46        self.body_node = None
47        self.orelse_node = None
48        # record node of another branch
49        if is_orelse:
50            NodeManager.set_manager_ast(self, ast_node.orelse)
51            self.orelse_node = self
52        else:
53            NodeManager.set_manager_ast(self, ast_node.body)
54            self.body_node = self
55        # record eval result of test code, used for ast.If
56        self.test_result = None
57        # record loop variables of control flow, e.g. 'item' of 'for item in self.cell_list:'
58        self.loop_vars: List[str] = []
59
60    def erase_node(self, node):
61        """Erase node from container."""
62        NodeManager.erase_node(self, node)
63        # erase node's ast
64        if isinstance(node, ControlFlow):
65            ret = AstModifier.earse_ast_of_control_flow(self.get_manager_ast(), node.get_ast(), node.is_orelse)
66        else:
67            ret = AstModifier.erase_ast_from_bodies(self.get_manager_ast(), node.get_ast())
68        if not ret:
69            raise ValueError(f"Erase node failed, node {node.get_name()} is not in ControlFlow ast tree.")
70
71    def insert_node(self, new_node: Node, base_node: Node, before_node: bool, insert_to_ast: bool = True):
72        """
73        Insert a node before or after base_node.
74
75        Args:
76            new_node (Node): Node to be inserted.
77            base_node (Node): New node will be inserted before or after base_node.
78            before_node (bool): Indicate whether new node is inserted before base_node.
79            insert_to_ast (bool): Indicate whether ast nodes need to be updated.
80        """
81        NodeManager.insert_node(self, new_node, base_node, before_node)
82        if insert_to_ast:
83            stree = self.get_belong_symbol_tree()
84            stree.insert_to_ast_while_insert_node(new_node, base_node, before_node)
85
86    def set_belong_symbol_tree(self, symbol_tree):
87        """Set the symbol tree to which node belongs."""
88        self._belong_tree = symbol_tree
89        for node in self.nodes():
90            node.set_belong_symbol_tree(symbol_tree)
91
92    def set_body_node(self, body_node):
93        """Set body_node of control flow"""
94        self.body_node = body_node
95
96    def set_orelse_node(self, orelse_node):
97        """Set orelse_node of control flow"""
98        self.orelse_node = orelse_node
99
100    def get_source_code(self) -> str:
101        """Print source code of control flow, overwriting the implementation in Node."""
102        source_code = Node.get_source_code(self)
103        if self.orelse_node:
104            else_pos = source_code.find("else:")
105            if else_pos == -1:
106                logger.warning(f"Failed to find code 'else:' in control flow node {self.get_name()}, "
107                               f"return all codes.")
108                return source_code
109            if self.is_orelse:
110                source_code = source_code[else_pos:].strip()
111            else:
112                source_code = source_code[:else_pos].strip()
113        return source_code
114