1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Rewrite module api: Node.""" 16 17from typing import Union, Optional, List, Dict 18from types import FunctionType 19 20from mindspore.nn import Cell 21from mindspore.ops.primitive import Primitive 22from mindspore import _checkparam as Validator 23from ..node.node import Node as NodeImpl 24from ..symbol_tree import SymbolTree as SymbolTreeImpl 25from .node_type import NodeType 26from .scoped_value import ScopedValue 27 28 29class Node: 30 """ 31 A node is a data structure that expresses source code statements in a network. 32 33 Each node usually corresponds to a statement in expanded forward evaluation process. 34 35 Nodes can express a ``Cell`` call statement, a ``Primitive`` call statement, an arithmetic operation statement, a 36 return statements, etc. of the forward calculation process. 37 38 Args: 39 node (NodeImpl): A handler of `NodeImpl`. It is recommended to call the specific methods in Node to create 40 a Node, such as 'create_call_cell', rather than calling the Node's constructor directly. 41 Don't care what `NodeImpl` is, just treat it as a handle. 42 """ 43 44 def __init__(self, node: NodeImpl): 45 self._node = node 46 47 48 def __eq__(self, other: 'Node'): 49 if not isinstance(other, Node): 50 return False 51 return self._node == other._node 52 53 @staticmethod 54 def create_call_cell(cell: Cell, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None, 55 kwargs: Dict[str, ScopedValue] = None, name: str = "", is_sub_net: bool = False) -> 'Node': 56 """ 57 Create a node. Only support create from a `Cell` now. 58 59 A node is corresponding to source code like: 60 61 ``targets = self.name(*args, **kwargs)`` 62 63 Args: 64 cell (Cell): Cell-operator of this forward-layer. 65 targets (List[Union[ScopedValue, str]]): Indicate output names. Used as targets of an assign statement in 66 source code. 67 args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in 68 source code. Default: ``None`` , which indicates the `cell` has no args inputs. 69 kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`. 70 Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source 71 code. Default: ``None`` , which indicates the `cell` has no kwargs inputs. 72 name (str): Indicate the name of node. Used as field name in source code. Default is None. Rewrite will 73 generate name from `cell` when name is None. Rewrite will check and ensure the uniqueness of `name` 74 while node being inserted. Default: ``""`` . 75 is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse 76 the `cell` to a TreeNode, otherwise the `cell` is parsed to a CallCell node. Default: ``False`` . 77 78 Returns: 79 An instance of `Node`. 80 81 Raises: 82 TypeError: If `cell` is not a `Cell`. 83 TypeError: If `targets` is not `list`. 84 TypeError: If the type of `targets` is not in `[ScopedValue, str]`. 85 TypeError: If arg in `args` is not a `ScopedValue`. 86 TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not a `ScopedValue`. 87 88 Examples: 89 >>> from mindspore.rewrite import SymbolTree, ScopedValue 90 >>> import mindspore.nn as nn 91 >>> # Define the network structure of LeNet5. Refer to 92 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 93 >>> net = LeNet5() 94 >>> stree = SymbolTree.create(net) 95 >>> node = stree.get_node("conv1") 96 >>> position = stree.after(node) 97 >>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'], 98 ... args=[ScopedValue.create_naming_value('x')], name='new_relu') 99 >>> stree.insert(position, new_node) 100 >>> print(type(new_node)) 101 <class 'mindspore.rewrite.api.node.Node'> 102 """ 103 Validator.check_value_type("cell", cell, [Cell, Primitive], "Node") 104 Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node") 105 Validator.check_value_type("name", name, [str], "Node") 106 Validator.check_value_type("is_sub_net", is_sub_net, [bool], "Node") 107 if args is not None: 108 Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node") 109 if kwargs is not None: 110 Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node") 111 return Node(NodeImpl.create_call_op(cell, None, targets, args, kwargs, name, is_sub_net)) 112 113 @staticmethod 114 def create_call_function(function: FunctionType, targets: List[Union[ScopedValue, str]], 115 args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None) -> 'Node': 116 """ 117 Create a node that corresponds to a function call. 118 119 Note: 120 The codes inside the function will not be parsed. 121 122 Args: 123 function (FunctionType): The function to be called. 124 targets (List[Union[ScopedValue, str]]): indicates output names. Used as targets of an assign statement in 125 source code. 126 args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in 127 source code. Default: ``None`` , which indicates the `function` has no args inputs. 128 kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`. 129 Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source 130 code. Default: ``None`` , which indicates the `function` has no kwargs inputs. 131 132 Returns: 133 An instance of `Node`. 134 135 Raises: 136 TypeError: If `function` is not a `FunctionType`. 137 TypeError: If `targets` is not `list`. 138 TypeError: If the type of `targets` is not in `[ScopedValue, str]`. 139 TypeError: If arg in `args` is not a `ScopedValue`. 140 TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not a `ScopedValue`. 141 142 Examples: 143 >>> from mindspore.rewrite import SymbolTree, ScopedValue 144 >>> import mindspore.nn as nn 145 >>> from mindspore import ops 146 >>> # Define the network structure of LeNet5. Refer to 147 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 148 >>> net = LeNet5() 149 >>> stree = SymbolTree.create(net) 150 >>> node = stree.get_node("conv1") 151 >>> position = stree.after(node) 152 >>> new_node = node.create_call_function(function=ops.abs, targets=['x'], 153 ... args=[ScopedValue.create_naming_value('x')]) 154 >>> stree.insert(position, new_node) 155 >>> print(new_node.get_node_type()) 156 NodeType.CallFunction 157 """ 158 Validator.check_value_type("function", function, [FunctionType, type, type(abs)], "create_call_function") 159 Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "create_call_function") 160 if args is not None: 161 Validator.check_element_type_of_iterable("args", args, [ScopedValue], "create_call_function") 162 if kwargs is not None: 163 Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "create_call_function") 164 return Node(NodeImpl._create_call_function(function, targets, args, kwargs)) 165 166 @staticmethod 167 def create_input(param_name: str, default: Optional[ScopedValue] = None) -> 'Node': 168 # pylint: disable=missing-function-docstring 169 Validator.check_value_type("param_name", param_name, [str], "Node") 170 if default is not None: 171 Validator.check_value_type("default", default, [ScopedValue], "Node") 172 return Node(NodeImpl.create_input_node(None, param_name, default, name=f"input_{param_name}")) 173 174 def get_handler(self) -> NodeImpl: 175 return self._node 176 177 def get_inputs(self) -> ['Node']: 178 """ 179 Gets a list of nodes whose output values are used as input values for the current node. 180 181 Returns: 182 A list of nodes. 183 184 Examples: 185 >>> from mindspore.rewrite import SymbolTree 186 >>> # Define the network structure of LeNet5. Refer to 187 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 188 >>> net = LeNet5() 189 >>> stree = SymbolTree.create(net) 190 >>> node = stree.get_node("conv2") 191 >>> inputs = node.get_inputs() 192 >>> print([input.get_name() for input in inputs]) 193 ['max_pool2d'] 194 """ 195 return [Node(node_impl) for node_impl in self._node.get_inputs()] 196 197 def get_users(self) -> ['Node']: 198 """ 199 Get a list of nodes that use the output of the current node as input. 200 201 Returns: 202 A list of nodes. 203 204 Examples: 205 >>> from mindspore.rewrite import SymbolTree 206 >>> # Define the network structure of LeNet5. Refer to 207 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 208 >>> net = LeNet5() 209 >>> stree = SymbolTree.create(net) 210 >>> node = stree.get_node("conv1") 211 >>> users = node.get_users() 212 >>> print([user.get_name() for user in users]) 213 ['relu'] 214 """ 215 return [Node(node_impl) for node_impl in self._node.get_users()] 216 217 def set_arg(self, index: int, arg: Union[ScopedValue, str]): 218 """ 219 Set argument of current node. 220 221 Args: 222 index (int): Indicate which input being modified. 223 arg (Union[ScopedValue, str]): New argument to been set. 224 225 Raises: 226 TypeError: If `index` is not a `int` number. 227 TypeError: If the type of `arg` is not in [`ScopedValue`, `str`]. 228 229 Examples: 230 >>> from mindspore.rewrite import SymbolTree 231 >>> # Define the network structure of LeNet5. Refer to 232 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 233 >>> net = LeNet5() 234 >>> stree = SymbolTree.create(net) 235 >>> node = stree.get_node("relu_3") 236 >>> node.set_arg(0, "fc1") 237 >>> print(node.get_args()) 238 [fc1] 239 """ 240 Validator.check_value_type("index", index, [int], "Node") 241 Validator.check_value_type("arg", arg, [ScopedValue, str], "Node") 242 belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree() 243 if belong_symbol_tree is None: 244 self._node.set_arg(arg, index) 245 else: 246 belong_symbol_tree.set_node_arg(self._node, index, arg) 247 248 def set_arg_by_node(self, arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None): 249 """ 250 Set argument of current node by another Node. 251 252 Args: 253 arg_idx (int): Indicate which input being modified. 254 src_node (Node): A `Node` as new input. Can be a node or name of node. 255 out_idx (int, optional): Indicate which output of `src_node` as new input of current node. 256 Default: ``None`` , 257 which means use first output of `src_node` as new input. 258 259 Raises: 260 TypeError: If `arg_idx` is not a `int` number. 261 ValueError: If `arg_idx` is out of range. 262 TypeError: If `src_node` is not a `Node` instance. 263 TypeError: If `out_idx` is not a `int` number. 264 ValueError: If `out_idx` is out of range. 265 ValueError: If `src_node` has multi-outputs while `out_idx` is None or `out_idx` is not offered. 266 267 Examples: 268 >>> from mindspore.rewrite import SymbolTree 269 >>> # Define the network structure of LeNet5. Refer to 270 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 271 >>> net = LeNet5() 272 >>> stree = SymbolTree.create(net) 273 >>> src_node = stree.get_node("fc1") 274 >>> dst_node = stree.get_node("relu_3") 275 >>> dst_node.set_arg_by_node(0, src_node, 0) 276 >>> print(dst_node.get_args()) 277 [fc1_var] 278 """ 279 Validator.check_value_type("arg_idx", arg_idx, [int], "Node") 280 Validator.check_value_type("src_node", src_node, [Node], "Node") 281 if out_idx is not None: 282 Validator.check_value_type("out_idx", out_idx, [int], "Node") 283 belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree() 284 if belong_symbol_tree is None: 285 self._node.set_arg_by_node(arg_idx, src_node._node, out_idx) 286 else: 287 belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node.get_handler(), out_idx) 288 289 def get_targets(self) -> [ScopedValue]: 290 """ 291 Gets a list of output values for the current node. 292 293 Returns: 294 A list of outputs of type ``ScopedValue`` . 295 """ 296 return self._node.get_targets() 297 298 def get_name(self) -> str: 299 """ 300 Get the name of current node. 301 302 When node has been inserted into `SymbolTree`, the name of node should be unique in `SymbolTree`. 303 304 Returns: 305 A string as name of node. 306 307 Examples: 308 >>> from mindspore.rewrite import SymbolTree 309 >>> # Define the network structure of LeNet5. Refer to 310 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 311 >>> net = LeNet5() 312 >>> stree = SymbolTree.create(net) 313 >>> node = stree.get_node("conv1") 314 >>> name = node.get_name() 315 >>> print(name) 316 conv1 317 """ 318 return self._node.get_name() 319 320 def get_node_type(self) -> NodeType: 321 """ 322 Get the node_type of current node. See :class:`mindspore.rewrite.NodeType` for details on node types. 323 324 Returns: 325 A NodeType as node_type of node. 326 327 Examples: 328 >>> from mindspore.rewrite import SymbolTree 329 >>> # Define the network structure of LeNet5. Refer to 330 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 331 >>> net = LeNet5() 332 >>> stree = SymbolTree.create(net) 333 >>> node = stree.get_node("conv1") 334 >>> node_type = node.get_node_type() 335 >>> print(node_type) 336 NodeType.CallCell 337 """ 338 return self._node.get_node_type() 339 340 def get_instance_type(self) -> type: 341 """ 342 Gets the instance type called in the code corresponding to the current node. 343 344 - When `node_type` of current node is `CallCell`, the code for that node calls an instance of type ``Cell`` . 345 - When `node_type` of current node is `CallPrimitive`, the code for that node calls an instance of 346 type ``Primitive`` . 347 - When `node_type` of current node is `Tree`, the code for that node calls an instance of network type. 348 - When `node_type` of current node is `Python`, `Input`, `Output` or `CallMethod`, the instance type 349 is ``NoneType`` . 350 351 Returns: 352 The type of instance called in the statement corresponding to the current node. 353 354 Examples: 355 >>> from mindspore.rewrite import SymbolTree 356 >>> # Define the network structure of LeNet5. Refer to 357 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 358 >>> net = LeNet5() 359 >>> stree = SymbolTree.create(net) 360 >>> node = stree.get_node("conv1") 361 >>> instance_type = node.get_instance_type() 362 >>> print(instance_type) 363 <class 'mindspore.nn.layer.conv.Conv2d'> 364 """ 365 return self._node.get_instance_type() 366 367 def get_instance(self): 368 return self._node.get_instance() 369 370 def get_args(self) -> [ScopedValue]: 371 """ 372 Get arguments of current node. 373 374 Returns: 375 A list of arguments of type ``ScopedValue`` . 376 377 Examples: 378 >>> from mindspore.rewrite import SymbolTree 379 >>> # Define the network structure of LeNet5. Refer to 380 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 381 >>> net = LeNet5() 382 >>> stree = SymbolTree.create(net) 383 >>> node = stree.get_node("conv1") 384 >>> print(node.get_args()) 385 [x] 386 """ 387 return self._node.get_args() 388 389 def get_symbol_tree(self) -> 'SymbolTree': 390 """ 391 Get the symbol tree which current node belongs to. 392 393 Returns: 394 SymbolTree, None if current node does not belong to any SymbolTree. 395 396 Examples: 397 >>> from mindspore.rewrite import SymbolTree 398 >>> # Define the network structure of LeNet5. Refer to 399 >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py 400 >>> net = LeNet5() 401 >>> stree = SymbolTree.create(net) 402 >>> node = stree.get_node("conv1") 403 >>> print(type(node.get_symbol_tree())) 404 <class 'mindspore.rewrite.api.symbol_tree.SymbolTree'> 405 """ 406 from .symbol_tree import SymbolTree 407 stree_impl = self._node.get_belong_symbol_tree() 408 if not stree_impl: 409 return None 410 return SymbolTree(stree_impl) 411 412 def get_sub_tree(self) -> 'SymbolTree': 413 """ 414 Get the sub symbol tree stored in node with type of `NodeType.Tree` . 415 See :class:`mindspore.rewrite.NodeType` for details on node types. 416 417 Returns: 418 SymbolTree stored in Tree node. 419 420 Raises: 421 TypeError: If current node is not type of `NodeType.Tree` . 422 AttributeError: If no symbol tree is stored in Tree node. 423 424 Examples: 425 >>> import mindspore.nn as nn 426 >>> from mindspore.rewrite import SymbolTree 427 >>> 428 >>> class SubNet(nn.Cell): 429 ... def __init__(self): 430 ... super().__init__() 431 ... self.relu = nn.ReLU() 432 ... 433 ... def construct(self, x): 434 ... x = self.relu(x) 435 ... return x 436 ... 437 >>> class Net(nn.Cell): 438 ... def __init__(self): 439 ... super().__init__() 440 ... self.subnet = SubNet() 441 ... 442 ... def construct(self, x): 443 ... x = self.subnet(x) 444 ... return x 445 >>> 446 >>> net = Net() 447 >>> stree = SymbolTree.create(net) 448 >>> node = stree.get_node("subnet") 449 >>> print(type(node.get_sub_tree())) 450 <class 'mindspore.rewrite.api.symbol_tree.SymbolTree'> 451 """ 452 if self.get_node_type() != NodeType.Tree: 453 raise TypeError("For get_sub_tree, the type of node should be 'NodeType.Tree', " 454 f"but got {self.get_node_type()}") 455 subtree: SymbolTreeImpl = self.get_handler().symbol_tree 456 if subtree is None: 457 raise AttributeError( 458 f"For get_sub_tree, no symbol tree is stroed in node {self.get_name()}.") 459 from .symbol_tree import SymbolTree 460 return SymbolTree(subtree) 461 462 def get_kwargs(self) -> {str: ScopedValue}: 463 """ 464 Get keyword arguments of current node. 465 466 Returns: 467 A dict of keyword arguments, where key is of type str, and value is of type ``ScopedValue`` . 468 469 Examples: 470 >>> from mindspore.rewrite import SymbolTree 471 >>> from mindspore import nn 472 >>> 473 >>> class ReLUNet(nn.Cell): 474 ... def __init__(self): 475 ... super().__init__() 476 ... self.relu = nn.ReLU() 477 ... 478 ... def construct(self, input): 479 ... output = self.relu(x=input) 480 ... return output 481 >>> 482 >>> net = ReLUNet() 483 >>> stree = SymbolTree.create(net) 484 >>> node = stree.get_node("relu") 485 >>> print(node.get_kwargs()) 486 {'x': input} 487 """ 488 return self._node.get_kwargs() 489 490 def set_attribute(self, key: str, value): 491 Validator.check_value_type("key", key, [str], "Node attribute") 492 self._node.set_attribute(key, value) 493 494 def get_attributes(self) -> {str: object}: 495 return self._node.get_attributes() 496 497 def get_attribute(self, key: str): 498 Validator.check_value_type("key", key, [str], "Node attribute") 499 return self._node.get_attribute(key) 500 501 # pylint: disable=missing-docstring 502 def get_arg_providers(self) -> dict: 503 arg_providers = {} 504 for arg_idx, providers in self._node.get_arg_providers().items(): 505 arg_providers[arg_idx] = (Node(providers[0]), providers[1]) 506 return arg_providers 507 508 # pylint: disable=missing-docstring 509 def get_target_users(self, index=-1) -> Union[dict, list]: 510 Validator.check_value_type("index", index, [int], "get_target_users") 511 if index == -1: 512 target_users = {} 513 for target_idx, users in self._node.get_target_users().items(): 514 target_users[target_idx] = [(Node(user[0]), user[1]) for user in users] 515 return target_users 516 target_users = [] 517 for users in self._node.get_target_users(index): 518 target_users.append((Node(users[0]), users[1])) 519 return target_users 520