• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Symbol implementation."""
16
17__all__ = ['Symbol']
18
19
20class Symbol:
21    r"""
22    Symbol is a data structure to indicate the symbolic info of shape.
23
24    For dynamic shape networks, compared with only setting the unknown dimensions ( ``None`` ) in `Tensor` , providing
25    more symbolic shape info can help the framework better optimize the computation graph, to improve the performance of
26    network execution.
27
28    Args:
29        max (int): The maximum length of this dimension, which is valid when it's greater than `min`. Default: ``0`` .
30        min (int): The minimum length of this dimension. Default: ``1`` .
31        divisor (int): The divisor( :math:`d` ). When `remainder` is 0, it means this dimension can be divided by
32            :math:`d` . Default: ``1`` .
33        remainder (int): The remainder( :math:`r` ) when symbol is represented by :math:`d * N + r, N \ge 1` .
34            Default: ``0`` .
35        unique (bool): When the symbol object is used multiple times, if `unique` is ``True`` , the shape items of this
36            symbol are considered to be same length, otherwise only symbol info is shared by multiple dimensions.
37            Default: ``False`` .
38
39    Outputs:
40        Symbol.
41
42    Raises:
43        TypeError: If `max`, `min`, `divisor`, `remainder` is not an int.
44        TypeError: If `unique` is not a bool.
45        ValueError: If `min` is not positive value.
46        ValueError: If `divisor` is not positive value.
47        ValueError: If `remainder` is not in the range :math:`[0, d)` .
48
49    Examples:
50        >>> import numpy as np
51        >>> import mindspore as ms
52        >>> from mindspore import nn, Tensor, Symbol
53        >>>
54        >>> class Net(nn.Cell):
55        ...     def __init__(self):
56        ...         super(Net, self).__init__()
57        ...         self.abs = ms.ops.Abs()
58        ...     def construct(self, x):
59        ...         return self.abs(x)
60        ...
61        >>> net = Net()
62        >>> s1 = Symbol(divisor=8, remainder=1)
63        >>> s2 = Symbol(max=32, unique=True)
64        >>> dyn_t = Tensor(shape=(None, s1, s1, s2, s2), dtype=ms.float32)
65        >>> net.set_inputs(dyn_t)
66        >>> # the shape values of last two dimensions must be equal, because "s2" is set to "unique"
67        >>> net(Tensor(np.random.randn(1, 9, 17, 32, 32), dtype=ms.float32)).shape
68        (1, 9, 17, 32, 32)
69        >>> net(Tensor(np.random.randn(8, 25, 9, 30, 30), dtype=ms.float32)).shape
70        (8, 25, 9, 30, 30)
71    """
72
73    def __init__(self, max=0, min=1, divisor=1, remainder=0, unique=False, **kawgs):
74        Symbol._check_args_type(max, min, divisor, remainder, unique)
75        if min <= 0:
76            raise ValueError("For 'Symbol', the 'min' value should be positive, but got {}".format(min))
77        if divisor <= 0:
78            raise ValueError("For 'Symbol', the 'divisor' value should be positive, but got {}".format(divisor))
79        if remainder < 0 or remainder >= divisor:
80            raise ValueError(
81                "For 'Symbol', the 'remainder' value should be in the range '[0, {})', but got {}".format(
82                    divisor, remainder))
83        self.max = max
84        self.min = min
85        self.divisor = divisor
86        self.remainder = remainder
87        self.unique = unique
88        self.id = id(self)
89
90    def __str__(self):
91        return str(self.to_dict())
92
93    @staticmethod
94    def _check_args_type(maxv, minv, divisor, remainder, unique):
95        """Check the type of arguments."""
96        if not isinstance(maxv, int):
97            raise TypeError(f"For 'Symbol', the argument 'max' must be int, but got {type(maxv)}")
98        if not isinstance(minv, int):
99            raise TypeError(f"For 'Symbol', the argument 'min' must be int, but got {type(minv)}")
100        if not isinstance(divisor, int):
101            raise TypeError(f"For 'Symbol', the argument 'divisor' must be int, but got {type(divisor)}")
102        if not isinstance(remainder, int):
103            raise TypeError(f"For 'Symbol', the argument 'remainder' must be int, but got {type(remainder)}")
104        if not isinstance(unique, bool):
105            raise TypeError(f"For 'Symbol', the argument 'unique' must be bool, but got {type(unique)}")
106
107    # pylint: disable=missing-docstring
108    def to_dict(self):
109        # Convert the symbolic info to dictionary.
110        # This method is not necessary to show in public api document, use comment instead of docstring.
111        res = {}
112        if self.max > self.min:
113            res["max"] = self.max
114        if self.min > self.divisor + self.remainder:  # the symbol is "d * N + r" and N >= 1
115            res["min"] = self.min
116        if self.divisor != 1:
117            res["divisor"] = self.divisor
118        if self.remainder != 0:
119            res["remainder"] = self.remainder
120        if self.unique:
121            res["id"] = self.id
122        return res
123