1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Ast optimizer for flatten recursive call.""" 16 17from typing import Any, Tuple, List, Dict, Union 18import keyword 19import ast 20import copy 21 22from mindspore import log as logger 23 24FLATTEN_BLACK_LIST = ["set_vertex_attr"] 25 26 27class AstFlattener(ast.NodeTransformer): 28 """Ast optimizer for flatten recursive call.""" 29 30 # Record origin test ast, used to judge direction of static if control flow. 31 ast_if_test_cache: Dict[ast.If, ast.AST] = {} 32 33 def __init__(self): 34 """ 35 Constructor of AstFlattener. 36 37 Returns: 38 An instance of ast optimizer for flatten recursive call. 39 """ 40 self._flatten_table: dict = { 41 ast.Return: ["value"], 42 ast.Call: ["func", "args", "keywords"], 43 ast.BinOp: ["left", "right"], 44 ast.BoolOp: ["values"], 45 ast.UnaryOp: ["operand"], 46 ast.Compare: ["left", "comparators"], 47 ast.If: ["test"], 48 ast.For: ["iter"], 49 ast.Tuple: ["elts"], 50 ast.List: ["elts"], 51 } 52 self._transform_functions = [] 53 self._symbol_tree = None # Used to get unique name 54 55 @staticmethod 56 def _check_flatten_black_list(node: ast.AST): 57 """Check whether node in flatten black list""" 58 func_name = "" 59 # Get func name of node 60 if isinstance(node, ast.Call): 61 if isinstance(node.func, ast.Name): 62 func_name = node.func.id 63 elif isinstance(node.func, ast.Attribute): 64 func_name = node.func.attr 65 # Check func name of node 66 if func_name and func_name in FLATTEN_BLACK_LIST: 67 return True 68 return False 69 70 @staticmethod 71 def _flatten_continuous_assign(ast_body: List[ast.AST]): 72 """ 73 Flatten ast.Assign with continuous targets. 74 """ 75 for pos, ast_node in enumerate(ast_body): 76 if not isinstance(ast_node, ast.Assign): 77 continue 78 if not len(ast_node.targets) > 1: 79 continue 80 for idx, ast_target in enumerate(ast_node.targets[:-1]): 81 new_assign = ast.Assign(targets=[ast_target], value=ast_node.targets[idx + 1]) 82 ast_body.insert(pos + idx + 1, new_assign) 83 ast_node.targets = [ast_node.targets[-1]] 84 85 @staticmethod 86 def _save_target_names(ast_body: List[ast.AST]): 87 """Saving target names in ast_body before getting unique names.""" 88 target_names = [] 89 for child in ast_body: 90 if not isinstance(child, ast.Assign): 91 continue 92 targets = child.targets 93 for target in targets: 94 if isinstance(target, ast.Name) and target.id not in target_names: 95 target_names.append(target.id) 96 elif isinstance(target, (ast.Tuple, ast.List)): 97 # get target names from list recursively 98 ast_queue = [target.elts] 99 while ast_queue: 100 elt = ast_queue.pop() 101 if isinstance(elt, ast.Name) and elt.id not in target_names: 102 target_names.append(elt.id) 103 elif isinstance(elt, (ast.Tuple, ast.List)): 104 ast_queue.extend(elt.elts) 105 elif isinstance(elt, (list, tuple)): 106 ast_queue.extend(elt) 107 return target_names 108 109 def _generate_target_name(self, node: ast.AST, target_names): 110 """Generate unique target name.""" 111 if isinstance(node, ast.Call): 112 func = node.func 113 if isinstance(func, ast.Name): 114 target_name = func.id + "_var" 115 elif isinstance(func, ast.Attribute): 116 target_name = func.attr + "_var" 117 else: 118 logger.debug("unhandled type of func of ast.Call while generating new target name: %s ", type(func)) 119 target_name = "function_var" 120 elif isinstance(node, ast.Return): 121 target_name = "return_value" 122 elif isinstance(node, (ast.BinOp, ast.BoolOp, ast.UnaryOp)): 123 target_name = type(node.op).__name__.lower() + "_var" 124 elif isinstance(node, ast.Tuple): 125 target_name = type(node).__name__.lower() + "_var" 126 elif isinstance(node, ast.Name): 127 target_name = node.id 128 elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name): 129 target_name = f"{node.value.id}_{node.attr}" 130 else: 131 logger.debug("unhandled type of node while generating new target name: %s ", type(node)) 132 target_name = type(node).__name__.lower() + "_var" 133 # avoid python built-in keyword 134 if keyword.iskeyword(target_name): 135 target_name = target_name + "_var" 136 suffix = 0 137 result = target_name 138 while result in target_names: 139 suffix += 1 140 result = f"{target_name}_{suffix}" 141 if self._symbol_tree: 142 result = self._symbol_tree.unique_name(result) 143 target_names.append(result) 144 return result 145 146 def _create_new_assign_node(self, node: ast.AST, target_names, father_node: ast.AST) \ 147 -> Tuple[Union[ast.Name, ast.Attribute], ast.AST]: 148 """Create new assign node to be inserted into ast.FunctionDef.""" 149 ast_unflattens = (ast.Name, ast.NameConstant, ast.Constant, ast.Num, ast.Str, ast.Bytes, ast.Ellipsis) 150 if isinstance(node, ast_unflattens): 151 return node, None 152 # ast.Attribute in ast.For will be force flatten 153 # when ast.Attribute is not in ast.For, it's value which is not type of ast.Name will be flatten 154 if isinstance(node, ast.Attribute) and not isinstance(father_node, ast.For): 155 iter_node = node 156 while isinstance(iter_node.value, ast.Attribute): 157 iter_node = iter_node.value 158 if isinstance(iter_node.value, ast.Name): 159 return node, None 160 new_target_name = self._generate_target_name(iter_node.value, target_names) 161 new_node = ast.Attribute(value=ast.Name(id=new_target_name, ctx=ast.Load()), 162 attr=iter_node.attr, ctx=iter_node.ctx) 163 return new_node, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=iter_node.value) 164 # flatten nodes 165 new_target_name = self._generate_target_name(node, target_names) 166 new_node = ast.Name(id=new_target_name, ctx=ast.Load()) 167 return new_node, ast.Assign(targets=[ast.Name(id=new_target_name, ctx=ast.Store())], value=node) 168 169 def _flatten_statement(self, node: ast.AST, target_names) -> [ast.AST]: 170 """Flatten recursive statement according to different node type.""" 171 if AstFlattener._check_flatten_black_list(node): 172 return [] 173 flatten_config = self._flatten_table.get(type(node)) 174 if flatten_config is None: 175 return [] 176 results = [] 177 for todo_name in flatten_config: 178 todos = getattr(node, todo_name) 179 if isinstance(todos, list): 180 new_list = [] 181 for idx, todo in enumerate(todos): 182 # Starred expression(e.g. *args) cannot be flatten. 183 if isinstance(todo, ast.Starred): 184 new_list.append(todo) 185 continue 186 # For codes like 'xxx and yyy and zzz', only 'xxx' can be flatten and parsed, 187 # otherwise executing 'yyy' may raise an exception when 'xxx' is False 188 if isinstance(node, ast.BoolOp) and isinstance(node.op, ast.And) and idx > 0: 189 new_list.append(todo) 190 continue 191 # ast.keywords are processed individually: 192 # y = func(key=value) => new_target_name = value & y = func(key=new_target_name) 193 if isinstance(todo, ast.keyword): 194 new_node, new_assign = self._create_new_assign_node(todo.value, target_names, node) 195 if id(new_node) != id(todo.value): 196 todo.value = new_node 197 results.append(new_assign) 198 new_list.append(todo) 199 continue 200 new_node, new_assign = self._create_new_assign_node(todo, target_names, node) 201 if id(new_node) != id(todo): 202 new_list.append(new_node) 203 results.append(new_assign) 204 else: 205 new_list.append(todo) 206 setattr(node, todo_name, new_list) 207 elif isinstance(todos, dict): 208 new_dict = [] 209 for key, value in todos: 210 new_node, new_assign = self._create_new_assign_node(value, target_names, node) 211 if id(new_node) != id(value): 212 new_dict[key] = new_node 213 results.append(new_assign) 214 else: 215 new_dict[key] = value 216 setattr(node, todo_name, new_dict) 217 else: 218 new_node, new_assign = self._create_new_assign_node(todos, target_names, node) 219 if id(new_node) != id(todos): 220 setattr(node, todo_name, new_node) 221 results.append(new_assign) 222 return results 223 224 def _visit_ast_bodies(self, ast_body: List[ast.AST]): 225 """Traverse nodes in ast_body and flatten nodes recursive.""" 226 # Flatten continuous assign statements in ast_body 227 AstFlattener._flatten_continuous_assign(ast_body) 228 # save target names, used when create new assign ast node 229 target_names = AstFlattener._save_target_names(ast_body) 230 index = len(ast_body) - 1 231 while index >= 0: 232 child = ast_body[index] 233 # Record origin test ast, used to judge direction of static if control flow. 234 if isinstance(child, ast.If) and child not in AstFlattener.ast_if_test_cache: 235 AstFlattener.ast_if_test_cache[child] = copy.deepcopy(child.test) 236 237 stmt = child.value if isinstance(child, (ast.Assign, ast.Expr)) else child 238 results = self._flatten_statement(stmt, target_names) 239 if results: 240 for result in reversed(results): 241 ast_body.insert(index, result) 242 index += 1 243 index -= 1 244 245 def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: # pylint: disable=invalid-name 246 """Traverse nodes in _transform_functions and flatten recursive nodes.""" 247 if node.name not in self._transform_functions: 248 return node 249 self._visit_ast_bodies(node.body) 250 return node 251 252 def transform(self, ast_root, transform_functions=None, stree=None): 253 """Interface of AstFlattener.""" 254 self._transform_functions = transform_functions if transform_functions else ["construct"] 255 self._symbol_tree = stree 256 ast_root = self.visit(ast_root) 257 ast_root = ast.fix_missing_locations(ast_root) 258 return ast_root 259 260 def transform_control_flow(self, ast_control_flow: Union[ast.If, ast.For, ast.While], stree=None): 261 """Interface of AstFlattener.""" 262 self._transform_functions = [] 263 self._symbol_tree = stree 264 self._visit_ast_bodies(ast_control_flow.body) 265 if ast_control_flow.orelse: 266 self._visit_ast_bodies(ast_control_flow.orelse) 267 ast_control_flow = ast.fix_missing_locations(ast_control_flow) 268 return ast_control_flow 269