1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Symbol implementation.""" 16 17__all__ = ['Symbol'] 18 19 20class Symbol: 21 r""" 22 Symbol is a data structure to indicate the symbolic info of shape. 23 24 For dynamic shape networks, compared with only setting the unknown dimensions ( ``None`` ) in `Tensor` , providing 25 more symbolic shape info can help the framework better optimize the computation graph, to improve the performance of 26 network execution. 27 28 Args: 29 max (int): The maximum length of this dimension, which is valid when it's greater than `min`. Default: ``0`` . 30 min (int): The minimum length of this dimension. Default: ``1`` . 31 divisor (int): The divisor( :math:`d` ). When `remainder` is 0, it means this dimension can be divided by 32 :math:`d` . Default: ``1`` . 33 remainder (int): The remainder( :math:`r` ) when symbol is represented by :math:`d * N + r, N \ge 1` . 34 Default: ``0`` . 35 unique (bool): When the symbol object is used multiple times, if `unique` is ``True`` , the shape items of this 36 symbol are considered to be same length, otherwise only symbol info is shared by multiple dimensions. 37 Default: ``False`` . 38 39 Outputs: 40 Symbol. 41 42 Raises: 43 TypeError: If `max`, `min`, `divisor`, `remainder` is not an int. 44 TypeError: If `unique` is not a bool. 45 ValueError: If `min` is not positive value. 46 ValueError: If `divisor` is not positive value. 47 ValueError: If `remainder` is not in the range :math:`[0, d)` . 48 49 Examples: 50 >>> import numpy as np 51 >>> import mindspore as ms 52 >>> from mindspore import nn, Tensor, Symbol 53 >>> 54 >>> class Net(nn.Cell): 55 ... def __init__(self): 56 ... super(Net, self).__init__() 57 ... self.abs = ms.ops.Abs() 58 ... def construct(self, x): 59 ... return self.abs(x) 60 ... 61 >>> net = Net() 62 >>> s1 = Symbol(divisor=8, remainder=1) 63 >>> s2 = Symbol(max=32, unique=True) 64 >>> dyn_t = Tensor(shape=(None, s1, s1, s2, s2), dtype=ms.float32) 65 >>> net.set_inputs(dyn_t) 66 >>> # the shape values of last two dimensions must be equal, because "s2" is set to "unique" 67 >>> net(Tensor(np.random.randn(1, 9, 17, 32, 32), dtype=ms.float32)).shape 68 (1, 9, 17, 32, 32) 69 >>> net(Tensor(np.random.randn(8, 25, 9, 30, 30), dtype=ms.float32)).shape 70 (8, 25, 9, 30, 30) 71 """ 72 73 def __init__(self, max=0, min=1, divisor=1, remainder=0, unique=False, **kawgs): 74 Symbol._check_args_type(max, min, divisor, remainder, unique) 75 if min <= 0: 76 raise ValueError("For 'Symbol', the 'min' value should be positive, but got {}".format(min)) 77 if divisor <= 0: 78 raise ValueError("For 'Symbol', the 'divisor' value should be positive, but got {}".format(divisor)) 79 if remainder < 0 or remainder >= divisor: 80 raise ValueError( 81 "For 'Symbol', the 'remainder' value should be in the range '[0, {})', but got {}".format( 82 divisor, remainder)) 83 self.max = max 84 self.min = min 85 self.divisor = divisor 86 self.remainder = remainder 87 self.unique = unique 88 self.id = id(self) 89 90 def __str__(self): 91 return str(self.to_dict()) 92 93 @staticmethod 94 def _check_args_type(maxv, minv, divisor, remainder, unique): 95 """Check the type of arguments.""" 96 if not isinstance(maxv, int): 97 raise TypeError(f"For 'Symbol', the argument 'max' must be int, but got {type(maxv)}") 98 if not isinstance(minv, int): 99 raise TypeError(f"For 'Symbol', the argument 'min' must be int, but got {type(minv)}") 100 if not isinstance(divisor, int): 101 raise TypeError(f"For 'Symbol', the argument 'divisor' must be int, but got {type(divisor)}") 102 if not isinstance(remainder, int): 103 raise TypeError(f"For 'Symbol', the argument 'remainder' must be int, but got {type(remainder)}") 104 if not isinstance(unique, bool): 105 raise TypeError(f"For 'Symbol', the argument 'unique' must be bool, but got {type(unique)}") 106 107 # pylint: disable=missing-docstring 108 def to_dict(self): 109 # Convert the symbolic info to dictionary. 110 # This method is not necessary to show in public api document, use comment instead of docstring. 111 res = {} 112 if self.max > self.min: 113 res["max"] = self.max 114 if self.min > self.divisor + self.remainder: # the symbol is "d * N + r" and N >= 1 115 res["min"] = self.min 116 if self.divisor != 1: 117 res["divisor"] = self.divisor 118 if self.remainder != 0: 119 res["remainder"] = self.remainder 120 if self.unique: 121 res["id"] = self.id 122 return res 123