1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Node class define of Rewrite. See detail in Node class docstring.""" 16from typing import Optional, Union, List, Dict 17import ast 18import inspect 19from types import FunctionType 20import sys 21 22from mindspore.nn import Cell 23from mindspore.ops import Primitive 24from mindspore import log as logger 25from ..api.scoped_value import ScopedValue, ValueType 26from ..api.node_type import NodeType 27from ..common.namespace import is_subtree 28from ..common.error_log import error_str 29from ..ast_helpers import AstModifier, AstReplacer, AstConverter 30from ... import _checkparam as Validator 31 32 33if sys.version_info >= (3, 9): 34 import ast as astunparse # pylint: disable=reimported, ungrouped-imports 35else: 36 import astunparse 37 38 39class LocalPrim(Primitive): 40 """This class is used to indicate a local primitive instance""" 41 def __init__(self, prim_obj: type): 42 super().__init__("rewrite_local_prim") 43 self.prim_obj = prim_obj 44 45 46class Node: 47 """ 48 Node is a data structure represents a source code line in network. For the most part, Node represents an operator 49 invoking in forward which could be an instance of Cell, an instance of Primitive or a callable method. Fields of 50 Node has different meaning in different type of node: 51 52 - CallCell: a call-cell node represents an assign statement whose value is a calling to cell in mindspore. 53 `targets` is corresponding to targets of ast.Assign which means return values of this cell-op. `args` and 54 `kwargs` are corresponding to args and keywords of ast.Call which mean arguments to invoke cell-op's forward 55 method. `func` is corresponding to func of call expression which means symbol of the cell-op. 56 - CallPrimitive: a call-primitive node represents an ast.Assign whose value is a calling to operator in mindspore. 57 `targets`, `args`, `kwargs` and `func_name` are as previous. 58 - CallMethod: a call-method node represents an ast.Assign whose value is a calling to python-method such as `len`. 59 `targets` is corresponding to targets of ast.Assign which means return values of this method. `func_name` 60 represents the string name of method. `args` and `kwargs` are corresponding to args and keywords to invoke the 61 method. When value of ast.Assign is an ast.Name or ast.Attribute, it means a simplest assign which would also be 62 mapped to CallMethod node whose `func_name` is "PassThrough". 63 - Python: a python node holds an ast-node which is not parsed. a python node means some python statement is not 64 supported by Rewrite or ignored by Rewrite. `targets`, `args`, `kwargs` and `func_name` are don't-care. 65 - Input: an input node represents an input of current network which also a parameter of forward method of Cell. 66 `targets` is corresponding to arg-name of parameter of forward function. `args` means default-value of parameter 67 of forward function. `kwargs` and `func_name` are don't-care. 68 - Output: an output node represents the output of current network which is corresponding to return statement of 69 forward method of Cell. `args` represents return values. `func_name` are always be "return". `targets` and 70 `kwargs` are don't-care. 71 - Tree: a tree node represents a sub-network call in current network. A sub-network is also a Cell in mindspore, so 72 `targets`, `args`, `kwargs` and `func_name` are same as a call-cell node. `symbol_tree` is a handler of a 73 SymbolTree instance. 74 """ 75 76 def __init__(self, node_type: NodeType, ast_node: Optional[ast.AST], targets: [ScopedValue], 77 func_name: Optional[ScopedValue], args: List[ScopedValue], kwargs: Dict[str, ScopedValue], name: str, 78 instance): 79 """ 80 Constructor of Node. Rewrite recommend invoking class method of Node to instantiate an instance of Node such 81 as `create_call_op`, `create_call_method`, `create_python_node`, `create_input_node` and 82 `create_output_node`, etc. rather than invoking constructor of Node directly. 83 84 Args: 85 node_type (NodeType): A NodeType as type of Node. 86 ast_node (ast.AST, optional): An instance of ast.AST represents corresponding node in ast. `ast_node` should 87 not be None except when node type is Unknown. 88 targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 89 func_name (ScopedValue, optional): An instance of ScopedValue. See detail in docstring of Node class. 90 args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 91 kwargs (Dict[str, ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 92 name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. 93 Name of node also used as field name in network class. 94 instance: Object in network corresponding to this node. 95 """ 96 self._node_type: NodeType = node_type 97 self._ast_node: Optional[ast.AST] = ast_node 98 self._attribute: {str, object} = {} 99 if node_type in (NodeType.CallModule, NodeType.CallCell, NodeType.CallPrimitive): 100 self._attribute = Node._get_cell_or_prim_op_attribute(instance) 101 self._instance = instance 102 self._name = name 103 self._func_name: Optional[ScopedValue] = func_name 104 self._targets: [ScopedValue] = targets if targets is not None else [] 105 self._args_num = len(args) if args is not None else 0 106 self._kwargs_num = len(kwargs) if kwargs is not None else 0 107 self._normalized_args_keys = [] # for saving args' order 108 self._normalized_args = self._get_normalized_args(args, kwargs) 109 # position in graph nodes list 110 # it will affect code-order of python code 111 self._prev: Optional[Node] = None 112 self._next: Optional[Node] = None 113 # A handler of SymbolTree current node belonging to 114 self._belong_tree = None 115 # A handler of NodeManager current node belonging to 116 self._node_manager = None 117 # A dict that records which target of which Node current Node's argument come from 118 self._arg_providers: {int: (Node, int)} = {} 119 # A dict that records which argument of which Node uses current Node's target 120 self._target_users: {int: [(Node, int)]} = {} 121 # Indicate this node represent a class type object, e.g. abs_ops = _get_cache_prim(P.Abs) 122 self._type_cls = None 123 # Indicate this node represent the initialize of a class type, e.g. abs_inst = P.Abs() 124 self._init_cls = None 125 126 @classmethod 127 def create_call_method(cls, ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]], 128 func_name: Union[ScopedValue, str], args: [ScopedValue] = None, 129 kwargs: {str: ScopedValue}=None, name: str = ""): 130 """ 131 Class method of Node. Instantiate an instance of node whose type is CallCell. A CallCell node represents an 132 invoking to cell-op. 133 134 Args: 135 ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. `ast_node` 136 should not be None currently. 137 targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 138 func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class. 139 args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 140 kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. 141 name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. 142 Name of node also used as field name in network class. 143 """ 144 if args is None: 145 args = [] 146 if kwargs is None: 147 kwargs = {} 148 if isinstance(func_name, str): 149 func_name = ScopedValue.create_naming_value(func_name) 150 new_targets = Node._handle_targets(targets) 151 if ast_node is None: 152 raise RuntimeError("Input ast_node is None") 153 return cls(NodeType.CallMethod, ast_node, new_targets, func_name, args, kwargs, name, None) 154 155 @classmethod 156 def create_python_node(cls, ast_node: ast.AST, name: str = "", instance=None): 157 """ 158 Class method of Node. Instantiate an instance of node whose type is Python. A Python node represents some python 159 statement is not supported by Rewrite or ignored by Rewrite. 160 161 Args: 162 ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. 163 name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. 164 Name of node also used as field name in network class. 165 instance: An object corresponding to this node in network. 166 """ 167 return cls(NodeType.Python, ast_node, None, None, [], {}, name, instance) 168 169 @classmethod 170 def create_input_node(cls, ast_node: Optional[ast.AST], arg_name: str, default: Optional[ScopedValue] = None, 171 name: str = ""): 172 """ 173 Class method of Node. Instantiate an instance of node whose type is Input. An Input node represents input of 174 SymbolTree which is corresponding to parameters of forward function. 175 176 Args: 177 ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. 178 arg_name (str): A string represents name of parameter. 179 default ([ScopedValue, optional]): An instance of ScopedValue represents default value of parameter. 180 name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. 181 Name of node also used as field name in network class. 182 """ 183 target = ScopedValue.create_naming_value(arg_name) 184 if default is None: 185 args = [] 186 else: 187 args = [default] 188 if ast_node is None: 189 ast_node = ast.arg(arg_name, annotation="") 190 return cls(NodeType.Input, ast_node, [target], None, args, {}, name, None) 191 192 @classmethod 193 def create_output_node(cls, ast_node: ast.AST, return_value: [ScopedValue], name: str = "return"): 194 """ 195 Class method of Node. Instantiate an instance of node whose type is Output. An Output node represents output of 196 SymbolTree which is corresponding to return statement of forward function. 197 198 Args: 199 ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. 200 return_values (list[str]): A list of string represents name of return values. 201 name (ScopedValue): An instance of ScopedValue represents name of node. 202 """ 203 return cls(NodeType.Output, ast_node, None, ScopedValue.create_naming_value("return"), return_value, {}, 204 name, None) 205 206 @classmethod 207 def create_mathops_node(cls, ast_node: ast.AST, targets: [ScopedValue], 208 op_type: ScopedValue, args: [ScopedValue], name: str = ""): 209 """ 210 Class method of Node. Instantiate an instance of node whose type is `MathOps` . 211 A mathops node is used to represent a node with mathematical operations, such as 212 `y = a + b` , `y = not a` , `y = 0 < a < 1`, `y = a or b` , etc. 213 214 Args: 215 ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. The type of 216 node is ast.Assign, and the type of ast_node.value is one of ast.BinOp, ast.UnaryOp, ast.BoolOp and 217 ast.Compare. 218 targets (list[ScopedValue]): Targets of mathematical operations. A list of instance of `ScopedValue`. 219 See detail in docstring of Node class. 220 op_type (ScopedValue): The type of ast_node.value saved by string. A ScopedValue with NamingValue type. 221 args (list[ScopedValue]): Values participating in the mathematical operations. All values are saved 222 sequentially in the list. 223 name (str): A string represents name of node. Name of node will be unique when inserted into `SymbolTree`. 224 Name of node also used as field name in network class. The format of mathops node name 225 is 'AstNodeName_AstOpName_n'. 226 """ 227 return cls(NodeType.MathOps, ast_node, targets, op_type, args, None, name, None) 228 229 @staticmethod 230 def _create_call_function(function: FunctionType, targets: [Union[ScopedValue, str]], args: [ScopedValue] = None, 231 kwargs: {str: ScopedValue}=None): 232 """ 233 Create a node that corresponds to a function call. 234 235 Args: 236 function (FunctionType): The function to be called. 237 targets (list[str]): indicates output names. Used as targets of an assign statement in source code. 238 args (list[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in 239 source code. Default: ``None`` , which indicates the `function` has no args inputs. 240 kwargs (dict): Type of key must be `str` and type of value must be `ScopedValue`. 241 Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source 242 code. Default: ``None`` , which indicates the `function` has no kwargs inputs. 243 244 Returns: 245 An instance of `Node`. 246 """ 247 if args is None: 248 args = [] 249 if kwargs is None: 250 kwargs = {} 251 targets = Node._handle_targets(targets) 252 func_name = function.__name__ 253 func_scope_name = ScopedValue.create_naming_value(func_name) 254 node = Node.inner_create_call_function(func_name, None, func_scope_name, function, targets, args, kwargs) 255 return node 256 257 @classmethod 258 def inner_create_call_function(cls, node_name: str, ast_node: ast.Assign, func_name: ScopedValue, func_obj: object, 259 targets: List[ScopedValue], args: List[ScopedValue], kwargs: Dict[str, ScopedValue]): 260 ''' 261 Instantiate an instance of node whose type is `CallFunction`. 262 263 Args: 264 node_name (str): Name of node. 265 func_name (ScopedValue): Name of function. 266 ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. 267 func_obj (Object): An object of function. See detail in docstring of Node class. 268 targets (List[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class. 269 args (List[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class. 270 kwargs (Dict[str, ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of `Node` 271 class. 272 ''' 273 from . import CallFunction 274 # create CallFunction node 275 return CallFunction(targets, func_name, args, kwargs, node_name, ast_node, None, None, func_obj, False) 276 277 @staticmethod 278 def create_call_op(op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [Union[ScopedValue, str]], 279 args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, node_name: str = "", 280 is_sub_net: bool = False): 281 """ 282 Static method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`. 283 If op is custom defined, it is treated by TreeNode. 284 A `CallCell` node represents an invoking to cell-op. 285 A `CallPrimitive` node represents an invoking to primitive-op. 286 287 Args: 288 op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node. 289 ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. 290 targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class. 291 args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class. 292 kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node` 293 class. 294 node_name (str): A string represents name of node. Name of node will be unique when inserted into 295 `SymbolTree`. Name of node also used as field name in network class. 296 is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse the 297 `cell` to a TreeNode, else a CallCell Node. Default is a False. 298 """ 299 Validator.check_value_type("op", op, [Cell, Primitive], "Node") 300 if ast_node is not None: 301 Validator.check_value_type("ast_node", ast_node, [ast.AST], "Node") 302 Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node") 303 if args is not None: 304 Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node") 305 if kwargs is not None: 306 Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node") 307 if args is None: 308 args = [] 309 if kwargs is None: 310 kwargs = {} 311 Validator.check_value_type("node_name", node_name, [str], "Node") 312 new_targets = Node._handle_targets(targets) 313 if isinstance(node_name, str): 314 func_name = ScopedValue.create_naming_value(node_name) 315 else: 316 func_name = node_name 317 if is_sub_net and is_subtree(op): 318 from ..symbol_tree import SymbolTreeBuilder 319 stb = SymbolTreeBuilder(op) 320 stree = stb.build() 321 replacer = AstReplacer(stree.get_class_ast()) 322 replacer.replace_all(stree.get_ori_cls_name(), stree.get_opt_cls_name()) 323 return TreeNode.create_tree_node(stree, ast_node, new_targets, func_name, args, kwargs, node_name, op) 324 325 return Node.create_call_buildin_op(op, ast_node, new_targets, func_name, args, kwargs, node_name) 326 327 @classmethod 328 def create_call_buildin_op(cls, op: Union[Cell, Primitive], ast_node: Optional[ast.AST], targets: [ScopedValue], 329 func_name: ScopedValue, args: [ScopedValue] = None, kwargs: {str: ScopedValue}=None, 330 node_name: str = ""): 331 """ 332 Class method of Node. Instantiate an instance of node whose type is `CallCell` or `CallPrimitive`. 333 A `CallCell` node represents an invoking to cell-op. 334 A `CallPrimitive` node represents an invoking to primitive-op. 335 336 Args: 337 op (Union[Cell, Primitive]): An instance of `Cell` or `Primitive` corresponding to this node. 338 ast_node ([ast.AST, optional]): An instance of ast.AST represents corresponding node in ast. 339 targets (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class. 340 func_name ([ScopedValue, optional]): An instance of `ScopedValue`. See detail in docstring of Node class. 341 args (list[ScopedValue]): A list of instance of `ScopedValue`. See detail in docstring of Node class. 342 kwargs (dict{str: ScopedValue}): A list of instance of `ScopedValue`. See detail in docstring of `Node` 343 class. 344 node_name (str): A string represents name of node. Name of node will be unique when inserted into 345 `SymbolTree`. Name of node also used as field name in network class. 346 """ 347 348 if not isinstance(op, (Cell, Primitive)): 349 raise ValueError("Input op is not a buildin op(Cell or Primitive): ", type(op)) 350 if isinstance(op, Cell): 351 node_type = NodeType.CallCell 352 else: 353 node_type = NodeType.CallPrimitive 354 return cls(node_type, ast_node, targets, func_name, args, kwargs, node_name, op) 355 356 @staticmethod 357 def _get_construct_arg_names(parameters): 358 """ 359 Static method of `Node`. Get parameters' names of the construct function. 360 361 Args: 362 parameters (MappingProxyType): An ordered mapping of parameters' names to the corresponding Parameter 363 objects. 364 365 Raises: 366 RuntimeError: Invalid parameter kind. 367 368 Returns: 369 - arg_names, Parameters' names, contain parameters of types in [POSITIONAL_ONLY, POSITIONAL_OR_KEYWORD]. 370 - var_positional_name, Name of VAR_POSITIONAL parameters. 371 - var_keyword_name, Name of VAR_KEYWORD parameters. 372 """ 373 position_only_names: [str] = [] 374 positional_or_keyword_names: [str] = [] 375 var_positional_name = None 376 keyword_only_names: [str] = [] 377 var_keyword_name = None 378 for name, para in parameters.items(): 379 if para.kind == inspect.Parameter.POSITIONAL_ONLY: # parameters which appear before a '/' 380 position_only_names.append(name) 381 elif para.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: # parameters which appear before '*' or '*args' 382 positional_or_keyword_names.append(name) 383 elif para.kind == inspect.Parameter.VAR_POSITIONAL: # corresponds to a '*args' 384 var_positional_name = name 385 elif para.kind == inspect.Parameter.KEYWORD_ONLY: # parameters which appear after '*' and before '**' 386 keyword_only_names.append(name) 387 elif para.kind == inspect.Parameter.VAR_KEYWORD: # corresponds to a '**kwargs' 388 var_keyword_name = name 389 else: 390 raise RuntimeError("invalid parameter kind:", para.kind) 391 if "self" in position_only_names: 392 position_only_names.remove("self") 393 if "self" in positional_or_keyword_names: 394 positional_or_keyword_names.remove("self") 395 names = (position_only_names, positional_or_keyword_names, var_positional_name, keyword_only_names, 396 var_keyword_name) 397 return names 398 399 @staticmethod 400 def _map_args_names(names: tuple, args: [ScopedValue], kwargs: {str: ScopedValue}, 401 normalized_args_keys: [str], normalized_args: {str: ScopedValue}): 402 """ 403 Fill in normalized_args according to the order of parameters of construct func. 404 405 Args: 406 names (tuple): Parameters' name got from construct func. 407 args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 408 kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. 409 normalized_args (dict{str: ScopedValue}): The normalized args to be filled. 410 411 Raises: 412 RuntimeError: Input args are invalid. 413 RuntimeError: Arg name already exist in kwargs. 414 RuntimeError: Input kwargs invalid. 415 """ 416 position_only_names, positional_or_keyword_names, var_positional_name, keyword_only_names, var_keyword_name = \ 417 names 418 for arg_index, arg in enumerate(args): 419 if arg_index < len(position_only_names): 420 arg_key = position_only_names[arg_index] 421 elif arg_index < len(position_only_names) + len(positional_or_keyword_names): 422 arg_key = positional_or_keyword_names[arg_index - len(position_only_names)] 423 elif var_positional_name: 424 arg_key = "{}_{}".format(var_positional_name, arg_index) 425 else: 426 raise RuntimeError("Input args are invalid.") 427 428 if arg_key in kwargs.keys(): 429 raise RuntimeError("Arg name already exist in kwargs.") 430 normalized_args[arg_key] = arg 431 normalized_args_keys.append(arg_key) 432 433 # add kwargs according to parameters' order 434 parameters_order: [str] = [] 435 parameters_order.extend(position_only_names) 436 parameters_order.extend(positional_or_keyword_names) 437 parameters_order.append(var_keyword_name) 438 parameters_order.extend(keyword_only_names) 439 parameters_order.append(var_keyword_name) 440 441 sorted_kwargs = [] 442 var_keyword_count = len(parameters_order) 443 for arg_key, value in kwargs.items(): 444 if arg_key not in parameters_order and not var_keyword_name: 445 raise RuntimeError("Input kwargs invalid.") 446 if arg_key in parameters_order: 447 sorted_kwargs.append([arg_key, value, parameters_order.index(arg_key)]) 448 else: 449 sorted_kwargs.append([arg_key, value, var_keyword_count]) 450 var_keyword_count += 1 451 452 sorted_kwargs.sort(key=lambda x: x[2]) 453 for sorted_kwarg in sorted_kwargs: 454 normalized_args[sorted_kwarg[0]] = sorted_kwarg[1] 455 normalized_args_keys.append(sorted_kwarg[0]) 456 457 @staticmethod 458 def _handle_custom_obj_in_args(args: [ScopedValue]) -> [ScopedValue]: 459 """ 460 Convert CustomObjValue type argument to NamingValue type argument. 461 462 Args: 463 args (list[ScopedValue]): A list of instance of ScopedValue to be converted. 464 465 Returns: 466 A list of instance of ScopedValue which have been converted. 467 """ 468 result = [] 469 for arg in args: 470 if not isinstance(arg, ScopedValue): 471 raise TypeError("arg should be ScopedValue, got: ", type(arg)) 472 if arg.type == ValueType.CustomObjValue: 473 logger.info("custom-object exist in args, should be replace before compile") 474 result.append(ScopedValue.create_naming_value("custom-object", "self")) 475 else: 476 result.append(arg) 477 return result 478 479 @staticmethod 480 def _handle_custom_obj_in_kwargs(kwargs: {str: ScopedValue}) -> {str: ScopedValue}: 481 """ 482 Convert CustomObjValue type argument to NamingValue type argument. 483 484 Args: 485 kwargs (dict{str: ScopedValue}): A str to instance of ScopedValue dict whose value to be converted. 486 487 Returns: 488 A str to instance of ScopedValue dict whose value has be converted. 489 """ 490 result: {str, ScopedValue} = {} 491 for arg, value in kwargs.items(): 492 if not isinstance(value, ScopedValue): 493 raise TypeError("value should be ScopedValue, got: ", type(value)) 494 if value.type == ValueType.CustomObjValue: 495 result[arg] = ScopedValue.create_naming_value("custom-object", "self") 496 else: 497 result[arg] = value 498 return result 499 500 @staticmethod 501 def _handle_targets(targets: [Union[ScopedValue, str]]) -> [ScopedValue]: 502 """ 503 Normalize targets to be a list of ScopedValue. If target is a str, it will be converted to NamingValue type 504 ScopedValue. 505 506 Args: 507 targets (Union[ScopedValue, str]]): A list whose element could be a ScopedValue or a str to be normalized. 508 509 Returns: 510 A list of instance of ScopedValue which have been converted. 511 """ 512 if not isinstance(targets, list): 513 raise TypeError("targets should be list, got: ", type(targets)) 514 results = [] 515 for target in targets: 516 if isinstance(target, str): 517 scope = "" 518 name = target 519 if target.count('.') > 0: 520 scope, name = target.rsplit('.', 1) 521 results.append(ScopedValue.create_naming_value(name, scope)) 522 elif isinstance(target, ScopedValue): 523 results.append(target) 524 else: 525 raise RuntimeError("Invalid symbol type: ", target) 526 return results 527 528 @staticmethod 529 def _get_cell_or_prim_op_attribute(obj) -> dict: 530 """ 531 Find attributes of cell-op or primitive-op. 532 533 Args: 534 obj: A cell-op or a primitive-op. 535 536 Returns: 537 A dict represents attributes of input 'obj'. 538 """ 539 attributes = {} 540 if obj is None: 541 return attributes 542 for k, v in obj.__dict__.items(): 543 if k.startswith("_"): 544 continue 545 attributes[k] = v 546 attributes["cls"] = obj.__class__ 547 return attributes 548 549 def get_type_cls(self) -> object: 550 """Get the class type object this node represented, e.g. abs_ops = _get_cache_prim(P.Abs)""" 551 return self._type_cls 552 553 def set_type_cls(self, x): 554 """Set the class type object this node represented, e.g. abs_ops = _get_cache_prim(P.Abs)""" 555 self._type_cls = x 556 557 def get_init_cls(self) -> object: 558 """Get the class type object initialized by this node, e.g. abs_inst = P.Abs()""" 559 return self._init_cls 560 561 def set_init_cls(self, x): 562 """Set the class type object initialized by this node""" 563 self._init_cls = x 564 565 def get_prev(self) -> 'Node': 566 """ 567 Get previous node of current node in source code order. 568 569 Returns: 570 An instance of Node as previous node. 571 """ 572 return self._prev 573 574 def get_next(self) -> 'Node': 575 """ 576 Get next node of current node in source code order. 577 578 Returns: 579 An instance of Node as next node. 580 """ 581 return self._next 582 583 def set_prev(self, node: 'Node'): 584 """ 585 Set previous node of current node. 586 587 Args: 588 node (Node): Node to be set as previous node of current node. 589 """ 590 self._prev = node 591 592 def set_next(self, node: 'Node'): 593 """ 594 Set next node of current node. 595 596 Args: 597 node (Node): Node to be set as next node of current node. 598 """ 599 self._next = node 600 601 def get_ast(self) -> Optional[ast.AST]: 602 """ 603 Getter of _ast_node. 604 605 Returns: 606 An instance of ast.AST if self._ast_node if not None else None. 607 """ 608 return self._ast_node 609 610 def set_ast(self, ast_node: ast.AST): 611 """ 612 Setter of _ast_node. 613 614 Args: 615 ast_node (ast.AST): An instance of ast.AST as new value for _ast_node. 616 """ 617 if not isinstance(ast_node, ast.AST): 618 raise TypeError("ast_node should be ast.AST, got: ", type(ast_node)) 619 self._ast_node = ast_node 620 621 def get_belong_symbol_tree(self): 622 """Get the symbol tree to which node belongs.""" 623 return self._belong_tree 624 625 def set_belong_symbol_tree(self, symbol_tree): 626 """Set the symbol tree to which node belongs.""" 627 self._belong_tree = symbol_tree 628 629 def get_node_manager(self): 630 """Get the NodeManager current node belongs to.""" 631 return self._node_manager 632 633 def set_node_manager(self, node_manager): 634 """Set NodeManager current node belongs.""" 635 self._node_manager = node_manager 636 637 def isolate(self): 638 """Link prev node to next node and isolate node from source code order list.""" 639 origin_prev: Optional[Node] = self.get_prev() 640 origin_next: Optional[Node] = self.get_next() 641 if origin_prev is not None: 642 origin_prev.set_next(origin_next) 643 if origin_next is not None: 644 origin_next.set_prev(origin_prev) 645 self.set_prev(None) 646 self.set_next(None) 647 648 def insert_before(self, node: 'Node'): 649 """ 650 Insert a node before current node in source code list. Note that topological order is not determined here. 651 652 Args: 653 node (Node): An instance of node to be inserted in. 654 """ 655 node.isolate() 656 origin_prev: Optional[Node] = self.get_prev() 657 if origin_prev is not None: 658 origin_prev.set_next(node) 659 node.set_prev(origin_prev) 660 node.set_next(self) 661 self.set_prev(node) 662 663 def insert_after(self, node: 'Node'): 664 """ 665 Insert a node after current node in source code list. Note that topological order is not determined here. 666 667 Args: 668 node (Node): An instance of node to be inserted in. 669 """ 670 node.isolate() 671 origin_next: Optional[Node] = self.get_next() 672 self.set_next(node) 673 node.set_prev(self) 674 node.set_next(origin_next) 675 if origin_next is not None: 676 origin_next.set_prev(node) 677 678 def get_inputs(self) -> ['Node']: 679 """ 680 Get input nodes of current node in topological order. 681 682 Returns: 683 A list of instances of Node as input nodes. 684 """ 685 inputs = [] 686 for arg_provider in self.get_arg_providers().values(): 687 if not arg_provider: 688 continue 689 inputs.append(arg_provider[0]) 690 return inputs 691 692 def get_users(self) -> ['Node']: 693 """ 694 Get user nodes of current node in topological order. 695 696 Returns: 697 A list of instances of Node as user nodes. 698 """ 699 users = [] 700 for target_users in self.get_target_users().values(): 701 if not target_users: 702 continue 703 for (user, _) in target_users: 704 if user not in users: 705 users.append(user) 706 return users 707 708 def get_targets(self) -> [ScopedValue]: 709 """ 710 Getter of _targets. 711 712 - When node_type of current node is CallCell or CallPrimitive or CallMethod or Tree, `targets` are strings 713 represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets of 714 ast.Assign. 715 - When node_type of current node is Input, `targets` should have only one element which is a string represents 716 name of parameter of function. 717 - When node_type of current node is Python or Output, `targets` are don't-care. 718 719 Returns: 720 A list of instances of ScopedValue as targets of node. 721 """ 722 return self._targets 723 724 def set_targets(self, targets: [ScopedValue]): 725 """ 726 Setter of _targets. 727 728 Note: 729 This interface can only be called before node been inserted into symbol-tree because target will be unique 730 while insert into symbol-tree, in other word, set_targets is not a user-interface. 731 732 When `_targets` is updated, corresponding ast node would be updated also. 733 734 When node_type of current node is CallCell or CallPrimitive or CallMethod or Tree, `targets` are strings 735 represents invoke result of the cell-op or primitive-op or function-call which are corresponding to targets 736 of ast.Assign. 737 738 When node_type of current node is Input, `targets` should have only one element which is a string represents 739 name of parameter of function. 740 741 When node_type of current node is Python or Output, `targets` are don't-care. 742 743 Args: 744 targets ([ScopedValue]): A list of instances of ScopedValue as new targets. 745 """ 746 self._targets = targets 747 if self._node_type in (NodeType.CallCell, NodeType.CallMethod, NodeType.CallPrimitive, 748 NodeType.Tree, NodeType.CallFunction, NodeType.CellContainer, 749 NodeType.MathOps): 750 self._sync_assign_targets_to_ast() 751 752 def get_func_name(self) -> ScopedValue: 753 """ 754 Getter of `_func_name`. See detail in docstring of Node class for meaning of func. 755 756 Returns: 757 An instance of ScopedValue. 758 """ 759 return self._func_name 760 761 def set_func_name(self, func_name: ScopedValue): 762 """ 763 Setter of `_func_name`. See detail in docstring of Node class for meaning of func. 764 765 Note: 766 When `_func_name` is updated, corresponding ast node would be updated also. 767 768 Args: 769 func (ScopedValue): An instance of ScopedValue as new func. 770 """ 771 self._func_name = func_name 772 if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive): 773 self._sync_assign_func_name_to_ast() 774 775 def get_name(self) -> str: 776 """ 777 Getter of `_name`. 778 779 Returns: 780 A str represents name of node. 781 """ 782 return self._name 783 784 def set_name(self, name: str): 785 """ 786 Setter of `_name`. 787 788 Args: 789 name (str): A str as new name of node. 790 """ 791 self._name = name 792 793 def get_node_type(self) -> NodeType: 794 """ 795 Get the node_type of current node. 796 797 Returns: 798 A NodeType as node_type of node. 799 """ 800 return self._node_type 801 802 def get_instance_type(self) -> type: 803 """ 804 Get the instance_type of current node. 805 806 - When node_type of current node is CallCell, instance_type is type of cell-op. 807 - When node_type of current node is CallPrimitive, instance_type is type of primitive-op. 808 - When node_type of current node is Tree, instance_type is type of network-cell. 809 - When node_type of current node is Python, Input, Output or CallMethod, instance_type should be NoneType 810 811 Returns: 812 A type. 813 """ 814 if isinstance(self._instance, LocalPrim): 815 return self._instance.prim_obj 816 if inspect.isfunction(self._instance): 817 return self._instance 818 return type(self._instance) 819 820 def get_instance(self): 821 """ 822 Get the instance of current node. 823 824 - When node_type of current node is CallCell, instance is an instance of Cell. 825 - When node_type of current node is CallPrimitive, instance is an instance of primitive. 826 - When node_type of current node is Tree, instance is an instance of network-cell. 827 - When node_type of current node is Python, Input, Output or CallMethod, instance should be None 828 829 Returns: 830 A object. 831 """ 832 return self._instance 833 834 def set_arg_by_node(self, arg_idx: int, node: 'Node', out_idx: Optional[int] = None): 835 """ 836 Set argument by another Node. 837 Note that when _normalized_args is updated, corresponding ast node would be updated also. 838 839 Args: 840 arg_idx (int): Indicate which input being modified. 841 node (Node): Node as new input. Can be a node or name of node. 842 out_idx ([int, optional]): Indicate which output of `node` as new argument. Default is None which means use 843 first output of `node_to_link` as new input. 844 845 Raises: 846 ValueError: If `arg_idx` is out of range. 847 ValueError: If `node` has multi-outputs while `out_idx` is None or `out_idx` is not offered. 848 """ 849 Validator.check_value_type("node", node, [Node], "Node") 850 Validator.check_int_range(arg_idx, 0, self._args_num, Validator.INC_LEFT, "arg_idx") 851 if out_idx is None: 852 if len(node.get_targets()) != 1: 853 raise ValueError("node should has one output when out_idx is not provided") 854 out_idx = 0 855 Validator.check_int_range(out_idx, 0, len(node.get_targets()), Validator.INC_LEFT, "arg_idx") 856 new_arg = node.get_targets()[out_idx] 857 self._normalized_args[self._normalized_args_keys[arg_idx]] = new_arg 858 self._sync_arg() 859 860 def set_arg(self, arg: Union[ScopedValue, str], index: int) -> (ScopedValue, ScopedValue): 861 """ 862 Set argument of `node`. 863 Note that when _normalized_args is updated, corresponding ast node would be updated also. 864 865 Args: 866 index (int): Indicate which input being modified. 867 arg (Union[ScopedValue, str]): New argument to been set. 868 869 Raises: 870 ValueError: If `index` is out of range. 871 """ 872 Validator.check_int_range(index, 0, self._args_num, Validator.INC_LEFT, "index") 873 Validator.check_value_type("arg", arg, [ScopedValue, str], "Node") 874 if isinstance(arg, str): 875 arg = ScopedValue.create_naming_value(arg) 876 old_arg = self._normalized_args.get(self._normalized_args_keys[index]) 877 self._normalized_args[self._normalized_args_keys[index]] = arg 878 self._sync_arg() 879 return arg, old_arg 880 881 def set_args(self, args: [ScopedValue]): 882 """ 883 Set arguments of `node`. 884 Note that when _normalized_args is updated, corresponding ast node would be updated also. 885 886 Args: 887 args (list[ScopedValue]): New arguments to been set. 888 889 Raises: 890 TypeError: Element of new argument is not an instance of ScopedValue. 891 """ 892 Validator.check_int_range(len(args), 0, self._args_num, Validator.INC_LEFT, "Length of args") 893 Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node") 894 for arg_index, arg in enumerate(args): 895 if not isinstance(arg, ScopedValue): 896 raise TypeError("arg should be ScopedValue, got: ", type(arg)) 897 self._normalized_args[self._normalized_args_keys[arg_index]] = arg 898 self._sync_arg() 899 900 def set_kwargs(self, kwargs: {str: ScopedValue}): 901 """ 902 Set keywords arguments of 'node'. 903 Note that when _normalized_args is updated, corresponding ast node would be updated also. 904 905 Args: 906 kwargs (dict{str: ScopedValue}): New arguments to been set. 907 908 Raises: 909 TypeError: Value of new argument is not an instance of ScopedValue. 910 RuntimeError: Length of new arguments is not equal to length of old arguments. 911 """ 912 Validator.check_int_range(len(kwargs), 0, self._kwargs_num, Validator.INC_LEFT, "Length of kwargs") 913 Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node") 914 for key, arg in kwargs.items(): 915 if key not in self._normalized_args.keys() or key not in self._normalized_args_keys: 916 raise RuntimeError("Input key is not exist, ", key) 917 if not isinstance(arg, ScopedValue): 918 raise TypeError("arg should be ScopedValue, got: ", type(arg)) 919 self._normalized_args[key] = arg 920 self._sync_arg() 921 922 def set_kwarg(self, key: str, arg: ScopedValue): 923 """ 924 Set keyword argument of 'node'. 925 Note that when _normalized_args is updated, corresponding ast node would be updated also. 926 927 Args: 928 key (str): A str represents key of new argument. 929 arg (ScopedValue): An instance of ScopedValue represents argument. 930 931 Raises: 932 RuntimeError: If 'key' is not in original kwargs' keys. 933 """ 934 if key not in self._normalized_args_keys[self._args_num:] or key not in self._normalized_args.keys(): 935 raise RuntimeError("Input key is not exist, ", key) 936 self._normalized_args[key] = arg 937 self._sync_arg() 938 939 def get_args(self): 940 """ 941 Get the arguments of current node. 942 943 - When node_type of current node is CallCell, CallPrimitive or Tree, arguments are corresponding to args of 944 ast.Call which represents arguments to invoke cell-op's forward method or primitive-op's `call()` method. 945 - When node_type of current node is Input, arguments represents default-value of argument of function. 946 - When node_type of current node is Output, arguments represents return values. 947 - When node_type of current node is Python, arguments are don't-care. 948 949 Returns: 950 A list of instances of ScopedValue. 951 """ 952 args = [] 953 for arg_index in range(self._args_num): 954 args.append(self._normalized_args.get(self._normalized_args_keys[arg_index])) 955 return args 956 957 def get_kwargs(self): 958 """ 959 Get the keyword arguments of current node. 960 961 - When node_type of current node is CallCell, CallPrimitive or Tree, keyword arguments are corresponding to 962 kwargs of ast.Call which represents arguments to invoke cell-op's forward method or primitive-op's `call()` 963 method. 964 - When node_type of current node is Python, Input or Output, keyword arguments are don't-care. 965 966 Returns: 967 A dict of str to instance of ScopedValue. 968 """ 969 kwargs: {str, ScopedValue} = {} 970 for arg_index in range(self._args_num, self._args_num + self._kwargs_num): 971 key = self._normalized_args_keys[arg_index] 972 kwargs[key] = self._normalized_args.get(key) 973 return kwargs 974 975 def get_normalized_args(self) -> {str: ScopedValue}: 976 """ 977 Get the normalized keyword arguments of current node. 978 Normalized arguments combine arguments and keyword arguments into keyword arguments by using parameter name as 979 key of arguments. 980 981 Returns: 982 A dict of str to instance of ScopedValue. 983 """ 984 output = {} 985 for key in self._normalized_args_keys: 986 output[key] = self._normalized_args.get(key) 987 return output 988 989 def set_normalized_args(self, args: {str, ScopedValue}): 990 """ 991 Set the normalized keyword arguments of current node. 992 Normalized arguments combine arguments and keyword arguments into keyword arguments by using parameter name as 993 key of arguments. 994 995 Args: 996 args ({str, ScopedValue}): A dict of str to instance of ScopedValue represents new normalized_args. 997 """ 998 if len(args.values()) != len(self._normalized_args_keys): 999 raise RuntimeError("Length of args.values() should be equal to length of _normalized_args_keys, ", 1000 len(args.values()), " vs ", len(self._normalized_args_keys)) 1001 for key, arg in args.items(): 1002 self._normalized_args[key] = arg 1003 self._sync_arg() 1004 1005 def set_attribute(self, key: str, value): 1006 """ 1007 Set attribute of current node. 1008 1009 Args: 1010 key (str): Key of new attribute. 1011 value (object): Value of new attribute. 1012 """ 1013 self._attribute[key] = value 1014 1015 def set_attributes(self, attributes): 1016 """ 1017 Set attributes of current node. 1018 1019 Args: 1020 attributes (dict): A dict represents new attributes. 1021 """ 1022 self._attribute = attributes 1023 1024 def get_attributes(self): 1025 """ 1026 Get all attributes of current node. 1027 1028 Returns: 1029 A dict of str to instance of object as attributes. 1030 """ 1031 return self._attribute 1032 1033 def get_attribute(self, key: str): 1034 """ 1035 Get attribute of current node by key. 1036 1037 Args: 1038 key (str): A str represents key of attribute you want to get. 1039 1040 Returns: 1041 A object as attribute. 1042 """ 1043 return self._attribute.get(key) 1044 1045 def get_arg_providers(self) -> dict: 1046 """ 1047 Getter of _arg_providers. 1048 1049 Return: 1050 dict, key is type of int indicating the index of args, and value is type of tuple, which includes 1051 the node and the index of node's targets who provides the argument. 1052 """ 1053 return self._arg_providers 1054 1055 def set_arg_providers(self, index: int, provider: tuple): 1056 """ 1057 Setter of _arg_providers. 1058 1059 Args: 1060 index (int): Indicating provider of which argument need to be set. 1061 provider (tuple): A tuple includes the node and the index of node's targets who provides the argument. 1062 """ 1063 self._arg_providers[index] = provider 1064 1065 def get_target_users(self, index=-1) -> Union[dict, list]: 1066 """ 1067 Getter of _target_users. 1068 1069 Args: 1070 index (int): Indicating users of which target need to be got. Default: -1, means all targets's users will 1071 be returned. 1072 1073 Return: 1074 Union[dict, list]. When index is not -1, a list of users of specified target will be returned. 1075 The type of elements in list is tuple, which includes the user node and the index of node's arguments 1076 who uses the target. When index is -1, a dict will be returned. The key is index of targets, and the 1077 value is list of users of corresponding target. 1078 """ 1079 if index == -1: 1080 return self._target_users 1081 if index not in self._target_users.keys(): 1082 self._target_users[index] = [] 1083 return self._target_users.get(index, None) 1084 1085 def append_target_users(self, index: int, provider: tuple): 1086 """ 1087 Setter of _target_users. 1088 1089 Args: 1090 index (int): Indicating users of which target need to be append. 1091 provider (tuple): A tuple includes the node and the index of node's argument who uses the target. 1092 1093 """ 1094 if index not in self._target_users.keys(): 1095 self._target_users[index] = [] 1096 self._target_users.get(index).append(provider) 1097 1098 def update_ast_node(self) -> ast.AST: 1099 """Update node's ast_node by current targets, func_name, args and kwargs.""" 1100 ast_assign = AstModifier.create_call_assign(self.get_targets(), self.get_func_name(), 1101 self.get_args(), self.get_kwargs()) 1102 self.set_ast(ast_assign) 1103 return ast_assign 1104 1105 def get_source_code(self) -> str: 1106 """Get source code of node from ast of node.""" 1107 return astunparse.unparse(self._ast_node).strip() 1108 1109 def append_kwarg(self, kwarg: Dict[str, ScopedValue]): 1110 """ 1111 Append a new keyword arg to node. 1112 1113 Args: 1114 kwarg (Dict[str, ScopedValue]): The new keyword arg. 1115 1116 """ 1117 if self.get_node_type() not in [NodeType.Tree, NodeType.CallFunction]: 1118 raise TypeError(f"For append_new_kwarg, the type of node can only be one of [Tree, CallFunction], " 1119 f"but got {self.get_node_type()}") 1120 Validator.check_element_type_of_dict("kwarg", kwarg, [str], [ScopedValue], "append_new_kwarg") 1121 for arg_key, value in kwarg.items(): 1122 # add keyword into _normalized_args 1123 self._normalized_args[arg_key] = value 1124 self._normalized_args_keys.append(arg_key) 1125 self._kwargs_num += 1 1126 # add keyword ast into ast.Call 1127 ast_assign: ast.Assign = self._ast_node 1128 ast_call: ast.Call = ast_assign.value 1129 new_keyword = ast.keyword(arg=arg_key, value=AstModifier.get_ast_by_value(value, None)) 1130 ast_call.keywords.append(new_keyword) 1131 1132 def _get_normalized_args(self, args: [ScopedValue], kwargs: {str: ScopedValue}) -> dict: 1133 """ 1134 Merge args and kwargs to normalized args. 1135 The keys of args are obtained from the construct function of type(self._instance). 1136 1137 Args: 1138 args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 1139 kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. 1140 1141 Raises: 1142 RuntimeError: Input args are invalid. 1143 RuntimeError: Arg name already exist in kwargs. 1144 1145 Returns: 1146 The normalized args. 1147 """ 1148 if not args: 1149 args = [] 1150 if not kwargs: 1151 kwargs = {} 1152 normalized_args: dict = dict() 1153 if (args or kwargs) and self._instance and hasattr(type(self._instance), "construct"): 1154 parameters = inspect.signature(type(self._instance).construct).parameters 1155 names = Node._get_construct_arg_names(parameters) 1156 Node._map_args_names(names, args, kwargs, self._normalized_args_keys, normalized_args) 1157 else: 1158 logger.debug("fail to get arg name from op, using arg_xx for args' name") 1159 arg_temp_name, suffix = "arg", 0 1160 for arg in args: 1161 arg_key = "{}_{}".format(arg_temp_name, suffix) 1162 while arg_key in kwargs.keys() or arg_key in normalized_args.keys(): 1163 suffix += 1 1164 arg_key = "{}_{}".format(arg_temp_name, suffix) 1165 normalized_args[arg_key] = arg 1166 self._normalized_args_keys.append(arg_key) 1167 for arg_key, value in kwargs.items(): 1168 normalized_args[arg_key] = value 1169 self._normalized_args_keys.append(arg_key) 1170 return normalized_args 1171 1172 # Synchronize rewrite node args to ast node 1173 def _sync_assign_func_name_to_ast(self): 1174 """Sync func_name of ast.Call of ast.Assign from self._name when NodeType is CallCell or CallPrimitive.""" 1175 if self._ast_node is None: 1176 return 1177 assign_ast = self._ast_node 1178 if not isinstance(assign_ast, ast.Assign): 1179 raise TypeError("assign_ast should be ast.Assign, got: ", type(assign_ast)) 1180 call_ast = assign_ast.value 1181 if not isinstance(call_ast, ast.Call): 1182 raise TypeError("call_ast should be ast.Call, got: ", type(call_ast)) 1183 if self._func_name.type == ValueType.UnsupportedValue: 1184 return 1185 func_ast = call_ast.func 1186 if not self._func_name.scope: 1187 if isinstance(func_ast, ast.Name): 1188 func_ast.id = self._func_name.value 1189 else: 1190 call_ast.func = ast.Name(self._func_name.value, ast.Store()) 1191 else: 1192 if isinstance(func_ast, ast.Attribute): 1193 if not isinstance(func_ast.value, ast.Name): 1194 func_ast.value = ast.Name(self._func_name.scope, ast.Load()) 1195 else: 1196 func_ast.value.id = self._func_name.scope 1197 func_ast.attr = self._func_name.value 1198 else: 1199 call_ast.func = ast.Attribute(ast.Name(self._func_name.scope, ast.Load()), 1200 self._func_name.value, ast.Store()) 1201 ast.fix_missing_locations(assign_ast) 1202 1203 def _sync_assign_targets_to_ast(self): 1204 """Sync targets of ast.Assign from self._targets when NodeType is CallCell, CallPrimitive or CallMethod.""" 1205 if self._ast_node is None: 1206 return 1207 assign_ast = self._ast_node 1208 if not isinstance(assign_ast, ast.Assign): 1209 raise TypeError(error_str(f"assign_ast should be ast.Assign, but got: {type(assign_ast)}", 1210 father_node=assign_ast)) 1211 # update targets 1212 target_ast_elems = AstConverter.get_ast_target_elems(assign_ast.targets[0]) 1213 if len(self._targets) != len(target_ast_elems): 1214 raise ValueError(error_str(f"The number of targets should be {len(target_ast_elems)}, " 1215 f"but got {len(self._targets)}", father_node=assign_ast)) 1216 for i, target_ast in enumerate(target_ast_elems): 1217 target_ast_elems[i] = AstModifier.get_ast_by_value(self._targets[i], target_ast) 1218 1219 def _sync_call_args_to_ast(self): 1220 """Sync args of ast.Call from self._normalized_args.""" 1221 if self._ast_node is None: 1222 return 1223 assign_ast = self._ast_node 1224 if not isinstance(assign_ast, ast.Assign): 1225 raise TypeError(f"When synchronizing args for '{self._name}'({self._node_type}), _ast_node should be " 1226 f"ast.Assign, but got: {type(assign_ast)}") 1227 assign_value = assign_ast.value 1228 if not isinstance(assign_value, ast.Call): 1229 if isinstance(assign_value, ast.Attribute) and self._node_type in (NodeType.CellContainer, 1230 NodeType.CallCell): 1231 # CellContainers in control flow may be flatten to ast.Attribute: blocks_var = self.blocks 1232 # In this case, no args exist in node, so we don't need to sync. 1233 # CellContainers may be type of CallCell when share one implementation 1234 return 1235 raise TypeError(f"When synchronizing args for '{self._name}'({self._node_type}), _ast_node.value should " 1236 f"be ast.Call, but got: {type(assign_value)}") 1237 keywords_ast = assign_value.keywords 1238 args_ast = assign_value.args 1239 if len(self._normalized_args_keys) != (len(keywords_ast) + len(args_ast)): 1240 raise ValueError("ast keywords plus args len is not equal to self._normalized_args value") 1241 for arg_index in range(self._args_num): 1242 arg_ast = args_ast[arg_index] 1243 args_ast[arg_index] = \ 1244 AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[arg_index]), arg_ast) 1245 1246 # the order of kwargs may not the same as that in keywords_ast 1247 keyword_map_index = {} 1248 for index, keyword_ast in enumerate(keywords_ast): 1249 keyword_map_index[keyword_ast.arg] = index 1250 for keyword_index in range(self._kwargs_num): 1251 key = self._normalized_args_keys[keyword_index + self._args_num] 1252 keywords_ast[keyword_map_index.get(key)].value = \ 1253 AstModifier.get_ast_by_value(self._normalized_args.get(key), 1254 keywords_ast[keyword_map_index.get(key)].value) 1255 1256 def _sync_call_method_args_to_ast(self): 1257 """ 1258 Sync args to value of ast.Assign from self._normalized_args when NodeType is CallMethod. 1259 For node with type of CallMethod, the value of ast.Assign is one of: 1260 | func_name | data_type | value of ast.Assign | 1261 |:---------------|:------------|:------------------------| 1262 | 'pass_through' | constants | ast.Constant, ast.NameConstant, ast.Num, ast.Bytes, ast.Str | 1263 | 'pass_through' | variables | ast.Name, ast.Attribute | 1264 | 'tuple' | tuple | ast.Tuple | 1265 | 'list' | list | ast.List | 1266 | 'dict' | dict | ast.Dict | 1267 """ 1268 if self._ast_node is None: 1269 return 1270 assign_ast = self._ast_node 1271 if not isinstance(assign_ast, ast.Assign): 1272 raise TypeError(f"For node '{self.get_name()}', assign_ast should be ast.Assign, got: ", type(assign_ast)) 1273 assign_value = assign_ast.value 1274 if self._func_name.value == "pass_through": 1275 # update constants/variables 1276 assign_ast.value = \ 1277 AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), assign_value) 1278 elif self._func_name.value in ("tuple", "list", "dict"): 1279 # update tuple/list/dict 1280 ast_elts = assign_value.values if isinstance(assign_value, ast.Dict) else assign_value.elts 1281 if len(self._normalized_args_keys) != len(ast_elts): 1282 raise ValueError(f"For node '{self.get_name()}', size of self._normalized_args_keys" 1283 f"({len(self._normalized_args_keys)}) should be equal to size of elements of " 1284 f"ast_elts({len(ast_elts)})") 1285 for index, elt in enumerate(ast_elts): 1286 scoped_value: ScopedValue = self._normalized_args.get(self._normalized_args_keys[index]) 1287 ast_elts[index] = AstModifier.get_ast_by_value(scoped_value, elt) 1288 else: 1289 raise TypeError(f"For node '{self.get_name()}', only support (pass_through, tuple or dict method) as " 1290 f"call_method, but got {self._func_name.value}") 1291 1292 def _sync_return_node_to_ast(self): 1293 """ 1294 Sync args to value of ast.Return from self._normalized_args when NodeType is Output. 1295 1296 For node with type of CallMethod, the value of ast.Assign is one of: 1297 (ast.Name, ast.Attribute) 1298 """ 1299 if self._ast_node is None: 1300 return 1301 return_ast = self._ast_node 1302 if not isinstance(return_ast, ast.Return): 1303 raise TypeError(f"For node '{self.get_name()}', return_ast should be ast.Return, got: {type(return_ast)}") 1304 return_value_ast = return_ast.value 1305 return_ast.value = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), 1306 return_value_ast) 1307 1308 def _sync_mathops_node_args_to_ast(self): 1309 """ 1310 Sync values from self._normalized_args to the ast node for mathematical operations. 1311 """ 1312 if self._ast_node is None: 1313 return 1314 if not isinstance(self._ast_node, ast.Assign): 1315 raise TypeError(f"type of node should be ast.Assign, but got {type(self._ast_node)}") 1316 mathops_node = self._ast_node.value 1317 if isinstance(mathops_node, ast.BinOp): 1318 left = mathops_node.left 1319 right = mathops_node.right 1320 mathops_node.left = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), 1321 left) 1322 mathops_node.right = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[1]), 1323 right) 1324 elif isinstance(mathops_node, ast.UnaryOp): 1325 operand = mathops_node.operand 1326 mathops_node.operand = \ 1327 AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), operand) 1328 elif isinstance(mathops_node, ast.BoolOp): 1329 values = mathops_node.values 1330 for arg_index in range(self._args_num): 1331 arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index]) 1332 values[arg_index] = AstModifier.get_ast_by_value(arg_value, values[arg_index]) 1333 elif isinstance(mathops_node, ast.Compare): 1334 left = mathops_node.left 1335 mathops_node.left = AstModifier.get_ast_by_value(self._normalized_args.get(self._normalized_args_keys[0]), 1336 left) 1337 comparators = mathops_node.comparators 1338 for arg_index in range(1, self._args_num): 1339 arg_value = self._normalized_args.get(self._normalized_args_keys[arg_index]) 1340 comparators[arg_index - 1] = AstModifier.get_ast_by_value(arg_value, comparators[arg_index - 1]) 1341 else: 1342 raise TypeError("The type of 'mathops_node' must be one of (ast.BinOp, ast.UnaryOp, " 1343 "ast.BoolOp, ast.Compare), but got ", type(mathops_node)) 1344 1345 def _sync_control_flow_args_to_ast(self): 1346 """ 1347 Sync values from self._normalized_args to the ast node of control flow. 1348 """ 1349 if self._ast_node is None: 1350 return 1351 normalized_args_num = len(self._normalized_args_keys) 1352 if normalized_args_num == 0: 1353 return 1354 if normalized_args_num > 1: 1355 raise ValueError("self._normalized_args_keys should have less than 1 elements") 1356 arg_value = self._normalized_args.get(self._normalized_args_keys[0]) 1357 if isinstance(self._ast_node, (ast.If, ast.IfExp, ast.While)): 1358 self._ast_node.test = AstModifier.get_ast_by_value(arg_value, self._ast_node.test) 1359 elif isinstance(self._ast_node, ast.For): 1360 self._ast_node.iter = AstModifier.get_ast_by_value(arg_value, self._ast_node.iter) 1361 else: 1362 raise ValueError(f"For Control Flow, ast_node should be one of [ast.If, ast.IfExp, " 1363 f"ast.While, ast.For], but got {type(self._ast_node)}") 1364 1365 def _sync_arg(self): 1366 """Sync _normalized_args to corresponding ast node when updated.""" 1367 if self._node_type in (NodeType.CallCell, NodeType.CallPrimitive, NodeType.Tree, \ 1368 NodeType.CellContainer, NodeType.CallFunction): 1369 self._sync_call_args_to_ast() 1370 elif self._node_type == NodeType.Output: 1371 self._sync_return_node_to_ast() 1372 elif self._node_type == NodeType.CallMethod: 1373 self._sync_call_method_args_to_ast() 1374 elif self._node_type == NodeType.MathOps: 1375 self._sync_mathops_node_args_to_ast() 1376 elif self._node_type == NodeType.ControlFlow: 1377 self._sync_control_flow_args_to_ast() 1378 1379 1380# Child classes 1381class TreeNode(Node): 1382 """Tree type Node who holds a handler of SymbolTree.""" 1383 1384 def __init__(self, tree, ast_node: ast.AST, targets: [ScopedValue], func: ScopedValue, 1385 args: [ScopedValue], kwargs: {str: ScopedValue}, name: str, instance): 1386 """ 1387 Constructor of TreeNode. Rewrite recommend to invoking class method of Node to instantiate an instance of 1388 TreeNode such as `create_tree_node` rather than invoking constructor of Node directly. 1389 1390 Args: 1391 tree: An instance of SymbolTree represents a handler of sub-symbol-tree. 1392 ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. 1393 targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 1394 func ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class. 1395 args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 1396 kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. 1397 name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. 1398 Name of node also used as field name in network class. 1399 instance: Object in network corresponding to this node. 1400 """ 1401 if isinstance(func, str): 1402 func = ScopedValue.create_naming_value(func) 1403 super().__init__(NodeType.Tree, ast_node, targets, func, args, kwargs, name, instance) 1404 self.symbol_tree = tree 1405 1406 @classmethod 1407 def create_tree_node(cls, tree, ast_node: ast.AST, targets: Union[ScopedValue, str], 1408 func_name: Union[ScopedValue, str], args: [ScopedValue], kwargs: {str: ScopedValue}, 1409 name: str = "", instance=None): 1410 """ 1411 Class method of TreeNode. Instantiate an instance of node whose type is Tree. A Tree node represents an invoking 1412 to sub-network. 1413 1414 Args: 1415 tree: An instance of SymbolTree represents a handler of sub-symbol-tree. 1416 ast_node (ast.AST): An instance of ast.AST represents corresponding node in ast. 1417 targets (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 1418 func_name ([ScopedValue, optional]): An instance of ScopedValue. See detail in docstring of Node class. 1419 args (list[ScopedValue]): A list of instance of ScopedValue. See detail in docstring of Node class. 1420 kwargs (dict{str: ScopedValue}): A list of instance of ScopedValue. See detail in docstring of Node class. 1421 name (str): A string represents name of node. Name of node will be unique when inserted into SymbolTree. 1422 Name of node also used as field name in network class. 1423 instance: Object in network corresponding to this node. 1424 """ 1425 new_targets = Node._handle_targets(targets) 1426 if isinstance(func_name, str): 1427 func_name = ScopedValue.create_naming_value(func_name) 1428 return cls(tree, ast_node, new_targets, func_name, args, kwargs, name, instance) 1429