• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Defines parameter operators with functional form."""
16
17from __future__ import absolute_import
18import numpy as np
19
20from mindspore import context
21from mindspore.ops import operations as P
22from mindspore.ops import functional as F
23from mindspore.ops.primitive import constexpr, _primexpr
24from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
25from mindspore.common import dtype as mstype
26from mindspore.common.seed import _get_graph_seed
27from mindspore.common.tensor import Tensor
28from mindspore.ops.operations.random_ops import RandomShuffle, RandomChoiceWithMask
29from mindspore.common.api import _function_forbid_reuse
30from mindspore.ops.auto_generate import randperm
31from mindspore.common.generator import default_generator
32from mindspore.ops.auto_generate import UniformExt, NormalTensorTensor, \
33    NormalTensorFloat, NormalFloatTensor, NormalFloatFloat, RandExt, RandLikeExt
34
35normal_tensor_tensor_op = NormalTensorTensor()
36normal_tensor_float_op = NormalTensorFloat()
37normal_float_tensor_op = NormalFloatTensor()
38normal_float_float_op = NormalFloatFloat()
39cast_ = P.Cast()
40log_ = P.Log()
41real_div_ = P.RealDiv()
42reshape_ = P.Reshape()
43shape_ = P.Shape()
44top_k_ = P.TopK()
45uniform_ = UniformExt()
46rand_ext_ = RandExt()
47rand_like_ext_ = RandLikeExt()
48generator_step_ = Tensor(10, mstype.int64)
49
50
51@constexpr
52def _set_prim_op_user_data(prim, key, value):
53    prim.add_prim_attr(key, value)
54    return prim
55
56
57@_function_forbid_reuse
58def random_gamma(shape, alpha, seed=None):
59    r"""
60    Outputs random values from the Gamma distribution(s) described by alpha.
61
62
63    Args:
64        shape (Tensor): The shape of random tensor to be generated.
65            Must be one of the following types: int32, int64. 1-D integer tensor.
66        alpha (Tensor): The :math:`\alpha` distribution parameter.
67            A Tensor. Must be one of the following types: half, float32, float64.
68        seed (int, optional): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
69            Default: ``None`` , which will be treated as 0.
70
71    Returns:
72        Tensor. The shape should be equal to the concat shape between the input `shape` and the broadcast
73        of `alpha`.
74        The dtype is the same type as alpha.
75
76    Raises:
77        TypeError: If `shape` is not a Tensor.
78        TypeError: If `alpha` is not a Tensor.
79        TypeError: If `seed` is not an int.
80        TypeError: If dtype of `alpha` is not half, float32 or float64.
81
82    Supported Platforms:
83        ``CPU``
84
85    Examples:
86        >>> import numpy as np
87        >>> import mindspore
88        >>> from mindspore import Tensor, ops
89        >>> shape = Tensor(np.array([7, 5]), mindspore.int32)
90        >>> alpha = Tensor(np.array([0.5, 1.5]), mindspore.float32)
91        >>> output = ops.random_gamma(shape, alpha, seed=5)
92        >>> result = output.shape
93        >>> print(result)
94        (7, 5, 2)
95    """
96    seed1, seed2 = _get_seed(seed, "random_gamma")
97    random_gamma_op = P.RandomGamma(seed1, seed2)
98    random_gamma_op = _set_prim_op_user_data(
99        random_gamma_op, "random_cache", False)
100    output = random_gamma_op(shape, alpha)
101    return output
102
103
104@constexpr(reuse_result=False)
105def _get_seed(op_seed, kernel_name):
106    """Get the graph-level seed."""
107    return _get_graph_seed(op_seed, kernel_name)
108
109
110@_function_forbid_reuse
111def standard_laplace(shape, seed=None):
112    r"""
113    Generates random numbers according to the Laplace random number distribution (mean=0, lambda=1).
114    It is defined as:
115
116    .. math::
117        \text{f}(x) = \frac{1}{2}\exp(-|x|)
118
119    Args:
120        shape (Union[tuple, Tensor]): The shape of random tensor to be generated. Only constant value is allowed
121          when the input type is tuple. And the operator supports dynamic shape only when the input type is Tensor.
122        seed (int, optional): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
123          Default: ``None`` , which will be treated as 0.
124
125    Returns:
126        Tensor. The shape that the input 'shape' denotes. The dtype is float32.
127
128    Raises:
129        TypeError: If shape is neither a tuple nor a Tensor.
130        ValueError: If shape is a tuple containing non-positive items.
131        ValueError: If shape is a Tensor, and the rank of the Tensor is not equal to 1.
132
133    Supported Platforms:
134        ``Ascend`` ``GPU`` ``CPU``
135
136    Examples:
137        >>> from mindspore import ops
138        >>> shape = (4, 4)
139        >>> output = ops.standard_laplace(shape)
140        >>> result = output.shape
141        >>> print(result)
142        (4, 4)
143    """
144    seed1, seed2 = _get_seed(seed, "standard_laplace")
145    standard_laplace_op = P.StandardLaplace(seed=seed1, seed2=seed2)
146    standard_laplace_op = _set_prim_op_user_data(
147        standard_laplace_op, "random_cache", False)
148    return standard_laplace_op(shape)
149
150
151@_function_forbid_reuse
152def random_categorical(logits, num_sample, seed=0, dtype=mstype.int64):
153    r"""
154    Generates random samples from a given categorical distribution tensor.
155
156    Args:
157        logits (Tensor): The input tensor. 2-D Tensor with shape :math:`(batch\_size, num\_classes)`.
158        num_sample (int):  Number of sample to be drawn. Only constant values is allowed.
159        seed (int):  Random seed. Only constant values is allowed. Default: ``0`` .
160        dtype (mindspore.dtype): The type of output. Its value must be one of mindspore.int16,
161            mindspore.int32 and mindspore.int64. Default: ``mstype.int64`` .
162
163    Returns:
164        Tensor, The output Tensor with shape :math:`(batch\_size, num\_samples)`.
165
166    Raises:
167        TypeError: If `dtype` is not one of the following: mindspore.int16, mindspore.int32, mindspore.int64.
168        TypeError: If `logits` is not a Tensor.
169        TypeError: If neither `num_sample` nor `seed` is an int.
170
171    Supported Platforms:
172        ``Ascend`` ``GPU`` ``CPU``
173
174    Examples:
175        >>> from mindspore import ops
176        >>> from mindspore import Tensor
177        >>> import mindspore.common.dtype as mstype
178        >>> import numpy as np
179        >>> logits = Tensor(np.random.random((10, 5)).astype(np.float32), mstype.float32)
180        >>> net = ops.random_categorical(logits, 8)
181        >>> result = net.shape
182        >>> print(result)
183        (10, 8)
184    """
185    random_categorical_ = P.RandomCategorical(dtype)
186    random_categorical_ = _set_prim_op_user_data(
187        random_categorical_, "random_cache", False)
188    return random_categorical_(logits, num_sample, seed)
189
190
191@_function_forbid_reuse
192def multinomial_with_replacement(x, seed, offset, numsamples, replacement=False):
193    r"""
194    Returns a tensor where each row contains numsamples indices sampled from the
195    multinomial distribution with replacement. It is different from `multinomial` in that it allows
196    the same outcome to be chosen multiple times.
197
198    Note:
199        The rows of input do not need to sum to one (in which case we use the values as weights),
200        but must be non-negative, finite and have a non-zero sum.
201
202    Args:
203        x (Tensor): the input tensor containing the cumsum of probabilities, must be 1 or 2
204          dimensions. Must be one of the following types: float16, float32, float64.
205        seed (int): If seed is set to be -1, and offset is set to be 0, the random number
206          generator is seeded by a random seed. Otherwise, it is seeded by the given seed.
207        offset (int): Offset used to avoid seed collision.
208        numsamples (int): the number of samples to draw.
209        replacement (bool, optional): Whether to draw with replacement or not. Default: ``False`` .
210
211    Returns:
212        Tensor with the same rows as `x`, each row has `numsamples` sampled indices.
213
214    Raises:
215        TypeError: If `x`  is not a 1D or 2D Tensor.
216        TypeError: If dtype of `x` is not float16, float32 or float64.
217        TypeError: If `numsamples` is not an int.
218        TypeError: If `replacement` is not a bool.
219        ValueError: If the value of `numsamples` is not greater than x_shape[-1] when `replacement` is False.
220        ValueError: If the sum of one row of `x` less than 0.
221        ValueError: If one of the element of each row of `x` less than 0.
222        ValueError: If `numsamples` equal or less than 0.
223
224    Supported Platforms:
225        ``CPU``
226
227    Examples:
228        >>> from mindspore import Tensor, ops
229        >>> from mindspore import dtype as mstype
230        >>> x = Tensor([[0., 9., 4., 0.]], mstype.float32)
231        >>> output = ops.multinomial_with_replacement(x, 2, 5, 2, True)
232        >>> print(output)
233        [[1 1]]
234    """
235    if not isinstance(seed, Tensor):
236        if not isinstance(seed, int):
237            raise TypeError(f"For multinomial_with_replacement,",
238                            f"the input[seed] must be int, but got {type(seed)}.")
239        seed = Tensor(seed, dtype=mstype.int64)
240    if not isinstance(offset, Tensor):
241        if not isinstance(offset, int):
242            raise TypeError(f"For multinomial_with_replacement,",
243                            f"the input[offset] must be int, but got {type(offset)}.")
244        offset = Tensor(offset, dtype=mstype.int64)
245    multinomial_with_replacement_ = P.MultinomialWithReplacement(numsamples=numsamples,
246                                                                 replacement=replacement)
247    multinomial_with_replacement_ = _set_prim_op_user_data(
248        multinomial_with_replacement_, "random_cache", False)
249    return multinomial_with_replacement_(x, seed, offset)
250
251
252@_function_forbid_reuse
253def uniform_ext(tensor, a, b, generator=None):
254    """
255    Generates random numbers in the half-open interval [a, b).
256
257    Args:
258        tensor (Tensor): The origin input tensor.
259        a (number): The lower bound of the interval.
260        b (number): The upper bound of the interval.
261        generator (Generator, optional): The random seed. Default: None.
262
263    Raises:
264        TypeError: If `a` is larger than `b`.
265
266    Returns:
267        Tensor, with the same shape as tensor.
268
269    Examples:
270        >>> import mindspore
271        >>> from mindspore import ops
272        >>> x = ops.ones((4, 2))
273        >>> generator = mindspore.Generator()
274        >>> generator.manual_seed(100)
275        >>> result = ops.function.random_func.uniform_ext(x, 1., 2., generator)
276        >>> print(result.shape)
277        (4, 2)
278    """
279    if generator is None:
280        generator = default_generator
281    seed, offset = generator._step(generator_step_)  # pylint: disable=protected-access
282    return uniform_(tensor, a, b, seed, offset)
283
284
285@_function_forbid_reuse
286def uniform(shape, minval, maxval, seed=None, dtype=mstype.float32):
287    """
288    Generates random numbers according to the Uniform random number distribution.
289
290    Note:
291        The number in tensor minval should be strictly less than maxval at any position after broadcasting.
292
293    Args:
294        shape (Union[tuple, Tensor]): The shape of random tensor to be generated.
295        minval (Tensor): The distribution parameter `a`.
296          It defines the minimum possible generated value, with int32 or float32 data type.
297          If dtype is int32, only one number is allowed.
298        maxval (Tensor): The distribution parameter `b`.
299          It defines the maximum possible generated value, with int32 or float32 data type.
300          If dtype is int32, only one number is allowed.
301        seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers,
302          must be non-negative. Default: ``None`` , which will be treated as 0.
303        dtype (mindspore.dtype): Type of the Uniform distribution. If it is int32, it generates numbers from discrete
304          uniform distribution; if it is float32, it generates numbers from continuous uniform distribution. It only
305          supports these two data types. Default: mstype.float32.
306
307    Returns:
308        Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
309        of `minval` and `maxval`.
310        The dtype is designated as the input `dtype`.
311
312    Raises:
313        TypeError: If `shape` is neither a tuple nor a Tensor.
314        TypeError: If 'minval' or 'maxval' is neither int32 nor float32
315            and dtype of 'minval' is not the same as 'maxval'.
316        TypeError: If `seed` is not an int.
317        TypeError: If 'dtype' is neither int32 nor float32.
318
319    Supported Platforms:
320        ``GPU`` ``CPU``
321
322    Examples:
323        >>> from mindspore import Tensor, ops
324        >>> import mindspore
325        >>> import numpy as np
326        >>> # For discrete uniform distribution, only one number is allowed for both minval and maxval:
327        >>> shape = (4, 2)
328        >>> minval = Tensor(1, mindspore.int32)
329        >>> maxval = Tensor(2, mindspore.int32)
330        >>> output = ops.uniform(shape, minval, maxval, seed=5, dtype=mindspore.int32)
331        >>>
332        >>> # For continuous uniform distribution, minval and maxval can be multi-dimentional:
333        >>> shape = (3, 1, 2)
334        >>> minval = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
335        >>> maxval = Tensor([8.0, 10.0], mindspore.float32)
336        >>> output = ops.uniform(shape, minval, maxval, seed=5)
337        >>> result = output.shape
338        >>> print(result)
339        (3, 2, 2)
340    """
341    if not isinstance(minval, Tensor) or not isinstance(maxval, Tensor):
342        raise TypeError(
343            f"For functional operator[uniform], the input[minval] and input[maxval] must be a Tensor.")
344
345    minval_dtype = F.dtype(minval)
346    maxval_dtype = F.dtype(maxval)
347    const_utils.check_type_valid(
348        dtype, [mstype.int32, mstype.float32], 'uniform')
349    const_utils.check_tensors_dtype_same(minval_dtype, dtype, "uniform")
350    const_utils.check_tensors_dtype_same(maxval_dtype, dtype, "uniform")
351    seed1, seed2 = _get_seed(seed, "uniform")
352    if const_utils.is_same_type(dtype, mstype.int32):
353        random_uniform = P.UniformInt(seed1, seed2)
354        random_uniform = _set_prim_op_user_data(
355            random_uniform, "random_cache", False)
356        value = random_uniform(shape, minval, maxval)
357    else:
358        uniform_real = P.UniformReal(seed1, seed2)
359        uniform_real = _set_prim_op_user_data(
360            uniform_real, "random_cache", False)
361        uniform_real = uniform_real(shape)
362        value = uniform_real * (maxval - minval) + minval
363    return value
364
365
366@_function_forbid_reuse
367def standard_normal(shape, seed=None):
368    r"""
369    Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
370
371    Returns the tensor with the given shape, the random numbers in it drawn from normal distributions
372    whose mean is 0 and standard deviation is 1.
373
374    .. math::
375        f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
376
377    Args:
378        shape (Union[tuple, Tensor]): The shape of random tensor to be generated. Only constant value is allowed
379          when the input type is tuple. And the operator supports dynamic shape only when the input type is Tensor.
380        seed (int, optional): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
381          Default: ``None`` , which will be treated as 0.
382
383    Returns:
384        Tensor. The shape that the input 'shape' denotes. The dtype is float32.
385
386    Raises:
387        TypeError: If `shape` is neither a tuple nor a Tensor.
388        ValueError: If `shape` is a tuple containing non-positive items.
389
390    Supported Platforms:
391        ``Ascend`` ``GPU`` ``CPU``
392
393    Examples:
394        >>> from mindspore import ops
395        >>> shape = (4, 4)
396        >>> output = ops.standard_normal(shape)
397        >>> result = output.shape
398        >>> print(result)
399        (4, 4)
400    """
401    seed1, seed2 = _get_seed(seed, "standard_normal")
402    standard_normal_op = P.StandardNormal(seed=seed1, seed2=seed2)
403    standard_normal_op = _set_prim_op_user_data(
404        standard_normal_op, "random_cache", False)
405    return standard_normal_op(shape)
406
407
408@_function_forbid_reuse
409def uniform_candidate_sampler(true_classes,
410                              num_true,
411                              num_sampled,
412                              unique,
413                              range_max,
414                              seed=0,
415                              remove_accidental_hits=False):
416    r"""
417    Uniform candidate sampler.
418
419    This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
420    If unique=True, candidates are drawn without replacement, else unique=False with replacement.
421
422    Args:
423        true_classes (Tensor): A Tensor. The target classes with a Tensor shape of :math:`(batch\_size, num\_true)` .
424        num_true (int): The number of target classes in each training example.
425        num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
426            of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
427        unique (bool): Whether all sampled classes in a batch are unique.
428        range_max (int): The number of possible classes, must be positive.
429        seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
430            the seed will be replaced with a randomly generated value. Default: ``0`` .
431        remove_accidental_hits (bool): Whether accidental hit is removed.
432            Accidental hit is when one of the true classes matches one of the sample classes.
433            Set ``True`` to remove which accidentally sampling the true class as sample class. Default: ``False`` .
434
435    Returns:
436        - **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
437          shape: :math:`(num\_sampled, )` .
438        - **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each
439          of true_classes. shape: :math:`(batch\_size, num\_true)` .
440        - **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of
441          each of sampled_candidates. shape: :math:`(num\_sampled, )` .
442
443    Raises:
444        TypeError: If neither `num_true` nor `num_sampled` is an int.
445        TypeError: If neither `unique` nor `remove_accidental_hits` is a bool.
446        TypeError: If neither `range_max` nor `seed` is an int.
447        TypeError: If `true_classes` is not a Tensor.
448
449    Supported Platforms:
450        ``Ascend`` ``GPU`` ``CPU``
451
452    Examples:
453        >>> import numpy as np
454        >>> from mindspore import Tensor, ops
455        >>> data = Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int64))
456        >>> output1, output2, output3 = ops.uniform_candidate_sampler(data, 1, 3, False, 4, 1)
457        >>> print(output1.shape)
458        (3,)
459        >>> print(output2.shape)
460        (5, 1)
461        >>> print(output3.shape)
462        (3,)
463    """
464    sampler_op = P.UniformCandidateSampler(num_true,
465                                           num_sampled,
466                                           unique,
467                                           range_max,
468                                           seed=seed,
469                                           remove_accidental_hits=remove_accidental_hits)
470    sampler_op = _set_prim_op_user_data(sampler_op, "random_cache", False)
471    sampled_candidates, true_expected_count, sampled_expected_count = sampler_op(
472        true_classes)
473    return sampled_candidates, true_expected_count, sampled_expected_count
474
475
476@_function_forbid_reuse
477def random_poisson(shape, rate, seed=None, dtype=mstype.float32):
478    r"""
479    Generates random number Tensor with shape `shape` according to a Poisson distribution with mean `rate`.
480
481
482    .. math::
483
484        \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
485
486    Args:
487        shape (Tensor): The shape of random tensor to be sampled from each poisson distribution, 1-D `Tensor` whose
488            dtype is mstype.int32 or mstype.int64.
489        rate (Tensor): The :math:`μ` parameter the distribution is constructed with.
490            It represents the mean of the distribution
491            and also the variance of the distribution. It should be a `Tensor` whose dtype is mstype.int64,
492            mstype.int32, mstype.float64, mstype.float32 or mstype.float16.
493        seed (int, optional): Seed is used as entropy source for the random number engines to generate pseudo-random
494            numbers and must be non-negative. Default: ``None`` , which will be treated as 0.
495        dtype (mindspore.dtype): The data type of output: ``mstype.int64``, ``mstype.int32``,
496            ``mstype.float64``, ``mstype.float32`` or ``mstype.float16``. Default: ``mstype.float32``.
497
498    Returns:
499        A Tensor whose shape is `mindspore.concat(['shape', mindspore.shape('rate')], axis=0)` and data type is equal to
500        argument `dtype`.
501
502    Raises:
503        TypeError: If `shape` is not a Tensor.
504        TypeError: If datatype of `shape` is not mstype.int64 nor mstype.int32.
505        ValueError: If shape of `shape` is not 1-D.
506        TypeError: If `rate` is not a Tensor nor a scalar.
507        TypeError: If datatype of `rate` is not in [mstype.int64, mstype.int32,
508            mstype.float64, mstype.float32 or mstype.float16].
509        TypeError: If `seed` is not a non-negtive int.
510        TypeError: If `dtype` is not in [mstype.int64, mstype.int32, mstype.float64,
511            mstype.float32 nor mstype.float16].
512        ValueError: If any element of input `shape` tensor is not positive.
513
514    Supported Platforms:
515        ``GPU`` ``CPU``
516
517    Examples:
518        >>> import mindspore
519        >>> import numpy as np
520        >>> from mindspore import Tensor, ops
521        >>> # case 1: 1-D shape, 2-D rate, float64 output
522        >>> shape = Tensor(np.array([2, 2]), mindspore.int64)
523        >>> rate = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), mindspore.float32)
524        >>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.float64)
525        >>> print(output.shape, output.dtype)
526        (2, 2, 2, 2) Float64
527        >>> # case 2: 1-D shape, scalar rate, int64 output
528        >>> shape = Tensor(np.array([2, 2]), mindspore.int64)
529        >>> rate = Tensor(5.0, mindspore.float64)
530        >>> output = ops.random_poisson(shape, rate, seed=5, dtype=mindspore.int64)
531        >>> print(output.shape, output.dtype)
532        (2, 2) Int64
533    """
534    seed1, seed2 = _get_seed(seed, "random_poisson")
535    prim_random_poisson = P.RandomPoisson(seed1, seed2, dtype)
536    prim_random_poisson = _set_prim_op_user_data(
537        prim_random_poisson, "random_cache", False)
538    value = prim_random_poisson(shape, rate)
539    return value
540
541
542@_function_forbid_reuse
543def shuffle(x, seed=None):
544    r"""
545    Randomly shuffles a Tensor along its first dimension.
546
547    Args:
548        x (Tensor): The Tensor need be shuffled.
549        seed (int, optional): Random seed used for random number generation, must be non-negative. If `seed` is 0,
550            which will be replaced with a randomly generated value. Default: ``None`` , which will be treated as 0.
551
552    Returns:
553        Tensor. The shape and type are the same as the input `x`.
554
555    Raises:
556        TypeError: If data type of `seed` is not None or non-negative int.
557
558    Supported Platforms:
559        ``Ascend`` ``GPU`` ``CPU``
560
561    Examples:
562        >>> import numpy as np
563        >>> from mindspore import Tensor, ops
564        >>> from mindspore import dtype as mstype
565        >>> x = Tensor(np.array([1, 2, 3, 4]), mstype.float32)
566        >>> output = ops.shuffle(x, seed=1)
567        >>> print(output)
568        [3. 4. 2. 1.]
569    """
570    seed, seed2 = _get_seed(seed, "shuffle")
571    random_shuffle_ = RandomShuffle(seed=seed, seed2=seed2)
572    random_shuffle_ = _set_prim_op_user_data(
573        random_shuffle_, "random_cache", False)
574    output = random_shuffle_(x)
575    return output
576
577
578@_function_forbid_reuse
579def log_uniform_candidate_sampler(true_classes, num_true=1, num_sampled=5, unique=True, range_max=5, seed=0):
580    r"""
581    Generates random labels with a log-uniform distribution for sampled_candidates.
582
583    Randomly samples a tensor of sampled classes from the range of integers [0, range_max).
584
585    Args:
586        true_classes (Tensor): The target classes. With data type of int64 and
587          shape :math:`(batch\_size, num\_true)` .
588        num_true (int): The number of target classes per training example. Default: ``1`` .
589        num_sampled (int): The number of classes to randomly sample. Default: ``5`` .
590        unique (bool): Determines whether sample with rejection. If `unique` is ``True`` ,
591          all sampled classes in a batch are unique. Default: ``True`` .
592        range_max (int): The number of possible classes. When `unique` is ``True`` ,
593          `range_max` must be greater than or equal to `num_sampled`. Default: ``5`` .
594        seed (int): Random seed, must be non-negative. Default: ``0`` .
595
596    Returns:
597        Tuple of 3 Tensors.
598
599        - **sampled_candidates** (Tensor) - A Tensor with shape :math:`(num\_sampled,)`
600          and the same type as `true_classes`.
601        - **true_expected_count** (Tensor) - A Tensor with the same shape as `true_classes and` type float32.
602        - **sampled_expected_count** (Tensor) - A Tensor with the same shape as `sampled_candidates` and type float32.
603
604    Raises:
605        TypeError: If neither `num_true` nor `num_sampled` is an int.
606        TypeError: If `unique` is not a bool.
607        TypeError: If neither `range_max` nor `seed` is an int.
608        TypeError: If `true_classes` is not a Tensor.
609
610    Supported Platforms:
611        ``Ascend`` ``CPU``
612
613    Examples:
614        >>> import numpy as np
615        >>> from mindspore import Tensor, ops
616        >>> output1, output2, output3 = ops.log_uniform_candidate_sampler(
617        ... Tensor(np.array([[1, 7], [0, 4], [3, 3]])), 2, 5, True, 5)
618        >>> print(output1, output2, output3)
619        [3 2 0 4 1]
620        [[0.92312991 0.49336370]
621         [0.99248987 0.65806371]
622         [0.73553443 0.73553443]]
623        [0.73553443 0.82625800 0.99248987 0.65806371 0.92312991]
624
625    """
626
627    sampler = P.LogUniformCandidateSampler(
628        num_true, num_sampled, unique, range_max, seed)
629    sampler = _set_prim_op_user_data(sampler, "random_cache", False)
630    return sampler(true_classes)
631
632
633@_function_forbid_reuse
634def choice_with_mask(input_x, count=256, seed=None):
635    """
636    Generates a random sample as index tensor with a mask tensor from a given tensor.
637
638    The `input_x` must be a tensor whose dimension is not less than 1. If its dimension is greater than or equal to 2,
639    the first dimension specifies the number of samples.
640    The returned index tensor denotes the index of the nonzero
641    sample, the mask tensor denotes which elements in the index tensor are valid.
642
643    Args:
644        input_x (Tensor[bool]): The input tensor.
645            The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
646        count (int, optional): Number of items expected to get and the number must be greater than 0. Default: ``256`` .
647        seed (int, optional): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
648            Default: ``None`` , which will be treated as 0.
649
650    Returns:
651        Two tensors, the first one is the index tensor and the other one is the mask tensor.
652
653        - **index** (Tensor) - The output shape is 2-D.
654        - **mask** (Tensor) - The output shape is 1-D.
655
656    Raises:
657        TypeError: If `count` is not an int.
658        TypeError: If `seed` is not an int.
659        TypeError: If `input_x` is not a Tensor.
660
661    Supported Platforms:
662        ``Ascend`` ``GPU`` ``CPU``
663
664    Examples:
665        >>> import numpy as np
666        >>> from mindspore import Tensor, ops
667        >>> input_x = Tensor(np.ones(shape=[240000, 4]).astype(np.bool_))
668        >>> output_y, output_mask = ops.choice_with_mask(input_x)
669        >>> result = output_y.shape
670        >>> print(result)
671        (256, 2)
672        >>> result = output_mask.shape
673        >>> print(result)
674        (256,)
675    """
676    seed1, seed2 = _get_seed(seed, "choice_with_mask")
677    choice_with_mask_ = RandomChoiceWithMask(
678        count=count, seed=seed1, seed2=seed2)
679    choice_with_mask_ = _set_prim_op_user_data(
680        choice_with_mask_, "random_cache", False)
681    output = choice_with_mask_(input_x)
682    return output
683
684
685@constexpr
686def is_cpu_backend():
687    """Check if the CPU is used"""
688    return context.get_context('device_target') == 'CPU'
689
690
691def normal_ext(mean=0.0, std=1.0, size=None, generator=None):
692    r"""
693    Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
694
695    Args:
696        mean (Union[float, Tensor], optional): Mean value of each element, the shape of the 'mean' tensor
697            should be the same as that of the 'std' tensor. Default: ``0.0``.
698        std (Union[float, Tensor], optional): Standard deviation for each element, the shape of the 'std' tensor
699            should be the same as that of the 'mean' tensor. The value of std should be greater than or equal to 0.
700            Default: ``1.0``.
701        size (tuple, optional): output size, where 'mean' and 'std' are constants. Default: ``None``.
702        generator (generator, optional): MindSpore generator. Default: ``None``.
703
704    Returns:
705        Outputs a tensor with the same shape as 'mean',
706        or when 'mean' and 'std' are constants and shape is specified as 'size'.
707
708    Raises:
709        TypeError: If `mean` or `std` is not Union[float, Tensor].
710
711    Supported Platforms:
712        ``Ascend``
713
714    Examples:
715        >>> import mindspore
716        >>> import numpy as np
717        >>> from mindspore import ops
718        >>> from mindspore import Tensor
719        >>> mean = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
720        >>> std = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32)
721        >>> output = ops.function.random_func.normal_ext(mean, std)
722        >>> print(output.shape)
723        (3,)
724    """
725    if generator is None:
726        generator = default_generator
727    seed, offset = generator._step(generator_step_)  # pylint: disable=protected-access
728
729    is_mean_tensor = isinstance(mean, Tensor)
730    is_std_tensor = isinstance(std, Tensor)
731
732    if is_mean_tensor and is_std_tensor:
733        return normal_tensor_tensor_op(mean, std, seed, offset)
734    if is_mean_tensor and not is_std_tensor:
735        return normal_tensor_float_op(mean, std, seed, offset)
736    if not is_mean_tensor and is_std_tensor:
737        return normal_float_tensor_op(mean, std, seed, offset)
738    return normal_float_float_op(mean, std, size, seed, offset)
739
740
741@_function_forbid_reuse
742def normal(shape, mean, stddev, seed=None):
743    """
744    Generates random numbers according to the Normal (or Gaussian) random number distribution.
745
746    Args:
747        shape (tuple): The shape of random tensor to be generated.
748          The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
749        mean (Union[Tensor, int, float]): The mean μ distribution parameter, which specifies the location of the peak.
750        stddev (Union[Tensor, int, float]): The deviation σ distribution parameter. It should be greater than 0.
751        seed (int): Seed is used as entropy source for the Random number engines to generate pseudo-random numbers.
752          The value must be non-negative. Default: ``None`` , which will be treated as 0.
753
754    Returns:
755        Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
756        of `mean` and `stddev`.
757        The dtype is [float32, float64].
758
759    Supported Platforms:
760        ``Ascend`` ``GPU`` ``CPU``
761
762    Examples:
763        >>> import mindspore
764        >>> import numpy as np
765        >>> from mindspore import Tensor, ops
766        >>> shape = (3, 1, 2)
767        >>> mean = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
768        >>> stddev = Tensor(1.0, mindspore.float32)
769        >>> output = ops.normal(shape, mean, stddev, seed=5)
770        >>> result = output.shape
771        >>> print(result)
772        (3, 2, 2)
773        >>> shape = (3, 1, 3)
774        >>> mean = Tensor(np.array([[3, 4, 3], [3, 5, 6]]), mindspore.float32)
775        >>> stddev = Tensor(1.0, mindspore.float32)
776        >>> output = ops.normal(shape, mean, stddev, seed=5)
777        >>> result = output.shape
778        >>> print(result)
779        (3, 2, 3)
780        >>> shape = (3, 1, 3)
781        >>> mean = Tensor(np.array([[1, 2, 3], [3, 4, 3], [3, 5, 6]]), mindspore.float32)
782        >>> stddev = Tensor(1.0, mindspore.float32)
783        >>> output = ops.normal(shape, mean, stddev, seed=5)
784        >>> result = output.shape
785        >>> print(result)
786        (3, 3, 3)
787    """
788    _check_param("normal", "mean", mean)
789    _check_param("normal", "stddev", stddev)
790    if not isinstance(mean, Tensor):
791        mean = Tensor(mean)
792    if not isinstance(stddev, Tensor):
793        stddev = Tensor(stddev)
794    seed1, seed2 = _get_seed(seed, "normal")
795    stdnormal = P.StandardNormal(seed1, seed2)
796    stdnormal = _set_prim_op_user_data(stdnormal, "random_cache", False)
797    _check_shape(shape)
798    random_normal = stdnormal(shape)
799    value = random_normal * stddev + mean
800    return value
801
802
803@_function_forbid_reuse
804def laplace(shape, mean, lambda_param, seed=None):
805    r"""
806    Generates random numbers according to the Laplace random number distribution.
807    It is defined as:
808
809    .. math::
810        \text{f}(x;μ,λ) = \frac{1}{2λ}\exp(-\frac{|x-μ|}{λ}),
811
812    Args:
813        shape (tuple): The shape of random tensor to be generated.
814          The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
815        mean (Tensor): The mean μ distribution parameter, which specifies the location of the peak.
816          With float32 data type.
817        lambda_param (Tensor): The parameter used for controlling the variance of this random distribution. The
818          variance of Laplace distribution is equal to twice the square of lambda_param. With float32 data type.
819        seed (int, optional): Seed is used as entropy source for Random number engines generating pseudo-random numbers.
820          Default: ``None`` , which will be treated as 0.
821
822    Returns:
823        Tensor. The shape should be the broadcasted shape of input `shape` and shapes of `mean` and `lambda_param`.
824        The dtype is float32.
825
826    Supported Platforms:
827        ``Ascend`` ``GPU`` ``CPU``
828
829    Examples:
830        >>> import mindspore
831        >>> from mindspore import Tensor
832        >>> from mindspore import ops as ops
833        >>> shape = (2, 3)
834        >>> mean = Tensor(1.0, mindspore.float32)
835        >>> lambda_param = Tensor(1.0, mindspore.float32)
836        >>> output = ops.laplace(shape, mean, lambda_param, seed=5)
837        >>> print(output.shape)
838        (2, 3)
839    """
840    mean_dtype = F.dtype(mean)
841    lambda_param_dtype = F.dtype(lambda_param)
842    const_utils.check_tensors_dtype_same(mean_dtype, mstype.float32, "laplace")
843    const_utils.check_tensors_dtype_same(
844        lambda_param_dtype, mstype.float32, "laplace")
845    seed1, seed2 = _get_seed(seed, "laplace")
846    stdlaplace = P.StandardLaplace(seed1, seed2)
847    stdlaplace = _set_prim_op_user_data(stdlaplace, "random_cache", False)
848    _check_shape(shape)
849    rnd = stdlaplace(shape)
850    value = rnd * lambda_param + mean
851    return value
852
853
854@_function_forbid_reuse
855def gamma(shape, alpha, beta, seed=None):
856    r"""
857    Generates random numbers according to the Gamma random number distribution.
858
859    Args:
860        shape (tuple): The shape of random tensor to be generated.
861        alpha (Tensor): The :math:`\alpha` distribution parameter. It should be greater than 0 with float32 data type.
862        beta (Tensor): The :math:`\beta` distribution parameter. It should be greater than 0 with float32 data type.
863        seed (int, optional): Seed is used as entropy source for the random number engines to generate
864            pseudo-random numbers, must be non-negative. Default: ``None`` , which will be treated as ``0`` .
865
866    Returns:
867        Tensor. The shape should be equal to the broadcasted shape between the input `shape` and shapes
868        of `alpha` and `beta`.
869        The dtype is float32.
870
871    Raises:
872        TypeError: If `shape` is not a tuple.
873        TypeError: If neither `alpha` nor `beta` is a Tensor.
874        TypeError: If `seed` is not an int.
875        TypeError: If dtype of `alpha` and `beta` is not float32.
876
877    Supported Platforms:
878        ``Ascend``
879
880    Examples:
881        >>> import mindspore
882        >>> import numpy as np
883        >>> from mindspore import Tensor, ops
884        >>> # case 1: alpha_shape is (2, 2)
885        >>> shape = (3, 1, 2)
886        >>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
887        >>> beta = Tensor(np.array([1.0]), mindspore.float32)
888        >>> output = ops.gamma(shape, alpha, beta, seed=5)
889        >>> result = output.shape
890        >>> print(result)
891        (3, 2, 2)
892        >>> # case 2: alpha_shape is (2, 3), so shape is (3, 1, 3)
893        >>> shape = (3, 1, 3)
894        >>> alpha = Tensor(np.array([[1, 3, 4], [2, 5, 6]]), mindspore.float32)
895        >>> beta = Tensor(np.array([1.0]), mindspore.float32)
896        >>> output = ops.gamma(shape, alpha, beta, seed=5)
897        >>> result = output.shape
898        >>> print(result)
899        (3, 2, 3)
900        >>> # case 3: beta_shape is (1, 2), the output is different.
901        >>> shape = (3, 1, 2)
902        >>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
903        >>> beta = Tensor(np.array([1.0, 2]), mindspore.float32)
904        >>> output = ops.gamma(shape, alpha, beta, seed=5)
905        >>> print(output)
906        [[[ 2.2132034  5.8855834]
907          [ 3.8825176  8.6066265]]
908         [[ 3.3981476  7.5805717]
909          [ 3.7190282 19.941492 ]]
910         [[ 2.9512358  2.5969937]
911          [ 3.786061   5.160872 ]]]
912        >>> # case 4: beta_shape is (2, 1), the output is different.
913        >>> shape = (3, 1, 2)
914        >>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mindspore.float32)
915        >>> beta = Tensor(np.array([[1.0], [2.0]]), mindspore.float32)
916        >>> output = ops.gamma(shape, alpha, beta, seed=5)
917        >>> print(output)
918        [[[ 5.6085486  7.8280783]
919         [ 15.97684  16.116285]]
920        [[ 1.8347423  1.713663]
921         [ 3.2434065 15.667398]]
922        [[ 4.2922077  7.3365674]
923         [ 5.3876944  13.159832 ]]]
924    """
925    seed1, seed2 = _get_seed(seed, "gamma")
926    gamma_v = P.Gamma(seed1, seed2)
927    gamma_v = _set_prim_op_user_data(gamma_v, "random_cache", False)
928    value = gamma_v(shape, alpha, beta)
929    return value
930
931
932@_primexpr
933def _generate_shapes(shape):
934    """Generate shapes for randn and rand."""
935    if not shape:
936        size = (1,)
937    elif len(shape) == 1:
938        if isinstance(shape[0], int):
939            size = shape
940        elif isinstance(shape[0], list):
941            size = tuple(shape[0])
942        elif isinstance(shape[0], tuple):
943            size = shape[0]
944        else:
945            raise TypeError(f"If the length of the argument 'shape' is 1, the type of the argument 'shape' must be "
946                            f"one of ['int', 'list', 'tuple'], but got {shape[0]}.")
947    else:
948        for value in shape:
949            if not isinstance(value, int):
950                raise TypeError(f"If the length of the argument 'shape' is > 1, the type of the argument 'shape' must "
951                                f"all be int, but got {value}.")
952        size = shape
953    return size
954
955
956@_function_forbid_reuse
957def rand(*size, dtype=None, seed=None):
958    r"""
959    Returns a new tensor that fills numbers from the uniform distribution over an interval :math:`[0, 1)`
960    based on the given shape and dtype.
961
962    Args:
963        size (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g. :math:`(2, 3)` or :math:`2`.
964
965    Keyword Args:
966        dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
967            `mindspore.float32` will be applied. Default: ``None`` .
968        seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and ``0`` will be used.
969
970    Returns:
971        Tensor, with the designated shape and dtype, filled with random numbers from the uniform distribution on
972        the interval :math:`[0, 1)`.
973
974    Raises:
975        TypeError: `seed` is not a non-negative integer.
976        ValueError: If `dtype` is not a `mstype.float_type` type.
977
978    Supported Platforms:
979        ``Ascend`` ``GPU`` ``CPU``
980
981    Examples:
982        >>> from mindspore import ops
983        >>> print(ops.rand((2,3)))
984        [[4.1702199e-01 9.9718481e-01 7.2032452e-01]
985         [9.3255734e-01 1.1438108e-04 1.2812445e-01]]
986    """
987    if dtype is None:
988        dtype = mstype.float32
989    elif dtype not in mstype.float_type:
990        raise ValueError(
991            f"For 'rand', the 'dtype' must be a float type, but got {dtype}.")
992    shape = _generate_shapes(size)
993    seed1, seed2 = _get_seed(seed, 'rand')
994    rand_op = P.UniformReal(seed1, seed2)
995    rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
996    output = rand_op(shape)
997    return cast_(output, dtype)
998
999
1000@_function_forbid_reuse
1001def rand_like(input, seed=None, *, dtype=None):
1002    r"""
1003    Returns a new tensor that fills numbers from the uniform distribution over an interval :math:`[0, 1)`
1004    based on the given shape and dtype.
1005
1006    Args:
1007        input (Tensor): Input Tensor to specify the output shape and its default dtype.
1008        seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and ``0`` will be used.
1009
1010    Keyword Args:
1011        dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
1012            the same dtype of `input` will be applied. Default: ``None`` .
1013
1014    Returns:
1015        Tensor, with the designated shape and dtype, filled with random numbers from the uniform distribution on
1016        the interval :math:`[0, 1)`.
1017
1018    Raises:
1019        TypeError: If `seed` is not a non-negative integer.
1020        ValueError: If `dtype` is not a `mstype.float_type` type.
1021
1022    Supported Platforms:
1023        ``Ascend`` ``GPU`` ``CPU``
1024
1025    Examples:
1026        >>> import mindspore as ms
1027        >>> from mindspore import Tensor, ops
1028        >>> a = Tensor([[2, 3, 4], [1, 2, 3]])
1029        >>> print(ops.rand_like(a, dtype=ms.float32))
1030        [[4.1702199e-01 9.9718481e-01 7.2032452e-01]
1031         [9.3255734e-01 1.1438108e-04 1.2812445e-01]]
1032    """
1033    if not isinstance(input, Tensor):
1034        raise TypeError(
1035            f"For 'rand_like', the 'input' must be a Tensor, but got {type(input)}")
1036    if dtype is None:
1037        dtype = input.dtype
1038    if dtype not in mstype.float_type:
1039        raise ValueError(
1040            f"For 'rand_like', the 'dtype' must be a float type, but got {dtype}.")
1041    shape = input.shape
1042    seed1, seed2 = _get_seed(seed, 'rand_like')
1043    rand_op = P.UniformReal(seed1, seed2)
1044    rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
1045    output = rand_op(shape)
1046    return cast_(output, dtype)
1047
1048
1049@_function_forbid_reuse
1050def rand_ext(*size, generator=None, dtype=None):
1051    r"""
1052    Returns a new tensor that fills numbers from the uniform distribution over an interval :math:`[0, 1)`
1053    based on the given shape and dtype.
1054
1055    Args:
1056        size (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g. :math:`(2, 3)` or :math:`2`.
1057
1058    Keyword Args:
1059        generator (:class:`mindspore.Generator`, optional): a pseudorandom number generator.
1060            Default: ``None``, uses the default pseudorandom number generator.
1061        dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
1062            `mindspore.float32` will be applied. Default: ``None`` .
1063
1064    Returns:
1065        Tensor, with the designated shape and dtype, filled with random numbers from the uniform distribution on
1066        the interval :math:`[0, 1)`.
1067
1068    Raises:
1069        ValueError: If `dtype` is not a `mstype.float_type` type.
1070
1071    Supported Platforms:
1072        ``Ascend``
1073
1074    Examples:
1075        >>> import mindspore.ops as ops
1076        >>> print(ops.function.random_func.rand_ext(2, 3).shape)
1077        (2, 3)
1078    """
1079    if not generator:
1080        generator = default_generator
1081    seed, offset = generator._step(generator_step_)  # pylint: disable=protected-access
1082    return rand_ext_(size, seed, offset, dtype)
1083
1084
1085@_function_forbid_reuse
1086def rand_like_ext(input, *, dtype=None):
1087    r"""
1088    Returns a new tensor that fills numbers from the uniform distribution over an interval :math:`[0, 1)`
1089    based on the given dtype and shape of the input tensor.
1090
1091    Args:
1092        input (Tensor): Input Tensor to specify the output shape and its default dtype.
1093
1094    Keyword Args:
1095        dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
1096            the same dtype of `input` will be applied. Default: ``None`` .
1097
1098    Returns:
1099        Tensor, with the designated shape and dtype, filled with random numbers from the uniform distribution on
1100        the interval :math:`[0, 1)`.
1101
1102    Raises:
1103        ValueError: If `dtype` is not a `mstype.float_type` type.
1104
1105    Supported Platforms:
1106        ``Ascend``
1107
1108    Examples:
1109        >>> import mindspore as ms
1110        >>> from mindspore import Tensor, ops
1111        >>> a = Tensor([[2, 3, 4], [1, 2, 3]])
1112        >>> print(ops.function.random_func.rand_like_ext(a, dtype=ms.float32).shape)
1113        (2, 3)
1114    """
1115    seed, offset = default_generator._step(generator_step_)  # pylint: disable=protected-access
1116    return rand_like_ext_(input, seed, offset, dtype)
1117
1118
1119@_function_forbid_reuse
1120def randn(*size, dtype=None, seed=None):
1121    r"""
1122    Returns a new Tensor with given shape and dtype, filled with a sample (or samples)
1123    from the standard normal distribution.
1124
1125    Args:
1126        size (Union[int, tuple(int), list(int)]): Shape of the new tensor, e.g., :math:`(2, 3)` or :math:`2`.
1127
1128    Keyword Args:
1129        dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
1130            `mindspore.float32` will be used. Default: ``None`` .
1131        seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and 0 will be used.
1132
1133    Returns:
1134        Tensor, with the designated shape and dtype, filled with a sample (or samples) from the
1135        "standard normal" distribution.
1136
1137    Raises:
1138        TypeError: `seed` is not a non-negative integer.
1139        ValueError: If `dtype` is not a `mstype.float_type`.
1140        ValueError: If `size` contains invalid number.
1141
1142    Supported Platforms:
1143        ``Ascend`` ``GPU`` ``CPU``
1144
1145    Examples:
1146        >>> from mindspore import ops
1147        >>> print(ops.randn((2, 2)))
1148        [[ 0.30639967 -0.42438635]
1149         [-0.4287376   1.3054721 ]]
1150    """
1151    if dtype is None:
1152        dtype = mstype.float32
1153    elif dtype not in mstype.float_type:
1154        raise ValueError(
1155            f"For 'randn', the 'dtype' must be a float type, but got {dtype}.")
1156    shape = _generate_shapes(size)
1157    seed1, seed2 = _get_seed(seed, 'randn')
1158    rand_op = P.StandardNormal(seed1, seed2)
1159    rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
1160    output = rand_op(shape)
1161    return cast_(output, dtype)
1162
1163
1164@_function_forbid_reuse
1165def randn_like(input, seed=None, *, dtype=None):
1166    r"""
1167    Returns a new Tensor with given shape and dtype, filled with a sample (or samples) from the standard normal
1168    distribution.
1169
1170    Args:
1171        input (Tensor): Input Tensor to specify the output shape and its default dtype.
1172        seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and 0 will be used.
1173
1174    Keyword Args:
1175        dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be float type. If None,
1176            `mindspore.float32` will be used. Default: ``None`` .
1177
1178    Returns:
1179        Tensor, with the designated shape and dtype, filled with a sample (or samples) from the
1180        "standard normal" distribution.
1181
1182    Raises:
1183        TypeError: `seed` is not a non-negative integer.
1184        ValueError: If `dtype` is not a `mstype.float_type`.
1185
1186    Supported Platforms:
1187        ``Ascend`` ``GPU`` ``CPU``
1188
1189    Examples:
1190        >>> import mindspore as ms
1191        >>> from mindspore import Tensor, ops
1192        >>> a = Tensor([[1, 2, 3], [4, 5, 6]])
1193        >>> print(ops.randn_like(a, dtype=ms.float32))
1194        [[ 0.30639967 -0.42438635 -0.20454668]
1195         [-0.4287376   1.3054721   0.64747655]]
1196    """
1197    if not isinstance(input, Tensor):
1198        raise TypeError(
1199            f"For 'randn_like', the 'input' must be a Tensor, but got {type(input)}")
1200    if dtype is None:
1201        dtype = mstype.float32
1202    if dtype not in mstype.float_type:
1203        raise ValueError(
1204            f"For 'randn_like', the 'dtype' must be a float type, but got {dtype}.")
1205    shape = input.shape
1206    seed1, seed2 = _get_seed(seed, 'randn_like')
1207    rand_op = P.StandardNormal(seed1, seed2)
1208    rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
1209    output = rand_op(shape)
1210    return cast_(output, dtype)
1211
1212
1213@_function_forbid_reuse
1214def randint(low, high, size, seed=None, *, dtype=None):
1215    r"""
1216    Returns a Tensor whose elements are random integers in the range of [ `low` , `high` ) .
1217
1218    Args:
1219        low (int): Start value of interval.
1220        high (int): End value of interval.
1221        size (tuple): Shape of the new tensor.
1222        seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and ``0`` will be used.
1223
1224    Keyword Args:
1225        dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be int type. If ``None`` ,
1226            `mindspore.int64` will be used. Default: ``None`` .
1227
1228    Returns:
1229        Tensor, with the designated shape and dtype, filled with random integers from low (inclusive)
1230        to high (exclusive).
1231
1232    Raises:
1233        TypeError: `seed` is not a non-negative integer.
1234        TypeError: `size` is not a tuple.
1235        TypeError: `low` or `high` is not an integer.
1236        ValueError: If `dtype` is not a `mstype.int_type`.
1237
1238
1239    Supported Platforms:
1240        ``Ascend`` ``GPU`` ``CPU``
1241
1242    Examples:
1243        >>> from mindspore import ops
1244        >>> print(ops.randint(1, 10, (2,3)))
1245        [[4 9 7]
1246         [9 1 2]]
1247    """
1248    if dtype is None:
1249        dtype = mstype.int64
1250    elif dtype not in mstype.int_type:
1251        raise ValueError(
1252            f"For 'randint', the 'dtype' must be an int type, but got {dtype}.")
1253    if not isinstance(size, tuple):
1254        raise ValueError(
1255            f"For 'randint', the input 'size' must be a tuple, but got {size}.")
1256    if not isinstance(low, int) or isinstance(low, bool):
1257        raise TypeError(
1258            f"For 'randint_like', 'low' must be an int, but got {type(low)}.")
1259    if not isinstance(high, int) or isinstance(high, bool):
1260        raise TypeError(
1261            f"For 'randint_like', 'high' must be an int, but got {type(high)}.")
1262    seed1, seed2 = _get_seed(seed, 'randint')
1263    rand_op = P.UniformInt(seed1, seed2)
1264    rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
1265    low_ = Tensor(low, mstype.int32)
1266    high_ = Tensor(high, mstype.int32)
1267    output = rand_op(size, low_, high_)
1268    return cast_(output, dtype)
1269
1270
1271@_function_forbid_reuse
1272def randint_like(input, low, high, seed=None, *, dtype=None):
1273    r"""
1274    Returns a tensor with the same shape as Tensor `input` whose elements are random integers in the range
1275    of [ `low` , `high` ) .
1276
1277    Args:
1278        input (Tensor): Input Tensor to specify the output shape and its default dtype.
1279        low(int): Start value of interval.
1280        high(int): End value of interval.
1281        seed (int, optional): Random seed, must be greater or equal to 0. Default: ``None`` , and 0 will be used.
1282
1283    Keyword Args:
1284        dtype (:class:`mindspore.dtype`, optional): Designated tensor dtype, it must be int type. If ``None`` ,
1285            the same dtype of `input` will be applied. Default: ``None`` .
1286
1287    Returns:
1288        Tensor, with the designated shape and dtype, filled with random integers from low (inclusive)
1289        to high (exclusive).
1290
1291    Raises:
1292        TypeError: `seed` is not a non-negative integer.
1293        TypeError: `low` or `high` is not an integer.
1294        ValueError: If `dtype` is not a `mstype.int_type`.
1295
1296    Supported Platforms:
1297        ``Ascend`` ``GPU`` ``CPU``
1298
1299    Examples:
1300       >>> from mindspore import Tensor, ops
1301       >>> a = Tensor([[1, 2, 3], [3, 2, 1]])
1302       >>> print(ops.randint_like(a, 1, 10))
1303       [[4 9 7]
1304        [9 1 2]]
1305    """
1306    if not isinstance(input, Tensor):
1307        raise TypeError(
1308            f"For 'randint_like', the 'input' must be a Tensor, but got {type(input)}")
1309    if dtype is None:
1310        dtype = input.dtype
1311    if dtype not in mstype.int_type:
1312        raise ValueError(
1313            f"For 'randint_like', the 'dtype' must be an int type, but got {dtype}.")
1314    if not isinstance(low, int) or isinstance(low, bool):
1315        raise TypeError(
1316            f"For 'randint_like', 'low' must be an int, but got {type(low)}.")
1317    if not isinstance(high, int) or isinstance(high, bool):
1318        raise TypeError(
1319            f"For 'randint_like', 'high' must be an int, but got {type(high)}.")
1320    size = input.shape
1321    seed1, seed2 = _get_seed(seed, 'randint_like')
1322    rand_op = P.UniformInt(seed1, seed2)
1323    rand_op = _set_prim_op_user_data(rand_op, "random_cache", False)
1324    low_ = Tensor(low, mstype.int32)
1325    high_ = Tensor(high, mstype.int32)
1326    size_ = Tensor(size, mstype.int32)
1327    output = rand_op(size_, low_, high_)
1328    return cast_(output, dtype)
1329
1330
1331@_function_forbid_reuse
1332def poisson(shape, mean, seed=None):
1333    r"""
1334    The ops.poisson is deprecated, please use :class:`mindspore.ops.random_poisson`
1335    Generates random numbers according to the Poisson random number distribution.
1336
1337    .. math::
1338
1339        \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!}
1340
1341    Args:
1342        shape (tuple): The shape of random tensor to be generated.
1343          The format is :math:`(N,*)` where :math:`*` means, any number of additional dimensions.
1344        mean (Tensor): The mean μ distribution parameter. It should be greater than 0 with float32 data type.
1345        seed (int): Seed is used as entropy source for the random number engines to generate pseudo-random numbers
1346          and must be non-negative. Default: ``None`` , which will be treated as 0.
1347
1348    Returns:
1349        Tensor. The shape should be equal to the broadcasted shape between the input "shape" and shapes of `mean`.
1350        The dtype is float32.
1351
1352    Raises:
1353        TypeError: If `shape` is not a tuple.
1354        TypeError: If `mean` is not a Tensor whose dtype is not float32.
1355        TypeError: If `seed` is not an int.
1356
1357    Supported Platforms:
1358        deprecated
1359
1360    Examples:
1361        >>> from mindspore import Tensor, ops
1362        >>> import mindspore
1363        >>> # case 1: It can be broadcast.
1364        >>> shape = (4, 1)
1365        >>> mean = Tensor(np.array([5.0, 10.0]), mindspore.float32)
1366        >>> output = ops.poisson(shape, mean, seed=5)
1367        >>> result = output.shape
1368        >>> print(result)
1369        (4, 2)
1370        >>> # case 2: It can not be broadcast. It is recommended to use the same shape.
1371        >>> shape = (2, 2)
1372        >>> mean = Tensor(np.array([[5.0, 10.0], [5.0, 1.0]]), mindspore.float32)
1373        >>> output = ops.poisson(shape, mean, seed=5)
1374        >>> result = output.shape
1375        >>> print(result)
1376        (2, 2)
1377    """
1378    seed1, seed2 = _get_seed(seed, "poisson")
1379    random_poisson_op = P.Poisson(seed1, seed2)
1380    random_poisson_op = _set_prim_op_user_data(
1381        random_poisson_op, "random_cache", False)
1382    value = random_poisson_op(shape, mean)
1383    return value
1384
1385
1386@_function_forbid_reuse
1387def multinomial(input, num_samples, replacement=True, seed=None):
1388    r"""
1389    Returns a tensor sampled from the multinomial probability distribution located in the corresponding
1390    row of the input tensor.
1391
1392    The polynomial distribution is a probability distribution that generalizes the binomial distribution formula to
1393    multiple states. In the polynomial distribution, each event has a fixed probability, and the sum of these
1394    probabilities is 1. The purpose of the `mindspore.ops.multinomial` interface is to perform `num_samples` sampling
1395    on the input `input`, and the output tensor is the index of the input tensor for each sampling.
1396    The values in `input` represent the probability of selecting the corresponding index for each sampling.
1397
1398    Here is an extreme example for better understanding. Suppose we have an input probability tensor with
1399    values `Tensor([90 / 100, 10 / 100, 0], mindspore.float32)`, which means we can sample three indices,
1400    namely index 0, index 1, and index 2, with probabilities of 90%, 10%, and 0%, respectively. We perform n samplings,
1401    and the resulting sequence is the calculation result of the polynomial distribution, with a length equal to the
1402    number of samplings.
1403
1404    In case 1 of the sample code, we perform two non-replacement samplings (`replacement` is `False`).
1405    The calculation result is most likely `[0, 1]`, and less likely `[1, 0]`. Since the probability of selecting
1406    index 0 is 90% for each sampling, the first result is most likely to be index 0. Since the probability of selecting
1407    index 2 is 0, index 2 cannot appear in the sampling result. Therefore, the second result must be index 1,
1408    and the resulting sequence is `[0, 1]`.
1409
1410    In case 2 of the sample code, we perform 10 replacement samplings (`replacement` is `True`).
1411    As expected, about 90% of the sampling results are index 0.
1412
1413    In case 3 of the sample code, we extend the input to 2 dimensions, and the sampling results
1414    in each dimension also match our sampling expectations.
1415
1416    Note:
1417        The rows of input do not need to sum to one (in which case we use the values as weights),
1418        but must be non-negative, finite and have a non-zero sum. When using values as weights, it can be understood as
1419        normalizing the input along the last dimension.
1420
1421    Args:
1422        input (Tensor): The input tensor containing probabilities, must be 1 or 2 dimensions, with
1423          float32 data type.
1424        num_samples (int): Number of samples to draw.
1425        replacement (bool, optional): Whether to draw with replacement or not. Default: ``True`` .
1426        seed (int, optional): Seed is used as entropy source for the random number engines to generate
1427          pseudo-random numbers, must be non-negative. Default: ``None`` .
1428
1429    Returns:
1430        Tensor, has the same rows with input. The number of sampled indices of each row is `num_samples`.
1431        The dtype is float32.
1432
1433    Raises:
1434        TypeError: If `input` is not a Tensor whose dtype is not float32.
1435        TypeError: If `num_samples` is not an int.
1436        TypeError: If `seed` is neither an int nor None.
1437
1438    Supported Platforms:
1439        ``Ascend`` ``GPU`` ``CPU``
1440
1441    Examples:
1442        >>> import mindspore
1443        >>> from mindspore import Tensor, ops
1444        >>> from mindspore import dtype as mstype
1445        >>> # case 1: The output is random, and the length of the output is the same as num_sample.
1446        >>> # replacement is False.
1447        >>> input1 = Tensor([90 / 100, 10 / 100, 0], mindspore.float32)
1448        >>> input2 = Tensor([90, 10, 0], mindspore.float32)
1449        >>> # input1 and input2 have the same meaning.
1450        >>> output1 = ops.multinomial(input1, 2, replacement=False)
1451        >>> output2 = ops.multinomial(input2, 2, replacement=False)
1452        >>> # print(output1)
1453        >>> # [0 1]
1454        >>> # print(output2)
1455        >>> # [0 1]
1456        >>> print(len(output1))
1457        2
1458        >>> print(len(output2))
1459        2
1460        >>> # case 2: The output is random, and the length of the output is the same as num_sample.
1461        >>> # replacement is True.
1462        >>> output3 = ops.multinomial(input1, 10)
1463        >>> # print(output3)
1464        >>> # [0 0 1 0 0 0 0 0 0 0]
1465        >>> print(len(output3))
1466        10
1467        >>> # case 3: The output is random, and the length of the output is the same as num_sample.
1468        >>> # replacement is True.
1469        >>> # rank is 2
1470        >>> input4 = Tensor([[90, 10, 0], [10, 90, 0]], mstype.float32)
1471        >>> output4 = ops.multinomial(input4, 10)
1472        >>> # print(output4)
1473        >>> # [[0 0 0 0 0 0 0 0 1 0]
1474        >>> #  [1 1 1 1 1 0 1 1 1 1]]
1475    """
1476    def _check_valid_dim(dim, name):
1477        if dim not in (1, 2):
1478            raise ValueError(
1479                f"For '{name}', the dimension of inputs must be 1d or 2d, but got {dim}.")
1480
1481    _check_valid_dim(len(shape_(input)), "multinomial")
1482    seed1, seed2 = _get_seed(seed, "multinomial")
1483    if not replacement:
1484        if shape_(input)[-1] < num_samples:
1485            const_utils.raise_value_error(f"For 'multinomial', the 'num_samples' must be less than "
1486                                          f"the last dimension of input without 'replacement', "
1487                                          f"but got 'num_samples': {num_samples} and "
1488                                          f"'replacement': {replacement}")
1489        n_dist = 1
1490        if len(shape_(input)) > 1:
1491            n_dist = shape_(input)[-2]
1492        random_uniform_real = P.UniformReal(seed1, seed2)
1493        random_cache_op = _set_prim_op_user_data(
1494            random_uniform_real, "random_cache", False)
1495        random_uniform = random_cache_op((n_dist * shape_(input)[-1],))
1496        if n_dist != 1:
1497            random_uniform = reshape_(
1498                random_uniform, (n_dist, shape_(input)[-1]))
1499
1500        vals = real_div_(log_(random_uniform), input + 1e-6)
1501        _, indices = top_k_(vals, num_samples)
1502        return indices
1503    random_nomial = P.Multinomial(seed1, seed2)
1504    random_nomial = _set_prim_op_user_data(
1505        random_nomial, "random_cache", False)
1506    return random_nomial(input, num_samples)
1507
1508
1509def _check_shape(input_shape):
1510    """Check 'shape' value."""
1511    if not isinstance(input_shape, tuple):
1512        const_utils.raise_type_error(
1513            f"Type of 'shape' must be tuple, but got: {type(input_shape)}")
1514    for item in input_shape:
1515        if not isinstance(item, int):
1516            const_utils.raise_type_error(
1517                f"Elements of 'shape' must be int, but got: {type(item)}")
1518        if item < 1:
1519            const_utils.raise_value_error(
1520                f"Elements of 'shape' must be positive int, but got: {item}")
1521    return True
1522
1523
1524def _check_param(op_name, param_name, param_value):
1525    """Check type of param_value is Tensor, int, or float."""
1526    if not isinstance(param_value, (Tensor, int, float, np.ndarray)):
1527        const_utils.raise_type_error("For '{}', the type of '{}' must be Tensor, int, or float, "
1528                                     "but got: {}".format(op_name, param_name, type(param_value)))
1529    return True
1530
1531
1532__all__ = [
1533    'standard_laplace', 'random_categorical', 'uniform', 'standard_normal', 'random_gamma',
1534    'uniform_candidate_sampler', 'random_poisson', 'log_uniform_candidate_sampler', 'shuffle', 'choice_with_mask',
1535    'normal', 'laplace', 'gamma', 'poisson', 'multinomial', 'rand', 'rand_like',
1536    'randn', 'randn_like',
1537    'randint', 'randint_like', 'multinomial_with_replacement', 'randperm'
1538]
1539__all__.sort()
1540