• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Provide random seed api."""
16import numpy as np
17from mindspore._checkparam import Validator
18
19# constants
20DEFAULT_GRAPH_SEED = 87654321
21_MAXINT32 = 2**31 - 1
22keyConstant = [3528531795, 2654435769, 3449720151, 3144134277]
23
24# set global RNG seed
25_GLOBAL_SEED = None
26_KERNEL_SEED = {}
27
28
29def _reset_op_seed():
30    """
31    Reset op seeds in the kernel's dictionary.
32    """
33    for (kernel_name, op_seed) in _KERNEL_SEED:
34        _KERNEL_SEED[(kernel_name, op_seed)] = op_seed
35
36
37def set_seed(seed):
38    """
39    Set global seed.
40
41    Note:
42        The global seed is used by numpy.random, mindspore.common.Initializer, mindspore.ops.composite.random_ops and
43        mindspore.nn.probability.distribution.
44
45        If global seed is not set, these packages will use their own default seed independently, numpy.random and
46        mindspore.common.Initializer will choose a random seed, mindspore.ops.composite.random_ops and
47        mindspore.nn.probability.distribution will use zero.
48
49        Seed set by numpy.random.seed() only used by numpy.random, while seed set by this API will also used by
50        numpy.random, so just set all seed by this API is recommended.
51
52    Args:
53        seed (int): The seed to be set.
54
55    Raises:
56        ValueError: If seed is invalid (< 0).
57        TypeError: If seed isn't a int.
58
59    Examples:
60        >>> import numpy as np
61        >>> import mindspore.ops as ops
62        >>> from mindspore import Tensor, set_seed, Parameter
63        >>> from mindspore.common.initializer import initializer
64        >>>
65        >>> # Note: (1) Please make sure the code is running in PYNATIVE MODE;
66        >>> # (2) Because Composite-level ops need parameters to be Tensors, for below examples,
67        >>> # when using ops.uniform operator, minval and maxval are initialised as:
68        >>> minval = Tensor(1.0, ms.float32)
69        >>> maxval = Tensor(2.0, ms.float32)
70        >>>
71        >>> # 1. If global seed is not set, numpy.random and initializer will choose a random seed:
72        >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
73        >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A2
74        >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W1
75        >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W2
76        >>> # Rerun the program will get different results:
77        >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A3
78        >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A4
79        >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W3
80        >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W4
81        >>>
82        >>> # 2. If global seed is set, numpy.random and initializer will use it:
83        >>> set_seed(1234)
84        >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
85        >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A2
86        >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W1
87        >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W2
88        >>> # Rerun the program will get the same results:
89        >>> set_seed(1234)
90        >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
91        >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A2
92        >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W1
93        >>> w1 = Parameter(initializer("uniform", [2, 2], ms.float32), name="w1") # W2
94        >>>
95        >>> # 3. If neither global seed nor op seed is set, mindspore.ops.composite.random_ops and
96        >>> # mindspore.nn.probability.distribution will choose a random seed:
97        >>> c1 = ops.uniform((1, 4), minval, maxval) # C1
98        >>> c2 = ops.uniform((1, 4), minval, maxval) # C2
99        >>> # Rerun the program will get different results:
100        >>> c1 = ops.uniform((1, 4), minval, maxval) # C3
101        >>> c2 = ops.uniform((1, 4), minval, maxval) # C4
102        >>>
103        >>> # 4. If global seed is set, but op seed is not set, mindspore.ops.composite.random_ops and
104        >>> # mindspore.nn.probability.distribution will calculate a seed according to global seed and
105        >>> # default op seed. Each call will change the default op seed, thus each call get different
106        >>> # results.
107        >>> set_seed(1234)
108        >>> c1 = ops.uniform((1, 4), minval, maxval) # C1
109        >>> c2 = ops.uniform((1, 4), minval, maxval) # C2
110        >>> # Rerun the program will get the same results:
111        >>> set_seed(1234)
112        >>> c1 = ops.uniform((1, 4), minval, maxval) # C1
113        >>> c2 = ops.uniform((1, 4), minval, maxval) # C2
114        >>>
115        >>> # 5. If both global seed and op seed are set, mindspore.ops.composite.random_ops and
116        >>> # mindspore.nn.probability.distribution will calculate a seed according to global seed and
117        >>> # op seed counter. Each call will change the op seed counter, thus each call get different
118        >>> # results.
119        >>> set_seed(1234)
120        >>> c1 = ops.uniform((1, 4), minval, maxval, seed=2) # C1
121        >>> c2 = ops.uniform((1, 4), minval, maxval, seed=2) # C2
122        >>> # Rerun the program will get the same results:
123        >>> set_seed(1234)
124        >>> c1 = ops.uniform((1, 4), minval, maxval, seed=2) # C1
125        >>> c2 = ops.uniform((1, 4), minval, maxval, seed=2) # C2
126        >>>
127        >>> # 6. If op seed is set but global seed is not set, 0 will be used as global seed. Then
128        >>> # mindspore.ops.composite.random_ops and mindspore.nn.probability.distribution act as in
129        >>> # condition 5.
130        >>> c1 = ops.uniform((1, 4), minval, maxval, seed=2) # C1
131        >>> c2 = ops.uniform((1, 4), minval, maxval, seed=2) # C2
132        >>> # Rerun the program will get the same results:
133        >>> c1 = ops.uniform((1, 4), minval, maxval, seed=2) # C1
134        >>> c2 = ops.uniform((1, 4), minval, maxval, seed=2) # C2
135        >>>
136        >>> # 7. Recall set_seed() in the program will reset numpy seed and op seed counter of
137        >>> # mindspore.ops.composite.random_ops and mindspore.nn.probability.distribution.
138        >>> set_seed(1234)
139        >>> np_1 = np.random.normal(0, 1, [1]).astype(np.float32) # A1
140        >>> c1 = ops.uniform((1, 4), minval, maxval, seed=2) # C1
141        >>> set_seed(1234)
142        >>> np_2 = np.random.normal(0, 1, [1]).astype(np.float32) # still get A1
143        >>> c2 = ops.uniform((1, 4), minval, maxval, seed=2) # still get C1
144    """
145    if not isinstance(seed, int):
146        raise TypeError("The argument 'seed' must be type of int, but got {}.".format(type(seed)))
147    Validator.check_non_negative_int(seed, "seed", "global_seed")
148    import mindspore.dataset as de
149    np.random.seed(seed)
150    de.config.set_seed(seed)
151    _reset_op_seed()
152    global _GLOBAL_SEED
153    _GLOBAL_SEED = seed
154
155
156def get_seed():
157    """
158    Get global seed.
159
160    Returns:
161        Integer. The global seed.
162    """
163    return _GLOBAL_SEED
164
165
166def _truncate_seed(seed):
167    """
168    Truncate the seed with MAXINT32.
169
170    Args:
171        seed (int): The seed to be truncated.
172
173    Returns:
174        Integer. The seed with MAXINT32.
175    """
176    return seed % _MAXINT32
177
178
179def _update_seeds(op_seed, kernel_name):
180    """
181    Update the seed every time when the op seed is called.
182
183    Args:
184        op_seed (int): The op seed to be updated.
185        kernel_name (string): The random op kernel.
186    """
187    global _KERNEL_SEED
188    if op_seed is not None:
189        _KERNEL_SEED[(kernel_name, op_seed)] = _KERNEL_SEED[(kernel_name, op_seed)] + (keyConstant[0] ^ keyConstant[2])
190
191
192def _get_op_seed(op_seed, kernel_name):
193    """
194    Get op seed which is relating to the specific kernel.
195    If the seed does not exist, add it into the kernel's dictionary.
196
197    Args:
198        op_seed (int): The op seed to be updated.
199        kernel_name (string): The random op kernel.
200    """
201    if (kernel_name, op_seed) not in _KERNEL_SEED:
202        _KERNEL_SEED[(kernel_name, op_seed)] = op_seed
203    return _KERNEL_SEED[(kernel_name, op_seed)]
204
205
206def _get_global_and_op_seed():
207    """Get global_seed and op_seed."""
208    global_seed = get_seed()
209    op_seed = get_seed()
210    if global_seed == 0:
211        global_seed = DEFAULT_GRAPH_SEED
212    elif global_seed is None:
213        global_seed = 0
214    if op_seed is None:
215        op_seed = 0
216    Validator.check_non_negative_int(op_seed, "seed", "init")
217    temp_seed = _get_op_seed(op_seed, "init")
218    seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
219    return seeds
220
221
222def _get_graph_seed(op_seed, kernel_name):
223    """
224    Get the graph-level seed.
225    Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
226    If op-level seed is 0, use graph-level seed; if graph-level seed is also 0, the system would generate a
227    random seed.
228
229    Note:
230        For each seed, either op-seed or graph-seed, a random sequence will be generated relating to this seed.
231        So, the state of the seed regarding to this op should be recorded.
232        A simple illustration should be:
233          If a random op is called twice within one program, the two results should be different:
234          minval = Tensor(1.0, mstype.float32)
235          maxval = Tensor(2.0, mstype.float32)
236          print(C.uniform((1, 4), minval, maxval, seed=1))  # generates 'A1'
237          print(C.uniform((1, 4), minval, maxval, seed=1))  # generates 'A2'
238          If the same program runs again, it repeat the results:
239          print(C.uniform((1, 4), minval, maxval, seed=1))  # generates 'A1'
240          print(C.uniform((1, 4), minval, maxval, seed=1))  # generates 'A2'
241
242    Returns:
243        Integer. The current graph-level seed.
244
245    Examples:
246        >>> print(_get_graph_seed(0, 'normal'))
247        (0, 0)
248    """
249    global_seed = get_seed()
250    if global_seed == 0:
251        global_seed = DEFAULT_GRAPH_SEED
252    elif global_seed is None:
253        global_seed = 0
254    if op_seed is None:
255        op_seed = 0
256    # neither global seed or op seed is set, return (0, 0) to let kernel choose random seed.
257    if global_seed == 0 and op_seed == 0:
258        seeds = 0, 0
259    else:
260        Validator.check_non_negative_int(op_seed, "seed", kernel_name)
261        temp_seed = _get_op_seed(op_seed, kernel_name)
262        seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)
263        _update_seeds(op_seed, kernel_name)
264    return seeds
265