• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Rewrite module api: Node."""
16
17from typing import Union, Optional, List, Dict
18from types import FunctionType
19
20from mindspore.nn import Cell
21from mindspore.ops.primitive import Primitive
22from mindspore import _checkparam as Validator
23from ..node.node import Node as NodeImpl
24from ..symbol_tree import SymbolTree as SymbolTreeImpl
25from .node_type import NodeType
26from .scoped_value import ScopedValue
27
28
29class Node:
30    """
31    A node is a data structure that expresses source code statements in a network.
32
33    Each node usually corresponds to a statement in expanded forward evaluation process.
34
35    Nodes can express a ``Cell`` call statement, a ``Primitive`` call statement, an arithmetic operation statement, a
36    return statements, etc. of the forward calculation process.
37
38    Args:
39        node (NodeImpl): A handler of `NodeImpl`. It is recommended to call the specific methods in Node to create
40            a Node, such as 'create_call_cell', rather than calling the Node's constructor directly.
41            Don't care what `NodeImpl` is, just treat it as a handle.
42    """
43
44    def __init__(self, node: NodeImpl):
45        self._node = node
46
47
48    def __eq__(self, other: 'Node'):
49        if not isinstance(other, Node):
50            return False
51        return self._node == other._node
52
53    @staticmethod
54    def create_call_cell(cell: Cell, targets: List[Union[ScopedValue, str]], args: List[ScopedValue] = None,
55                         kwargs: Dict[str, ScopedValue] = None, name: str = "", is_sub_net: bool = False) -> 'Node':
56        """
57        Create a node. Only support create from a `Cell` now.
58
59        A node is corresponding to source code like:
60
61        ``targets = self.name(*args, **kwargs)``
62
63        Args:
64            cell (Cell): Cell-operator of this forward-layer.
65            targets (List[Union[ScopedValue, str]]): Indicate output names. Used as targets of an assign statement in
66                source code.
67            args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
68                source code. Default: ``None`` , which indicates the `cell` has no args inputs.
69            kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`.
70                Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
71                code. Default: ``None`` , which indicates the `cell` has no kwargs inputs.
72            name (str): Indicate the name of node. Used as field name in source code. Default is None. Rewrite will
73                generate name from `cell` when name is None. Rewrite will check and ensure the uniqueness of `name`
74                while node being inserted. Default: ``""`` .
75            is_sub_net (bool): Indicate that is `cell` a network. If `is_sub_net` is true, Rewrite will try to parse
76                the `cell` to a TreeNode, otherwise the `cell` is parsed to a CallCell node. Default: ``False`` .
77
78        Returns:
79            An instance of `Node`.
80
81        Raises:
82            TypeError: If `cell` is not a `Cell`.
83            TypeError: If `targets` is not `list`.
84            TypeError: If the type of `targets` is not in `[ScopedValue, str]`.
85            TypeError: If arg in `args` is not a `ScopedValue`.
86            TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not a `ScopedValue`.
87
88        Examples:
89            >>> from mindspore.rewrite import SymbolTree, ScopedValue
90            >>> import mindspore.nn as nn
91            >>> # Define the network structure of LeNet5. Refer to
92            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
93            >>> net = LeNet5()
94            >>> stree = SymbolTree.create(net)
95            >>> node = stree.get_node("conv1")
96            >>> position = stree.after(node)
97            >>> new_node = node.create_call_cell(cell=nn.ReLU(), targets=['x'],
98            ...                                  args=[ScopedValue.create_naming_value('x')], name='new_relu')
99            >>> stree.insert(position, new_node)
100            >>> print(type(new_node))
101            <class 'mindspore.rewrite.api.node.Node'>
102        """
103        Validator.check_value_type("cell", cell, [Cell, Primitive], "Node")
104        Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "Node")
105        Validator.check_value_type("name", name, [str], "Node")
106        Validator.check_value_type("is_sub_net", is_sub_net, [bool], "Node")
107        if args is not None:
108            Validator.check_element_type_of_iterable("args", args, [ScopedValue], "Node")
109        if kwargs is not None:
110            Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "Node")
111        return Node(NodeImpl.create_call_op(cell, None, targets, args, kwargs, name, is_sub_net))
112
113    @staticmethod
114    def create_call_function(function: FunctionType, targets: List[Union[ScopedValue, str]],
115                             args: List[ScopedValue] = None, kwargs: Dict[str, ScopedValue] = None) -> 'Node':
116        """
117        Create a node that corresponds to a function call.
118
119        Note:
120            The codes inside the function will not be parsed.
121
122        Args:
123            function (FunctionType): The function to be called.
124            targets (List[Union[ScopedValue, str]]): indicates output names. Used as targets of an assign statement in
125                source code.
126            args (List[ScopedValue]): Indicate input names. Used as args of a call expression of an assign statement in
127                source code. Default: ``None`` , which indicates the `function` has no args inputs.
128            kwargs (Dict[str, ScopedValue]): Type of key must be `str` and type of value must be `ScopedValue`.
129                Indicate keyword input names. Used as kwargs of a call expression of an assign statement in source
130                code. Default: ``None`` , which indicates the `function` has no kwargs inputs.
131
132        Returns:
133            An instance of `Node`.
134
135        Raises:
136            TypeError: If `function` is not a `FunctionType`.
137            TypeError: If `targets` is not `list`.
138            TypeError: If the type of `targets` is not in `[ScopedValue, str]`.
139            TypeError: If arg in `args` is not a `ScopedValue`.
140            TypeError: If key of `kwarg` is not a str or value of kwarg in `kwargs` is not a `ScopedValue`.
141
142        Examples:
143            >>> from mindspore.rewrite import SymbolTree, ScopedValue
144            >>> import mindspore.nn as nn
145            >>> from mindspore import ops
146            >>> # Define the network structure of LeNet5. Refer to
147            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
148            >>> net = LeNet5()
149            >>> stree = SymbolTree.create(net)
150            >>> node = stree.get_node("conv1")
151            >>> position = stree.after(node)
152            >>> new_node = node.create_call_function(function=ops.abs, targets=['x'],
153            ...                                      args=[ScopedValue.create_naming_value('x')])
154            >>> stree.insert(position, new_node)
155            >>> print(new_node.get_node_type())
156            NodeType.CallFunction
157        """
158        Validator.check_value_type("function", function, [FunctionType, type, type(abs)], "create_call_function")
159        Validator.check_element_type_of_iterable("targets", targets, [ScopedValue, str], "create_call_function")
160        if args is not None:
161            Validator.check_element_type_of_iterable("args", args, [ScopedValue], "create_call_function")
162        if kwargs is not None:
163            Validator.check_element_type_of_dict("kwargs", kwargs, [str], [ScopedValue], "create_call_function")
164        return Node(NodeImpl._create_call_function(function, targets, args, kwargs))
165
166    @staticmethod
167    def create_input(param_name: str, default: Optional[ScopedValue] = None) -> 'Node':
168        # pylint: disable=missing-function-docstring
169        Validator.check_value_type("param_name", param_name, [str], "Node")
170        if default is not None:
171            Validator.check_value_type("default", default, [ScopedValue], "Node")
172        return Node(NodeImpl.create_input_node(None, param_name, default, name=f"input_{param_name}"))
173
174    def get_handler(self) -> NodeImpl:
175        return self._node
176
177    def get_inputs(self) -> ['Node']:
178        """
179        Gets a list of nodes whose output values are used as input values for the current node.
180
181        Returns:
182            A list of nodes.
183
184        Examples:
185            >>> from mindspore.rewrite import SymbolTree
186            >>> # Define the network structure of LeNet5. Refer to
187            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
188            >>> net = LeNet5()
189            >>> stree = SymbolTree.create(net)
190            >>> node = stree.get_node("conv2")
191            >>> inputs = node.get_inputs()
192            >>> print([input.get_name() for input in inputs])
193            ['max_pool2d']
194        """
195        return [Node(node_impl) for node_impl in self._node.get_inputs()]
196
197    def get_users(self) -> ['Node']:
198        """
199        Get a list of nodes that use the output of the current node as input.
200
201        Returns:
202            A list of nodes.
203
204        Examples:
205            >>> from mindspore.rewrite import SymbolTree
206            >>> # Define the network structure of LeNet5. Refer to
207            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
208            >>> net = LeNet5()
209            >>> stree = SymbolTree.create(net)
210            >>> node = stree.get_node("conv1")
211            >>> users = node.get_users()
212            >>> print([user.get_name() for user in users])
213            ['relu']
214        """
215        return [Node(node_impl) for node_impl in self._node.get_users()]
216
217    def set_arg(self, index: int, arg: Union[ScopedValue, str]):
218        """
219        Set argument of current node.
220
221        Args:
222            index (int): Indicate which input being modified.
223            arg (Union[ScopedValue, str]): New argument to been set.
224
225        Raises:
226            TypeError: If `index` is not a `int` number.
227            TypeError: If the type of `arg` is not in [`ScopedValue`, `str`].
228
229        Examples:
230            >>> from mindspore.rewrite import SymbolTree
231            >>> # Define the network structure of LeNet5. Refer to
232            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
233            >>> net = LeNet5()
234            >>> stree = SymbolTree.create(net)
235            >>> node = stree.get_node("relu_3")
236            >>> node.set_arg(0, "fc1")
237            >>> print(node.get_args())
238            [fc1]
239        """
240        Validator.check_value_type("index", index, [int], "Node")
241        Validator.check_value_type("arg", arg, [ScopedValue, str], "Node")
242        belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
243        if belong_symbol_tree is None:
244            self._node.set_arg(arg, index)
245        else:
246            belong_symbol_tree.set_node_arg(self._node, index, arg)
247
248    def set_arg_by_node(self, arg_idx: int, src_node: 'Node', out_idx: Optional[int] = None):
249        """
250        Set argument of current node by another Node.
251
252        Args:
253            arg_idx (int): Indicate which input being modified.
254            src_node (Node): A `Node` as new input. Can be a node or name of node.
255            out_idx (int, optional): Indicate which output of `src_node` as new input of current node.
256                Default: ``None`` ,
257                which means use first output of `src_node` as new input.
258
259        Raises:
260            TypeError: If `arg_idx` is not a `int` number.
261            ValueError: If `arg_idx` is out of range.
262            TypeError: If `src_node` is not a `Node` instance.
263            TypeError: If `out_idx` is not a `int` number.
264            ValueError: If `out_idx` is out of range.
265            ValueError: If `src_node` has multi-outputs while `out_idx` is None or `out_idx` is not offered.
266
267        Examples:
268            >>> from mindspore.rewrite import SymbolTree
269            >>> # Define the network structure of LeNet5. Refer to
270            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
271            >>> net = LeNet5()
272            >>> stree = SymbolTree.create(net)
273            >>> src_node = stree.get_node("fc1")
274            >>> dst_node = stree.get_node("relu_3")
275            >>> dst_node.set_arg_by_node(0, src_node, 0)
276            >>> print(dst_node.get_args())
277            [fc1_var]
278        """
279        Validator.check_value_type("arg_idx", arg_idx, [int], "Node")
280        Validator.check_value_type("src_node", src_node, [Node], "Node")
281        if out_idx is not None:
282            Validator.check_value_type("out_idx", out_idx, [int], "Node")
283        belong_symbol_tree: SymbolTreeImpl = self._node.get_belong_symbol_tree()
284        if belong_symbol_tree is None:
285            self._node.set_arg_by_node(arg_idx, src_node._node, out_idx)
286        else:
287            belong_symbol_tree.set_node_arg_by_node(self._node, arg_idx, src_node.get_handler(), out_idx)
288
289    def get_targets(self) -> [ScopedValue]:
290        """
291        Gets a list of output values for the current node.
292
293        Returns:
294            A list of outputs of type ``ScopedValue`` .
295        """
296        return self._node.get_targets()
297
298    def get_name(self) -> str:
299        """
300        Get the name of current node.
301
302        When node has been inserted into `SymbolTree`, the name of node should be unique in `SymbolTree`.
303
304        Returns:
305            A string as name of node.
306
307        Examples:
308            >>> from mindspore.rewrite import SymbolTree
309            >>> # Define the network structure of LeNet5. Refer to
310            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
311            >>> net = LeNet5()
312            >>> stree = SymbolTree.create(net)
313            >>> node = stree.get_node("conv1")
314            >>> name = node.get_name()
315            >>> print(name)
316            conv1
317        """
318        return self._node.get_name()
319
320    def get_node_type(self) -> NodeType:
321        """
322        Get the node_type of current node. See :class:`mindspore.rewrite.NodeType` for details on node types.
323
324        Returns:
325            A NodeType as node_type of node.
326
327        Examples:
328            >>> from mindspore.rewrite import SymbolTree
329            >>> # Define the network structure of LeNet5. Refer to
330            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
331            >>> net = LeNet5()
332            >>> stree = SymbolTree.create(net)
333            >>> node = stree.get_node("conv1")
334            >>> node_type = node.get_node_type()
335            >>> print(node_type)
336            NodeType.CallCell
337        """
338        return self._node.get_node_type()
339
340    def get_instance_type(self) -> type:
341        """
342        Gets the instance type called in the code corresponding to the current node.
343
344        - When `node_type` of current node is `CallCell`, the code for that node calls an instance of type ``Cell`` .
345        - When `node_type` of current node is `CallPrimitive`, the code for that node calls an instance of
346          type ``Primitive`` .
347        - When `node_type` of current node is `Tree`, the code for that node calls an instance of network type.
348        - When `node_type` of current node is `Python`, `Input`, `Output` or `CallMethod`, the instance type
349          is ``NoneType`` .
350
351        Returns:
352            The type of instance called in the statement corresponding to the current node.
353
354        Examples:
355            >>> from mindspore.rewrite import SymbolTree
356            >>> # Define the network structure of LeNet5. Refer to
357            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
358            >>> net = LeNet5()
359            >>> stree = SymbolTree.create(net)
360            >>> node = stree.get_node("conv1")
361            >>> instance_type = node.get_instance_type()
362            >>> print(instance_type)
363            <class 'mindspore.nn.layer.conv.Conv2d'>
364        """
365        return self._node.get_instance_type()
366
367    def get_instance(self):
368        return self._node.get_instance()
369
370    def get_args(self) -> [ScopedValue]:
371        """
372        Get arguments of current node.
373
374        Returns:
375            A list of arguments of type ``ScopedValue`` .
376
377        Examples:
378            >>> from mindspore.rewrite import SymbolTree
379            >>> # Define the network structure of LeNet5. Refer to
380            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
381            >>> net = LeNet5()
382            >>> stree = SymbolTree.create(net)
383            >>> node = stree.get_node("conv1")
384            >>> print(node.get_args())
385            [x]
386        """
387        return self._node.get_args()
388
389    def get_symbol_tree(self) -> 'SymbolTree':
390        """
391        Get the symbol tree which current node belongs to.
392
393        Returns:
394            SymbolTree, None if current node does not belong to any SymbolTree.
395
396        Examples:
397            >>> from mindspore.rewrite import SymbolTree
398            >>> # Define the network structure of LeNet5. Refer to
399            >>> # https://gitee.com/mindspore/docs/blob/master/docs/mindspore/code/lenet.py
400            >>> net = LeNet5()
401            >>> stree = SymbolTree.create(net)
402            >>> node = stree.get_node("conv1")
403            >>> print(type(node.get_symbol_tree()))
404            <class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
405        """
406        from .symbol_tree import SymbolTree
407        stree_impl = self._node.get_belong_symbol_tree()
408        if not stree_impl:
409            return None
410        return SymbolTree(stree_impl)
411
412    def get_sub_tree(self) -> 'SymbolTree':
413        """
414        Get the sub symbol tree stored in node with type of `NodeType.Tree` .
415        See :class:`mindspore.rewrite.NodeType` for details on node types.
416
417        Returns:
418            SymbolTree stored in Tree node.
419
420        Raises:
421            TypeError: If current node is not type of `NodeType.Tree` .
422            AttributeError: If no symbol tree is stored in Tree node.
423
424        Examples:
425        >>> import mindspore.nn as nn
426        >>> from mindspore.rewrite import SymbolTree
427        >>>
428        >>> class SubNet(nn.Cell):
429        ...     def __init__(self):
430        ...         super().__init__()
431        ...         self.relu = nn.ReLU()
432        ...
433        ...     def construct(self, x):
434        ...         x = self.relu(x)
435        ...         return x
436        ...
437        >>> class Net(nn.Cell):
438        ...     def __init__(self):
439        ...         super().__init__()
440        ...         self.subnet = SubNet()
441        ...
442        ...     def construct(self, x):
443        ...         x = self.subnet(x)
444        ...         return x
445        >>>
446        >>> net = Net()
447        >>> stree = SymbolTree.create(net)
448        >>> node = stree.get_node("subnet")
449        >>> print(type(node.get_sub_tree()))
450        <class 'mindspore.rewrite.api.symbol_tree.SymbolTree'>
451        """
452        if self.get_node_type() != NodeType.Tree:
453            raise TypeError("For get_sub_tree, the type of node should be 'NodeType.Tree', "
454                            f"but got {self.get_node_type()}")
455        subtree: SymbolTreeImpl = self.get_handler().symbol_tree
456        if subtree is None:
457            raise AttributeError(
458                f"For get_sub_tree, no symbol tree is stroed in node {self.get_name()}.")
459        from .symbol_tree import SymbolTree
460        return SymbolTree(subtree)
461
462    def get_kwargs(self) -> {str: ScopedValue}:
463        """
464        Get keyword arguments of current node.
465
466        Returns:
467            A dict of keyword arguments, where key is of type str, and value is of type ``ScopedValue`` .
468
469        Examples:
470            >>> from mindspore.rewrite import SymbolTree
471            >>> from mindspore import nn
472            >>>
473            >>> class ReLUNet(nn.Cell):
474            ...     def __init__(self):
475            ...         super().__init__()
476            ...         self.relu = nn.ReLU()
477            ...
478            ...     def construct(self, input):
479            ...         output = self.relu(x=input)
480            ...         return output
481            >>>
482            >>> net = ReLUNet()
483            >>> stree = SymbolTree.create(net)
484            >>> node = stree.get_node("relu")
485            >>> print(node.get_kwargs())
486            {'x': input}
487        """
488        return self._node.get_kwargs()
489
490    def set_attribute(self, key: str, value):
491        Validator.check_value_type("key", key, [str], "Node attribute")
492        self._node.set_attribute(key, value)
493
494    def get_attributes(self) -> {str: object}:
495        return self._node.get_attributes()
496
497    def get_attribute(self, key: str):
498        Validator.check_value_type("key", key, [str], "Node attribute")
499        return self._node.get_attribute(key)
500
501    # pylint: disable=missing-docstring
502    def get_arg_providers(self) -> dict:
503        arg_providers = {}
504        for arg_idx, providers in self._node.get_arg_providers().items():
505            arg_providers[arg_idx] = (Node(providers[0]), providers[1])
506        return arg_providers
507
508    # pylint: disable=missing-docstring
509    def get_target_users(self, index=-1) -> Union[dict, list]:
510        Validator.check_value_type("index", index, [int], "get_target_users")
511        if index == -1:
512            target_users = {}
513            for target_idx, users in self._node.get_target_users().items():
514                target_users[target_idx] = [(Node(user[0]), user[1]) for user in users]
515            return target_users
516        target_users = []
517        for users in self._node.get_target_users(index):
518            target_users.append((Node(users[0]), users[1]))
519        return target_users
520