1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Parse ast.Assign in construct function to node of SymbolTree.""" 16from typing import Union, List, Dict 17import types 18import os 19import ast 20import sys 21import inspect 22import builtins 23from textwrap import dedent 24 25from mindspore import log as logger 26from mindspore.nn import Cell, SequentialCell, CellList 27from mindspore.ops.primitive import Primitive 28import mindspore.ops.functional as F 29from . import Parser, ParserRegister, reg_parser 30from ..symbol_tree import SymbolTree 31from ..node import Node, TreeNode, NodeManager, CallFunction, CellContainer, ControlFlow, LocalPrim 32from ..api.scoped_value import ScopedValue 33from ..ast_helpers import AstFlattener, AstConverter, AstFinder 34from ..common.error_log import error_str 35from ..common.namespace import is_subtree, is_ms_function, is_third_party 36from ..common.namer import FunctionNamer 37 38 39if sys.version_info >= (3, 9): 40 import ast as astunparse # pylint: disable=reimported, ungrouped-imports 41else: 42 import astunparse 43 44 45class AssignParser(Parser): 46 """Parse ast.Assign in construct function to node of SymbolTree.""" 47 48 # Types for creating Cell Container node 49 types_for_cell_container = [SequentialCell,] 50 # If mindspore built-in function to be parsered or skipped 51 _skip_ms_function = False 52 # Functions in black list will not be parsed 53 _function_parse_black_list = [F.arange] 54 # Share one implementation for the same instances 55 _share_one_implementation = False 56 # Implementation caches of sub SymbolTrees, CallFunction nodes and CellContainer nodes 57 # Keys are ids of the instance object 58 _cached_trees: Dict[int, SymbolTree] = {} 59 _cached_functions: Dict[int, Node] = {} 60 _cached_cell_containers: Dict[int, Node] = {} 61 62 def __init__(self): 63 super().__init__() 64 self._variables_cache = [] 65 self.stree: SymbolTree = None 66 self.ast_assign: ast.Assign = None 67 self.node_manager: NodeManager = None 68 self.targets: List[ScopedValue] = None 69 self.args: List[ScopedValue] = None 70 self.kwargs: Dict[str, ScopedValue] = None 71 72 @staticmethod 73 def _get_func_name(ast_call: ast.Call) -> str: 74 """ 75 Get the func name from ast.Call. 76 77 Args: 78 ast_call (ast.Call): Input ast.Call node. 79 80 Returns: 81 Func name. 82 """ 83 func = ast_call.func 84 if isinstance(func, ast.Name): 85 return func.id 86 if isinstance(func, ast.Attribute): 87 return func.attr 88 func_full_name = astunparse.unparse(func).strip() 89 if func_full_name.count('.') > 0: 90 return func_full_name.split('.')[-1] 91 return func_full_name 92 93 @staticmethod 94 def _get_func_scope(ast_call: ast.Call) -> str: 95 """ 96 Get the func scope from ast.Call. 97 98 Args: 99 ast_call (ast.Call): Input ast.Call node. 100 101 Returns: 102 Func scope. 103 """ 104 func = ast_call.func 105 if isinstance(func, ast.Name): 106 return "" 107 func_full_name = astunparse.unparse(func).strip() 108 if func_full_name.count('.') > 0: 109 return func_full_name.rsplit('.', 1)[0] 110 return "" 111 112 @staticmethod 113 def _create_targets(ast_target: ast.AST) -> List[ScopedValue]: 114 """Get targets from ast node.""" 115 ast_target_elems = AstConverter.get_ast_target_elems(ast_target) 116 targets = [AstConverter.create_scopedvalue(ast_node) for ast_node in ast_target_elems] 117 return targets 118 119 @staticmethod 120 def _create_kwargs(keywords: [ast.keyword]) -> Dict[str, ScopedValue]: 121 """ 122 Transfer ast.Call keywords to a dict of ScopedValue when creating a symbol tree node. 123 124 Args: 125 keywords ([ast.keyword]): Keywords of ast.Call node. 126 127 Returns: 128 A dict of ScopedValue. 129 """ 130 results = {} 131 for keyword in keywords: 132 results[keyword.arg] = AstConverter.create_scopedvalue(keyword.value) 133 return results 134 135 136 @staticmethod 137 def _get_inst_and_name(ast_node: ast.Attribute, stree: SymbolTree): 138 """ 139 Try to get instance object of ast_node from ast.Attribute. 140 """ 141 if not isinstance(ast_node, ast.Attribute): 142 return None, "" 143 scope_name = astunparse.unparse(ast_node).strip() 144 scope, name = scope_name.split('.', 1) 145 if scope != 'self': 146 return None, scope_name 147 if not hasattr(stree.get_origin_network(), name): 148 return None, scope_name 149 return getattr(stree.get_origin_network(), name), scope_name 150 151 @staticmethod 152 def _list_of_cells(cell_list: list): 153 """Check if elements in the list are all cells.""" 154 for item in cell_list: 155 if not isinstance(item, Cell): 156 return False 157 return True 158 159 @staticmethod 160 def _get_path_of_node_manager(node_manager: NodeManager): 161 """Get file path of type(instance) in NodeManager""" 162 node_manager = node_manager.get_top_manager() 163 if isinstance(node_manager, SymbolTree): 164 return inspect.getfile(type(node_manager.get_origin_network())) 165 return inspect.getfile(node_manager.get_instance()) 166 167 @staticmethod 168 def _get_module_of_node_manager(node_manager: NodeManager): 169 """Get module where the node manager is located""" 170 # get module where function object is used 171 func_path = AssignParser._get_path_of_node_manager(node_manager) 172 func_path = os.path.normcase(os.path.normpath(func_path)) 173 modules = list(sys.modules.values()) 174 for m in modules: 175 if hasattr(m, "__file__") and m.__file__ is not None and func_path == os.path.normcase(m.__file__): 176 return m, func_path 177 return None, func_path 178 179 @staticmethod 180 def _get_object_from_module(func_full_name: str, module: types.ModuleType): 181 """Get object from module according to full name of function""" 182 names = func_full_name.split('.') 183 obj = module 184 for attr in names: 185 if not hasattr(obj, attr): 186 logger.info(f"For '{func_full_name}', failed to get attr '{attr}' from '{obj}'") 187 return None 188 obj = getattr(obj, attr) 189 return obj 190 191 @staticmethod 192 def _get_local_var_provider(node_manager: NodeManager, var: str) -> Node: 193 """Get the node providing specific variable""" 194 node = node_manager.get_tail() 195 while node is not None: 196 if var in [str(target) for target in node.get_targets()]: 197 return node 198 node = node.get_prev() 199 # When node_manager is control flow, nodes in upper node_manager need to be traversed. 200 if isinstance(node_manager, ControlFlow): 201 return AssignParser._get_local_var_provider(node_manager.get_node_manager(), var) 202 return None 203 204 def target(self): 205 """Parse target type.""" 206 return ast.Assign 207 208 def store_env(self): 209 """Store current environments""" 210 self._variables_cache.append( 211 [self.stree, self.ast_assign, self.node_manager, self.targets, self.args, self.kwargs]) 212 self.stree = None 213 self.ast_assign = None 214 self.node_manager = None 215 self.targets = None 216 self.args = None 217 self.kwargs = None 218 219 def restore_env(self): 220 """Restore last environments""" 221 self.stree, self.ast_assign, self.node_manager, self.targets, self.args, self.kwargs = \ 222 self._variables_cache.pop() 223 224 def _get_cell_instance(self, func_scope, func_name): 225 """ 226 Get object instance from ast.Call with type of Cell. 227 228 Args: 229 func_scope (str): Func scope. 230 func_name (str): Func name. 231 232 Returns: 233 An instance represents operator instance. 234 """ 235 if func_scope != "self": 236 return None 237 var_dict = self.stree.get_origin_network().__dict__ 238 # Instance is of type Cell 239 for key, value in var_dict["_cells"].items(): 240 if key == func_name: 241 return value 242 # Instance is of other type. 243 return None 244 245 def _get_primitive_instance(self, func_scope, func_name): 246 """ 247 Get object instance from ast.Call with type of Primitive. 248 249 Args: 250 func_scope (str): Func scope. 251 func_name (str): Func name. 252 253 Returns: 254 An instance represents operator instance. 255 """ 256 if func_scope != "self": 257 return None 258 var_dict = self.stree.get_origin_network().__dict__ 259 # Instance is of type Primitive 260 for key, value in var_dict["_primitives"].items(): 261 if key == func_name: 262 return value 263 # Instance is of other type. 264 return None 265 266 def _get_method_object(self, func_scope, func_name): 267 """Get method object from network instance.""" 268 stree = self.stree 269 if func_scope in ('self', stree.get_opt_cls_name()) and hasattr(stree.get_origin_network(), func_name): 270 return getattr(stree.get_origin_network(), func_name) 271 return None 272 273 def _get_local_variable(self, func_scope, func_name) -> (bool, object): 274 """ 275 Get local variable 276 277 Args: 278 func_scope (str): Func scope. 279 func_name (str): Func name. 280 281 Returns: 282 bool: Indicate whether local variable is found. 283 object (Union[LocalPrim, type]): Instance of LocalPrim when calling the class, or class type 284 object when initializing the class. 285 """ 286 func_full_name = f"{func_scope}.{func_name}" if func_scope else func_name 287 # try to find func_name in class variables initializing the primitive during forward method 288 provider_node = None 289 if func_scope == "self": 290 for node in self.stree.local_prim_inits(): 291 if func_full_name in [str(target) for target in node.get_targets()]: 292 provider_node = node 293 # try to find func_name in local variables 294 if provider_node is None: 295 provider_node = AssignParser._get_local_var_provider(self.node_manager, func_full_name) 296 if provider_node: 297 # when the node providering the local variable initialized a primitive during forward method, 298 # we use LocalPrim to indicate the instance of this primitive. e.g. : 299 # abs_inst = P.Abs() -> 'abs_inst' is an instance of primitive initialized locally 300 # y = abs_inst(x) -> here we are parsing now 301 cls_init = provider_node.get_init_cls() 302 if cls_init and inspect.isclass(cls_init) and issubclass(cls_init, Primitive): 303 return True, LocalPrim(cls_init) 304 # when the node providering the local variable represent a primitive type object, we return 305 # type-object to indicate that we are initializing this primitive. e.g. : 306 # abs_ops = _get_cache_prim(P.Abs) -> 'abs_ops' is an primitive type object 307 # y = abs_ops(x) -> here we are parsing now 308 cls_type = provider_node.get_type_cls() 309 if cls_type and inspect.isclass(cls_type) and issubclass(cls_type, Primitive): 310 return True, cls_type 311 # local variable whose type is not primitive instance 312 logger.info(f"Ignore local variable: {func_full_name}") 313 return True, None 314 # other local variable 315 if AssignParser._get_local_var_provider(self.node_manager, func_full_name.split('.')[0]): 316 logger.info(f"Ignore local variable: {func_full_name}") 317 return True, None 318 return False, None 319 320 def _get_function_object(self, func_scope, func_name, ast_call) -> (object, bool): 321 """ 322 Get function object from module. 323 324 If the code represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs), 325 return primitive type object with class type flag True. 326 327 if the code represent an initializtion of a class, e.g. abs_inst = P.Abs(), 328 return primitive type object with class type flag False. 329 330 if the code represent the call of function or class instance, e.g. y = abs_inst(x)/func(x), 331 return primitive instance or function object with class type flag False. 332 333 Args: 334 func_scope (str): Func scope. 335 func_name (str): Func name. 336 ast_call (ast.Call): ast.Call of ast.Assign. 337 338 Returns: 339 object: Class type object, class instance or function object 340 bool: Flag indicate is node represent a class type object. 341 """ 342 func_full_name = f"{func_scope}.{func_name}" if func_scope else func_name 343 # get module where function object is used 344 module, func_path = AssignParser._get_module_of_node_manager(self.node_manager) 345 if module is None: 346 logger.debug(f"When getting object of '{func_full_name}', failed to find module in '{func_path}'") 347 return None, False 348 # if name of function is _get_cache_prim, return primitive type object 349 is_cls_type_obj = False 350 if func_full_name == '_get_cache_prim': 351 func_full_name = astunparse.unparse(ast_call.args[0]).strip() 352 is_cls_type_obj = True 353 # find object in module 354 obj = AssignParser._get_object_from_module(func_full_name, module) 355 return obj, is_cls_type_obj 356 357 def _update_field_in_init(self, func_name: str, sub_tree: SymbolTree) -> bool: 358 """ 359 When node is an invoking to sub-network, update value of ast.Assign of corresponding field in `__init__` method. 360 Add the code like: `self.field = SubNetwork(self.field)` 361 362 Args: 363 func_name (str): A string represents scope and name of function symbol. 364 sub_tree (SymbolTree): The SymbolTree corresponding to sub-network. 365 """ 366 init_func_ast = self.stree.get_init_func_ast() 367 sub_net_obj = sub_tree.get_origin_network() 368 sub_net_opt_name = sub_tree.get_opt_cls_name() 369 # Add .to_float(mindspore.float16) if origin subnet has this attribute 370 new_code = f"{func_name} = {sub_net_opt_name}({func_name})" 371 if hasattr(sub_net_obj, "fp16") and sub_net_obj.fp16: 372 new_code = f"{new_code}.to_float(mindspore.float16)" 373 elif hasattr(sub_net_obj, "bf16") and sub_net_obj.bf16: 374 new_code = f"{new_code}.to_float(mindspore.bfloat16)" 375 new_ast = ast.parse(new_code).body[0] 376 init_func_ast.body.append(new_ast) 377 378 def _update_cell_container_in_init(self, container_name, container_idx, subnet_opt_name): 379 """ 380 When nn.SequentialCell include sub-symboltree, the new class definition will be used to create object. 381 So the assign code will be got from origin code first, and then be modified to new class name. 382 383 Codes like: 384 385 `self.container = nn.SequentialCell([ReLU(), MyNet()])` 386 387 will be updated by add codes: 388 389 `self.container[1] = MyNetOpt(self.container[1])` 390 391 """ 392 new_code = f"{container_name}[{container_idx}] = {subnet_opt_name}({container_name}[{container_idx}])" 393 new_ast = ast.parse(new_code).body[0] 394 self.stree.get_init_func_ast().body.append(new_ast) 395 396 def _add_import(self, import_name: str): 397 """ add import to current node manager.""" 398 module, _ = AssignParser._get_module_of_node_manager(self.node_manager) 399 if module is None: 400 logger.info(f"Cannot get module where '{import_name}' is located, ignore import info") 401 return 402 node_manager = self.node_manager.get_top_manager() 403 belonging_ast = None if isinstance(node_manager, SymbolTree) else node_manager.get_manager_ast() 404 self.stree.add_import(module, import_name, belonging_ast) 405 406 def cell_container_process(self, func_name: str, node_name: str, container_obj: object): 407 """ parse cell container object.""" 408 # create unparsable node if container is already parsed when sharing one implementation 409 if AssignParser._share_one_implementation and id(container_obj) in AssignParser._cached_cell_containers: 410 cell_container = Node.create_call_buildin_op(container_obj, self.ast_assign, self.targets, 411 func_name, self.args, self.kwargs, node_name) 412 return cell_container 413 cell_container = CellContainer(self.ast_assign, self.targets, func_name, self.args, self.kwargs, 414 node_name, self.stree, container_obj) 415 for i, cell in enumerate(container_obj): 416 cell_name = type(cell).__name__ 417 # The type of cell is container of cells (e.g. SequentialCell) 418 if isinstance(cell, tuple(AssignParser.types_for_cell_container)): 419 sub_node = self.cell_container_process(f"{func_name}[{i}]", cell_name, cell) 420 elif is_subtree(cell): 421 # create unparsable node if tree node is already parsed when sharing one implementation 422 if AssignParser._share_one_implementation and id(cell) in AssignParser._cached_trees: 423 first_stree = AssignParser._cached_trees.get(id(cell)) 424 self._update_cell_container_in_init(func_name, i, first_stree.get_opt_cls_name()) 425 sub_node = Node.create_call_buildin_op(cell, None, self.targets, cell_name, self.args, 426 self.kwargs, cell_name) 427 else: 428 from ..symbol_tree import SymbolTreeBuilder 429 stb = SymbolTreeBuilder(cell) 430 new_stree = stb.build() 431 sub_node = TreeNode.create_tree_node(new_stree, None, self.targets, cell_name, self.args, 432 self.kwargs, cell_name, cell) 433 self._update_cell_container_in_init(func_name, i, new_stree.get_opt_cls_name()) 434 # save symbol tree if it is firstly parsed when sharing one implementation 435 if AssignParser._share_one_implementation: 436 AssignParser._cached_trees[id(cell)] = new_stree 437 else: 438 sub_node = Node.create_call_buildin_op(cell, None, self.targets, cell_name, self.args, 439 self.kwargs, cell_name) 440 # add sub node to cell_container 441 cell_container.append(sub_node, False) 442 # save the node if container is firstly parsed when sharing one implementation 443 if AssignParser._share_one_implementation: 444 AssignParser._cached_cell_containers[id(container_obj)] = cell_container 445 return cell_container 446 447 def process_cell(self, func_scope_name: ScopedValue, node_name: str, cell_inst: Cell): 448 """Create CallCell node with instance of cell.""" 449 # The type of cell is container of cells (e.g. SequentialCell) 450 if isinstance(cell_inst, tuple(AssignParser.types_for_cell_container)): 451 node = self.cell_container_process(func_scope_name, node_name, cell_inst) 452 # The type of cell is user custom network, then we create sub-symboltree 453 elif is_subtree(cell_inst): 454 # create unparsable node if tree node is already parsed when sharing one implementation 455 if AssignParser._share_one_implementation and id(cell_inst) in AssignParser._cached_trees: 456 first_stree = AssignParser._cached_trees.get(id(cell_inst)) 457 self._update_field_in_init(str(func_scope_name), first_stree) 458 node = Node.create_call_buildin_op(cell_inst, self.ast_assign, self.targets, func_scope_name, 459 self.args, self.kwargs, node_name) 460 else: 461 from ..symbol_tree import SymbolTreeBuilder 462 stb = SymbolTreeBuilder(cell_inst) 463 new_stree = stb.build() 464 self._update_field_in_init(str(func_scope_name), new_stree) 465 node = TreeNode.create_tree_node(new_stree, self.ast_assign, self.targets, func_scope_name, 466 self.args, self.kwargs, node_name, new_stree.get_origin_network()) 467 # save symbol tree if it is firstly parsed when sharing one implementation 468 if AssignParser._share_one_implementation: 469 AssignParser._cached_trees[id(cell_inst)] = new_stree 470 else: 471 # The type of cell is built-in cells 472 node = Node.create_call_buildin_op(cell_inst, self.ast_assign, self.targets, func_scope_name, self.args, 473 self.kwargs, node_name) 474 self.stree.append_origin_field(node, self.node_manager) 475 476 def process_primitive(self, func_scope_name: ScopedValue, node_name: str, primitive_inst: Primitive): 477 """Create CallPrimitive node with instance of primitive.""" 478 node = Node.create_call_buildin_op(primitive_inst, self.ast_assign, self.targets, func_scope_name, 479 self.args, self.kwargs, node_name) 480 self.stree.append_origin_field(node, self.node_manager) 481 482 def process_class_method(self, func_scope_name: ScopedValue, node_name: str, method_object: object): 483 """Create CallFunction node for class method function.""" 484 func_name = func_scope_name.value 485 # get ast.FunctionDef 486 ast_functiondef = None 487 for body in self.stree.get_class_ast().body: 488 if isinstance(body, ast.FunctionDef) and func_name == body.name: 489 ast_functiondef = body 490 if ast_functiondef is None: 491 # method of child class may be called and will be ignored now. 492 logger.info(error_str(f"Find ast of function '{func_name}' in network '{self.stree.get_ori_cls_name()}' " 493 f"failed", child_node=self.ast_assign)) 494 self.insert_callfunction_node(func_scope_name, node_name, None, None, False) 495 else: 496 # create CallFunction node 497 self.insert_callfunction_node(func_scope_name, node_name, ast_functiondef, method_object, True) 498 499 def process_function(self, func_scope_name: ScopedValue, node_name: str, function_object: object, 500 is_cls_type_obj: bool): 501 """Create node for function.""" 502 # Ignore functions in _function_parse_black_list 503 if function_object in AssignParser._function_parse_black_list: 504 logger.debug(f"'{func_scope_name}' is in the _function_parse_black_list and will not be parsed") 505 if not func_scope_name.scope: 506 self._add_import(func_scope_name.value) 507 self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False) 508 return 509 # break loop function 510 node_manager = self.node_manager 511 while node_manager and isinstance(node_manager, Node): 512 if isinstance(node_manager, CallFunction) and node_manager.get_instance() == function_object: 513 logger.info(f"loop function detected in '{func_scope_name}', stop parsing function.") 514 self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False) 515 return 516 node_manager = node_manager.get_node_manager() 517 # process primitive instances: 518 # (global/local) _ops_func = P.FUNC() 519 # (here) y = _ops_func(x) <- (process: _ops_func) 520 if isinstance(function_object, Primitive): 521 # when primitive instance is not a local variable, it will be a global object which need to be imported 522 if not isinstance(function_object, LocalPrim): 523 import_name = str(func_scope_name).split('.')[0] 524 self._add_import(import_name) 525 # create CallPrimitive node 526 self.process_primitive(func_scope_name, func_scope_name.value, function_object) 527 return 528 # process primitive object: 529 # (here) _ops_func = P.FUNC() <- (process: P.FUNC) 530 # (later) y = _ops_func(x) 531 if inspect.isclass(function_object): 532 node = self.insert_callfunction_node(func_scope_name, node_name, None, None, False) 533 if is_cls_type_obj: 534 # represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs) 535 node.set_type_cls(function_object) 536 # add import 537 if str(func_scope_name) == '_get_cache_prim': 538 import_name = astunparse.unparse(self.ast_assign.value.args[0]).strip() 539 if '.' not in import_name: 540 self._add_import(import_name) 541 else: 542 # represent the initialize of a class type, e.g. abs_inst = P.Abs() 543 node.set_init_cls(function_object) 544 # record local primitive objects 545 if func_scope_name.scope == 'self' and issubclass(function_object, Primitive): 546 self.stree.local_prim_inits.append(node) 547 return 548 # process third party functions 549 is_ms_func = is_ms_function(function_object) 550 if not is_ms_func and is_third_party(function_object): 551 logger.info(f"Ignore third party function '{func_scope_name}'.") 552 self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False) 553 return 554 # process mindspore functions 555 if is_ms_func and AssignParser._skip_ms_function: 556 logger.info(f"Ignore mindspore function '{func_scope_name}'.") 557 self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False) 558 return 559 # get ast.FunctionDef 560 source_code = inspect.getsource(function_object) 561 ast_functiondef = ast.parse(dedent(source_code)).body[0] 562 if not isinstance(ast_functiondef, ast.FunctionDef): 563 logger.info(error_str(f"Get ast.FunctionDef of function {str(func_scope_name)} failed, the type of " 564 f"ast node is {type(ast_functiondef)}", child_node=self.ast_assign)) 565 self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False) 566 return 567 if [n for n in ast_functiondef.body if isinstance(n, ast.FunctionDef)]: 568 logger.info(error_str(f"closure syntax is not supported now, {str(func_scope_name)} will not be parsed.", 569 child_node=ast_functiondef)) 570 if not func_scope_name.scope: 571 self._add_import(func_scope_name.value) 572 self.insert_callfunction_node(func_scope_name, node_name, None, function_object, False) 573 return 574 # update func_name, and remove scope 575 new_name = ast_functiondef.name 576 # when func_scope_name(e.g. 'C.uniform') is not the name in ast.FunctionDef(e.g. 'uniform'), this name may be 577 # already used as variable(e.g. uniform = C.uniform(x)). 578 # To avoid new function's name being duplicated with existed variable, an suffix '_opt' will be added. 579 if new_name != str(func_scope_name): 580 new_name = f"{new_name}_opt" 581 new_name = FunctionNamer().instance().get_name(new_name) 582 # create unparsable node if function is already parsed when sharing one implementation 583 if AssignParser._share_one_implementation and id(function_object) in AssignParser._cached_functions: 584 first_node = AssignParser._cached_functions.get(id(function_object)) 585 ast_call: ast.Call = self.ast_assign.value 586 ast_call.func = ast.Name(id=str(first_node.get_func_name()), ctx=ast.Load()) 587 self.insert_callfunction_node(func_scope_name, new_name, None, function_object, False) 588 return 589 ast_functiondef.name = new_name 590 ast_call: ast.Call = self.ast_assign.value 591 ast_call.func = ast.Name(id=new_name, ctx=ast.Load()) 592 # save ast.FunctionDef into stree._external_ast 593 self.stree.get_external_ast()[ast_functiondef] = [] 594 # import module which function defined in 595 func_file_path = inspect.getabsfile(function_object) 596 self.stree.save_imports_from_file(func_file_path, ast_functiondef) 597 # create CallFunction node 598 func_scope_name = ScopedValue.create_naming_value(new_name, "") 599 node = self.insert_callfunction_node(func_scope_name, new_name, ast_functiondef, function_object, False) 600 # save function node if it is firstly parsed when sharing one implementation 601 if AssignParser._share_one_implementation: 602 AssignParser._cached_functions[id(function_object)] = node 603 604 def insert_callfunction_node(self, func_name: ScopedValue, node_name: str, ast_functiondef: ast.FunctionDef, 605 func_obj: object, is_method: bool) -> Node: 606 """Create CallFunction node for function.""" 607 if ast_functiondef is None: 608 node = Node.inner_create_call_function(node_name, self.ast_assign, func_name, func_obj, 609 self.targets, self.args, self.kwargs) 610 self.stree.append_origin_field(node, self.node_manager) 611 return node 612 # create CallFunction node 613 node = CallFunction(self.targets, func_name, self.args, self.kwargs, node_name, self.ast_assign, 614 ast_functiondef, self.stree, func_obj, is_method) 615 self.stree.append_origin_field(node, self.node_manager) 616 # expand ast codes 617 ast_functiondef = AstFlattener().transform(ast_functiondef, [func_name.value], self.stree) 618 # parse ast codes into CallFunction Node 619 parser = ParserRegister.instance().get_parser(ast.FunctionDef) 620 parser.process(self.stree, ast_functiondef, node_manager=node) 621 return node 622 623 def process_ast_call(self, ast_call: ast.Call): 624 """ 625 Convert ast.Call to a symbol tree node. 626 627 Args: 628 ast_call (ast.Call): An ast.Call of assign node in construct. 629 """ 630 self.targets = AssignParser._create_targets(self.ast_assign.targets[0]) 631 self.args = [AstConverter.create_scopedvalue(arg) for arg in ast_call.args] 632 self.kwargs = AssignParser._create_kwargs(ast_call.keywords) 633 func_name = AssignParser._get_func_name(ast_call) 634 func_scope = AssignParser._get_func_scope(ast_call) 635 func_scope_name = ScopedValue.create_naming_value(func_name, func_scope) 636 func_full_name = str(func_scope_name) 637 # y = func(xxx)(xxx) / y = func1(xxx).func2(xxx) is not supported, and should be flattened before parsing. 638 if AstFinder(ast_call.func).find_all(ast.Call): 639 logger.info(error_str("ast.Call in func name of ast.Call is not supported.", ast_call, self.ast_assign)) 640 self.insert_callfunction_node(func_scope_name, func_name, None, None, False) 641 return 642 # Ignore built-in functions 643 if func_full_name in dir(builtins): 644 logger.info(f"Ignore built-in function: {func_scope_name}") 645 self.insert_callfunction_node(func_scope_name, func_name, None, None, False) 646 return 647 # Ignore function name is target of for loop 648 if isinstance(self.node_manager, ControlFlow) and func_full_name in self.node_manager.loop_vars: 649 logger.info(f"Ignore function of loop variable: {func_scope_name}") 650 self.insert_callfunction_node(func_scope_name, func_name, None, None, False) 651 return 652 # Instance with type of Cell 653 cell_inst = self._get_cell_instance(func_scope, func_name) 654 if cell_inst is not None: 655 self.process_cell(func_scope_name, func_name, cell_inst) 656 return 657 # Instance with type of Primitive 658 primitive_inst = self._get_primitive_instance(func_scope, func_name) 659 if primitive_inst is not None: 660 self.process_primitive(func_scope_name, func_name, primitive_inst) 661 return 662 # Class method object 663 method_object = self._get_method_object(func_scope, func_name) 664 if method_object is not None: 665 if inspect.ismethod(method_object): 666 self.process_class_method(func_scope_name, func_name, method_object) 667 elif isinstance(inspect.getattr_static(self.stree.get_origin_network(), func_name), staticmethod): 668 self.insert_callfunction_node(func_scope_name, func_name, None, None, False) 669 else: 670 self.process_function(func_scope_name, func_name, method_object, False) 671 return 672 # Local variable 673 is_local_var, primitive_obj = self._get_local_variable(func_scope, func_name) 674 if primitive_obj is not None: 675 self.process_function(func_scope_name, func_name, primitive_obj, False) 676 return 677 if is_local_var: 678 # for a variable whose type is not primitive instance, create normal node for it 679 self.insert_callfunction_node(func_scope_name, func_name, None, None, False) 680 return 681 # Function object 682 function_object, is_cls_type_obj = self._get_function_object(func_scope, func_name, ast_call) 683 if function_object is not None: 684 self.process_function(func_scope_name, func_name, function_object, is_cls_type_obj) 685 return 686 logger.info(error_str("Failed to get instance or object of ast.Call.", ast_call, self.ast_assign)) 687 self.insert_callfunction_node(func_scope_name, func_name, None, None, False) 688 689 def process_ast_mathops(self, ast_op: Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): 690 """ 691 Convert ast node of math operations(ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare) to 692 a symbol tree node. 693 694 Args: 695 ast_op (Union[ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare]): An assign node with mathematival 696 operation in construct function. 697 698 Raises: 699 TypeError: The type of parameter 'ast_op' is not in (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare). 700 701 """ 702 if not isinstance(ast_op, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)): 703 raise TypeError("The type of parameter 'ast_op' must be one of (ast.BinOp, ast.UnaryOp, " 704 "ast.BoolOp, ast.Compare), but got ", type(ast_op)) 705 706 targets = AssignParser._create_targets(self.ast_assign.targets[0]) 707 args = [] 708 op_type_str = type(ast_op).__name__ 709 op_type = ScopedValue.create_naming_value(op_type_str) 710 name = op_type_str 711 if isinstance(ast_op, ast.BinOp): 712 op = type(ast_op.op).__name__ 713 name = f'{name}_{op}' 714 args.append(AstConverter.create_scopedvalue(ast_op.left)) 715 args.append(AstConverter.create_scopedvalue(ast_op.right)) 716 elif isinstance(ast_op, ast.UnaryOp): 717 op = type(ast_op.op).__name__ 718 name = f'{name}_{op}' 719 args.append(AstConverter.create_scopedvalue(ast_op.operand)) 720 elif isinstance(ast_op, ast.BoolOp): 721 op = type(ast_op.op).__name__ 722 name = f'{name}_{op}' 723 for value in ast_op.values: 724 args.append(AstConverter.create_scopedvalue(value)) 725 elif isinstance(ast_op, ast.Compare): 726 args.append(AstConverter.create_scopedvalue(ast_op.left)) 727 for idx, ast_cmp_op in enumerate(ast_op.ops): 728 op = type(ast_cmp_op).__name__ 729 name = f'{name}_{op}' 730 args.append(AstConverter.create_scopedvalue(ast_op.comparators[idx])) 731 name = name.lower() 732 node = Node.create_mathops_node(self.ast_assign, targets, op_type, args, name) 733 self.stree.append_origin_field(node, self.node_manager) 734 735 def process_ast_constant(self, ast_constant: Union[ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str]): 736 """ 737 Convert ast node of constant types (ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str) to 738 a symbol tree node. 739 """ 740 node_name = f"{type(ast_constant).__name__.lower()}_assign" 741 targets = AssignParser._create_targets(self.ast_assign.targets[0]) 742 args = [AstConverter.create_scopedvalue(ast_constant)] 743 node = Node.create_call_method(self.ast_assign, targets, "pass_through", args, {}, node_name) 744 self.stree.append_origin_field(node, self.node_manager) 745 746 def process_ast_name(self, ast_node: Union[ast.Name, ast.Attribute]): 747 """ 748 Convert ast node of ast.Name and ast.Attribute to a symbol tree node. 749 """ 750 self.targets = AssignParser._create_targets(self.ast_assign.targets[0]) 751 inst, scope_name = AssignParser._get_inst_and_name(ast_node, self.stree) 752 if inst is not None and (isinstance(inst, CellList) or 753 isinstance(inst, list) and AssignParser._list_of_cells(inst)): 754 node = self.cell_container_process(scope_name, scope_name, inst) 755 else: 756 node_name = f"{type(ast_node).__name__.lower()}_assign" 757 args = [AstConverter.create_scopedvalue(ast_node)] 758 node = Node.create_call_method(self.ast_assign, self.targets, "pass_through", args, {}, node_name) 759 self.stree.append_origin_field(node, self.node_manager) 760 761 def process_ast_tuple(self, ast_node: Union[ast.Tuple, ast.List]): 762 """ 763 Convert ast node of ast.Tuple or ast.List to a symbol tree node. 764 """ 765 # ensure that each element's type in tuple is supported by scopled value 766 if AstConverter.ast_tuple_elts_support_scopledvalue(ast_node): 767 targets = AssignParser._create_targets(self.ast_assign.targets[0]) 768 args = [] 769 for elt in ast_node.elts: 770 args.append(AstConverter.create_scopedvalue(elt)) 771 func_name = "tuple" if isinstance(ast_node, ast.Tuple) else "list" 772 node = Node.create_call_method(self.ast_assign, targets, func_name, args, {}, func_name) 773 self.stree.append_origin_field(node, self.node_manager) 774 else: 775 logger.info(f"some elements in assign({astunparse.unparse(self.ast_assign)}) are not supported " 776 "in rewrite, fallback to python") 777 self.stree.try_append_python_node(self.ast_assign, self.ast_assign, self.node_manager) 778 779 def process_ast_dict(self, ast_dict: ast.Dict): 780 """ 781 Convert ast node of ast.Dict to a symbol tree node. 782 """ 783 # ensure that each element's type in dict is supported by scopled value 784 if AstConverter.ast_dict_support_scopledvalue(ast_dict): 785 targets = AssignParser._create_targets(self.ast_assign.targets[0]) 786 kwargs = {} 787 for idx, key in enumerate(ast_dict.keys): 788 kwargs[key.value] = AstConverter.create_scopedvalue(ast_dict.values[idx]) 789 func_name = ScopedValue.create_naming_value("dict") 790 node = Node.create_call_method(self.ast_assign, targets, func_name, [], kwargs, "dict") 791 self.stree.append_origin_field(node, self.node_manager) 792 else: 793 logger.info(f"some elements in assign({astunparse.unparse(self.ast_assign)}) are not supported " 794 "in rewrite, fallback to python") 795 self.stree.try_append_python_node(self.ast_assign, self.ast_assign, self.node_manager) 796 797 def process_ast_subscript(self, ast_subscript: ast.Subscript): 798 """ 799 Convert ast node of ast.Subscript to a symbol tree node. 800 """ 801 targets = AssignParser._create_targets(self.ast_assign.targets[0]) 802 args = [AstConverter.create_scopedvalue(ast_subscript)] 803 node = Node.create_call_method(self.ast_assign, targets, "pass_through", args, {}, "subscript_var") 804 self.stree.append_origin_field(node, self.node_manager) 805 806 def process(self, stree: SymbolTree, node: ast.Assign, node_manager: NodeManager): 807 """ 808 Parse ast.Assign and create a node in symbol tree. 809 810 - Create node when value of ast.Assign is in [ast.Call, ast.Name, ast.Constant, ast.Attribute]. 811 - Create python node when value of ast.Assign is in [ast.BinOp, ast.BoolOp, ast.Subscript, ast.List, ast.Tuple, 812 ast.Dict]. 813 - Other value types are not supported. 814 815 Args: 816 stree ([SymbolTree]): Symbol Tree under parsing. 817 node ([ast.Assign]): An ast.Assign node. 818 node_manager (NodeManager): NodeManager those asts belong to. 819 """ 820 if len(node.targets) != 1: 821 logger.info(error_str(f"Continuous assignment statement(e.g. 'a = b = 1') should be flatten before.", 822 child_node=node)) 823 stree.try_append_python_node(node, node, node_manager) 824 return 825 826 self.store_env() 827 self.stree = stree 828 self.ast_assign = node 829 self.node_manager = node_manager 830 value = node.value 831 if isinstance(value, ast.Call): 832 self.process_ast_call(value) 833 elif isinstance(value, (ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare)): 834 self.process_ast_mathops(value) 835 elif isinstance(value, ast.Subscript): 836 self.process_ast_subscript(value) 837 elif isinstance(value, (ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str)): 838 self.process_ast_constant(value) 839 elif isinstance(value, (ast.Name, ast.Attribute)): 840 self.process_ast_name(value) 841 elif isinstance(value, (ast.Tuple, ast.List)): 842 self.process_ast_tuple(value) 843 elif isinstance(value, ast.Dict): 844 self.process_ast_dict(value) 845 else: 846 logger.info(f"ops-call({astunparse.unparse(node).strip()}) in assign will be supported in near feature, " 847 f"ignored as a python node now") 848 stree.try_append_python_node(node, node, node_manager) 849 self.restore_env() 850 851 852g_assign_parser = reg_parser(AssignParser()) 853