1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""ControlFlow Node.""" 16from typing import Union, List 17import ast 18from mindspore import log as logger 19from .node import Node 20from .node_manager import NodeManager 21from ..api.node_type import NodeType 22from ..ast_helpers import AstModifier 23 24 25class ControlFlow(Node, NodeManager): 26 """ControlFlow node is used for statements like loops and `if` .""" 27 def __init__(self, node_name: str, ast_node: Union[ast.If, ast.IfExp, ast.For, ast.While], is_orelse: bool, 28 args: None, stree): 29 """ 30 Constructor of ControlFlow. 31 32 Args: 33 node_name (str): A string represents name of node. Name of node will be unique when inserted into 34 SymbolTree. Name of node also used as field name in network class. 35 ast_node (ast.AST): An instance of ast.AST represents control flow statements, can be one of ast.If, 36 ast.IfExp, ast.For, ast.While. 37 is_orelse (bool): Whether current node presents the else branch of node. 38 args (list[ScopedValue]): A list of instance of ScopedValue. 39 stree (SymbolTree): Symbol tree used to get node_namer. 40 """ 41 Node.__init__(self, NodeType.ControlFlow, ast_node, None, node_name, args, {}, node_name, None) 42 NodeManager.__init__(self) 43 NodeManager.set_manager_node_namer(self, stree.get_node_namer()) 44 NodeManager.set_manager_name(self, node_name) 45 self.is_orelse = is_orelse 46 self.body_node = None 47 self.orelse_node = None 48 # record node of another branch 49 if is_orelse: 50 NodeManager.set_manager_ast(self, ast_node.orelse) 51 self.orelse_node = self 52 else: 53 NodeManager.set_manager_ast(self, ast_node.body) 54 self.body_node = self 55 # record eval result of test code, used for ast.If 56 self.test_result = None 57 # record loop variables of control flow, e.g. 'item' of 'for item in self.cell_list:' 58 self.loop_vars: List[str] = [] 59 60 def erase_node(self, node): 61 """Erase node from container.""" 62 NodeManager.erase_node(self, node) 63 # erase node's ast 64 if isinstance(node, ControlFlow): 65 ret = AstModifier.earse_ast_of_control_flow(self.get_manager_ast(), node.get_ast(), node.is_orelse) 66 else: 67 ret = AstModifier.erase_ast_from_bodies(self.get_manager_ast(), node.get_ast()) 68 if not ret: 69 raise ValueError(f"Erase node failed, node {node.get_name()} is not in ControlFlow ast tree.") 70 71 def insert_node(self, new_node: Node, base_node: Node, before_node: bool, insert_to_ast: bool = True): 72 """ 73 Insert a node before or after base_node. 74 75 Args: 76 new_node (Node): Node to be inserted. 77 base_node (Node): New node will be inserted before or after base_node. 78 before_node (bool): Indicate whether new node is inserted before base_node. 79 insert_to_ast (bool): Indicate whether ast nodes need to be updated. 80 """ 81 NodeManager.insert_node(self, new_node, base_node, before_node) 82 if insert_to_ast: 83 stree = self.get_belong_symbol_tree() 84 stree.insert_to_ast_while_insert_node(new_node, base_node, before_node) 85 86 def set_belong_symbol_tree(self, symbol_tree): 87 """Set the symbol tree to which node belongs.""" 88 self._belong_tree = symbol_tree 89 for node in self.nodes(): 90 node.set_belong_symbol_tree(symbol_tree) 91 92 def set_body_node(self, body_node): 93 """Set body_node of control flow""" 94 self.body_node = body_node 95 96 def set_orelse_node(self, orelse_node): 97 """Set orelse_node of control flow""" 98 self.orelse_node = orelse_node 99 100 def get_source_code(self) -> str: 101 """Print source code of control flow, overwriting the implementation in Node.""" 102 source_code = Node.get_source_code(self) 103 if self.orelse_node: 104 else_pos = source_code.find("else:") 105 if else_pos == -1: 106 logger.warning(f"Failed to find code 'else:' in control flow node {self.get_name()}, " 107 f"return all codes.") 108 return source_code 109 if self.is_orelse: 110 source_code = source_code[else_pos:].strip() 111 else: 112 source_code = source_code[:else_pos].strip() 113 return source_code 114