• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Operators for random."""
16
17from ..._checkparam import Validator, Rel
18from ...common import dtype as mstype
19from ..primitive import PrimitiveWithInfer, prim_attr_register
20from .._utils import get_broadcast_shape
21
22
23class StandardNormal(PrimitiveWithInfer):
24    r"""
25    Generates random numbers according to the standard Normal (or Gaussian) random number distribution.
26
27    Returns the tensor with the given shape, the random numbers in it drawn from normal distributions
28    whose mean is 0 and standard deviation is 1.
29
30    .. math::
31        f(x)=\frac{1}{\sqrt{2 \pi}} e^{\left(-\frac{x^{2}}{2}\right)}
32
33    Args:
34        seed (int): Random seed, must be non-negative. Default: 0.
35        seed2 (int): Random seed2, must be non-negative. Default: 0.
36
37    Inputs:
38        - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
39
40    Outputs:
41        Tensor. The shape is the same as the input `shape`. The dtype is float32.
42
43    Raises:
44        TypeError: If neither `seed` nor `seed2` is an int.
45        TypeError: If `shape` is not a tuple.
46        ValueError: If `shape` is not a constant value.
47
48    Supported Platforms:
49        ``Ascend`` ``GPU`` ``CPU``
50
51    Examples:
52        >>> shape = (3, 4)
53        >>> stdnormal = ops.StandardNormal(seed=2)
54        >>> output = stdnormal(shape)
55        >>> print(output)
56        [[-1.3031056   0.64198005 -0.65207404 -1.767485  ]
57         [-0.91792876  0.6508565  -0.9098478  -0.14092612]
58         [ 0.7806437   1.1585592   1.9676613  -0.00440959]]
59    """
60
61    @prim_attr_register
62    def __init__(self, seed=0, seed2=0):
63        """Initialize StandardNormal"""
64        self.init_prim_io_names(inputs=['shape'], outputs=['output'])
65        self.add_prim_attr("_random_effect", True)
66        Validator.check_non_negative_int(seed, "seed", self.name)
67        Validator.check_non_negative_int(seed2, "seed2", self.name)
68
69    def __infer__(self, shape):
70        shape_v = shape["value"]
71        if shape_v is None:
72            raise ValueError(f"For '{self.name}', the 'shape' cannot be None.")
73        Validator.check_value_type("shape", shape_v, [tuple], self.name)
74        for i, shape_i in enumerate(shape_v):
75            Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
76        out = {
77            'shape': shape_v,
78            'dtype': mstype.float32,
79            'value': None}
80        return out
81
82
83class StandardLaplace(PrimitiveWithInfer):
84    r"""
85    Generates random numbers according to the Laplace random number distribution (mean=0, lambda=1).
86    It is defined as:
87
88    .. math::
89        \text{f}(x;0,1) = \frac{1}{2}\exp(-|x|),
90
91    Args:
92        seed (int): Random seed. Default: 0.
93        seed2 (int): Random seed2. Default: 0.
94
95    Inputs:
96        - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
97
98    Outputs:
99        Tensor. The shape that the input 'shape' denotes. The dtype is float32.
100
101    Raises:
102        TypeError: If neither `seed` nor `seed2` is an int.
103        TypeError: If `shape` is not a tuple.
104        ValueError: If `shape` is not a constant value.
105
106    Supported Platforms:
107        ``Ascend``
108
109    Examples:
110        >>> shape = (4, 16)
111        >>> stdlaplace = ops.StandardLaplace(seed=2)
112        >>> output = stdlaplace(shape)
113        >>> result = output.shape
114        >>> print(result)
115        (4, 16)
116    """
117
118    @prim_attr_register
119    def __init__(self, seed=0, seed2=0):
120        """Initialize StandardLaplace"""
121        self.init_prim_io_names(inputs=['shape'], outputs=['output'])
122        self.add_prim_attr("_random_effect", True)
123        Validator.check_value_type('seed', seed, [int], self.name)
124        Validator.check_value_type('seed2', seed2, [int], self.name)
125
126    def __infer__(self, shape):
127        shape_v = shape["value"]
128        if shape_v is None:
129            raise ValueError(f"For '{self.name}', the 'shape' cannot be None.")
130        Validator.check_value_type("shape", shape_v, [tuple], self.name)
131        for i, shape_i in enumerate(shape_v):
132            Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
133        out = {
134            'shape': shape_v,
135            'dtype': mstype.float32,
136            'value': None}
137        return out
138
139
140class Gamma(PrimitiveWithInfer):
141    r"""
142    Produces random positive floating-point values x, distributed according to probability density function:
143
144    .. math::
145        \text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}}
146
147    Args:
148        seed (int): Random seed, must be non-negative. Default: 0.
149        seed2 (int): Random seed2, must be non-negative. Default: 0.
150
151    Inputs:
152        - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
153        - **alpha** (Tensor) - The α distribution parameter. It must be greater than 0.
154          It is also known as the shape parameter with float32 data type.
155        - **beta** (Tensor) - The β distribution parameter. It must be greater than 0.
156          It is also known as the scale parameter with float32 data type.
157
158    Outputs:
159        Tensor. The shape must be the broadcasted shape of Input "shape" and shapes of alpha and beta.
160        The dtype is float32.
161
162    Raises:
163        TypeError: If neither `seed` nor `seed2` is an int.
164        TypeError: If neither `alpha` nor `beta` is a Tensor.
165        ValueError: If `shape` is not a constant value.
166
167    Supported Platforms:
168        ``Ascend``
169
170    Examples:
171        >>> shape = (3, 1, 2)
172        >>> alpha = Tensor(np.array([[3, 4], [5, 6]]), mstype.float32)
173        >>> beta = Tensor(np.array([1.0]), mstype.float32)
174        >>> gamma = ops.Gamma(seed=3)
175        >>> output = gamma(shape, alpha, beta)
176        >>> result = output.shape
177        >>> print(result)
178        (3, 2, 2)
179    """
180
181    @prim_attr_register
182    def __init__(self, seed=0, seed2=0):
183        """Initialize Gamma"""
184        self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
185        self.add_prim_attr("_random_effect", True)
186        Validator.check_non_negative_int(seed, "seed", self.name)
187        Validator.check_non_negative_int(seed2, "seed2", self.name)
188
189    def __infer__(self, shape, alpha, beta):
190        shape_v = shape["value"]
191        if shape_v is None:
192            raise ValueError(f"For '{self.name}', the 'shape' cannot be None.")
193        Validator.check_value_type("shape", shape_v, [tuple], self.name)
194        for i, shape_i in enumerate(shape_v):
195            Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
196        Validator.check_tensor_dtype_valid("alpha", alpha["dtype"], [mstype.float32], self.name)
197        Validator.check_tensor_dtype_valid("beta", beta["dtype"], [mstype.float32], self.name)
198        broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
199        broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
200        out = {
201            'shape': broadcast_shape,
202            'dtype': mstype.float32,
203            'value': None}
204        return out
205
206
207class Poisson(PrimitiveWithInfer):
208    r"""
209    Produces random non-negative integer values i, distributed according to discrete probability function:
210
211    .. math::
212        \text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!},
213
214    Args:
215        seed (int): Random seed, must be non-negative. Default: 0.
216        seed2 (int): Random seed2, must be non-negative. Default: 0.
217
218    Inputs:
219        - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
220        - **mean** (Tensor) - μ parameter the distribution was constructed with. The parameter defines mean number
221          of occurrences of the event. It must be greater than 0. With float32 data type.
222
223    Outputs:
224        Tensor. Its shape must be the broadcasted shape of `shape` and the shape of `mean`.
225        The dtype is int32.
226
227    Raises:
228        TypeError: If neither `seed` nor `seed2` is an int.
229        TypeError: If `shape` is not a tuple.
230        TypeError: If `mean` is not a Tensor whose dtype is not float32.
231
232    Supported Platforms:
233        ``Ascend``
234
235    Examples:
236        >>> shape = (4, 1)
237        >>> mean = Tensor(np.array([5.0, 10.0]), mstype.float32)
238        >>> poisson = ops.Poisson(seed=5)
239        >>> output = poisson(shape, mean)
240        >>> result = output.shape
241        >>> print(result)
242        (4, 2)
243    """
244
245    @prim_attr_register
246    def __init__(self, seed=0, seed2=0):
247        """Initialize Poisson"""
248        self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
249        self.add_prim_attr("_random_effect", True)
250        Validator.check_non_negative_int(seed, "seed", self.name)
251        Validator.check_non_negative_int(seed2, "seed2", self.name)
252
253    def __infer__(self, shape, mean):
254        shape_v = shape["value"]
255        if shape_v is None:
256            raise ValueError(f"For '{self.name}', the 'shape' cannot be None.")
257        Validator.check_value_type("shape", shape_v, [tuple], self.name)
258        for i, shape_i in enumerate(shape_v):
259            Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
260        Validator.check_tensor_dtype_valid("mean", mean["dtype"], [mstype.float32], self.name)
261        broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
262        out = {
263            'shape': broadcast_shape,
264            'dtype': mstype.int32,
265            'value': None}
266        return out
267
268
269class UniformInt(PrimitiveWithInfer):
270    r"""
271    Produces random integer values i, uniformly distributed on the closed interval [minval, maxval), that is,
272    distributed according to the discrete probability function:
273
274    .. math::
275        \text{P}(i|a,b) = \frac{1}{b-a+1},
276
277    where the :math:`a` indicates the min distribution parameter,
278    the :math:`b` indicates the max distribution parameter.
279
280    Note:
281        The number in tensor minval must be strictly less than maxval at any position after broadcasting.
282
283    Args:
284        seed (int): Random seed, must be non-negative. Default: 0.
285        seed2 (int): Random seed2, must be non-negative. Default: 0.
286
287    Inputs:
288        - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
289        - **minval** (Tensor) - The distribution parameter, a.
290          It defines the minimum possibly generated value, with int32 data type. Only one number is supported.
291        - **maxval** (Tensor) - The distribution parameter, b.
292          It defines the maximum possibly generated value, with int32 data type. Only one number is supported.
293
294    Raises:
295        TypeError: If neither `seed` nor `seed2` is an int.
296        TypeError: If `shape` is not a tuple.
297        TypeError: If neither `minval` nor `maxval` is a Tensor.
298        ValueError: If `shape` is not a constant value.
299
300    Outputs:
301        Tensor. The shape is the same as the input 'shape', and the data type is int32.
302
303    Supported Platforms:
304        ``Ascend`` ``GPU`` ``CPU``
305
306    Examples:
307        >>> shape = (2, 4)
308        >>> minval = Tensor(1, mstype.int32)
309        >>> maxval = Tensor(5, mstype.int32)
310        >>> uniform_int = ops.UniformInt(seed=10)
311        >>> output = uniform_int(shape, minval, maxval)
312        >>> result = output.shape
313        >>> print(result)
314        (2, 4)
315    """
316
317    @prim_attr_register
318    def __init__(self, seed=0, seed2=0):
319        """Initialize UniformInt"""
320        self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output'])
321        self.add_prim_attr("_random_effect", True)
322        Validator.check_non_negative_int(seed, "seed", self.name)
323        Validator.check_non_negative_int(seed2, "seed2", self.name)
324
325    def __infer__(self, shape, minval, maxval):
326        shape_v = shape["value"]
327        if shape_v is None:
328            raise ValueError(f"For '{self.name}', the 'shape' cannot be None.")
329        Validator.check_value_type("shape", shape_v, [tuple], self.name)
330        for i, shape_i in enumerate(shape_v):
331            Validator.check_positive_int(shape_i, f'shape[{i}]', self.name)
332        Validator.check_tensor_dtype_valid("minval", minval["dtype"], [mstype.int32], self.name)
333        Validator.check_tensor_dtype_valid("maxval", maxval["dtype"], [mstype.int32], self.name)
334        minval_shape = minval['shape']
335        maxval_shape = maxval['shape']
336        Validator.check("dim of minval", len(minval_shape), '0(scalar)', 0, Rel.EQ, self.name)
337        Validator.check("dim of maxval", len(maxval_shape), '0(scalar)', 0, Rel.EQ, self.name)
338        out = {
339            'shape': shape_v,
340            'dtype': mstype.int32,
341            'value': None}
342        return out
343
344
345class UniformReal(StandardNormal):
346    r"""
347    Produces random floating-point values i, uniformly distributed to the interval [0, 1).
348
349    Args:
350        seed (int): Random seed, must be non-negative. Default: 0.
351        seed2 (int): Random seed2, must be non-negative. Default: 0.
352
353    Inputs:
354        - **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
355
356    Outputs:
357        Tensor. The shape that the input 'shape' denotes. The dtype is float32.
358
359    Raises:
360        TypeError: If neither `seed` nor `seed2` is an int.
361        TypeError: If `shape` is not a tuple.
362        ValueError: If `shape` is not a constant value.
363
364    Supported Platforms:
365        ``Ascend`` ``GPU`` ``CPU``
366
367    Examples:
368        >>> shape = (2, 2)
369        >>> uniformreal = ops.UniformReal(seed=2)
370        >>> output = uniformreal(shape)
371        >>> result = output.shape
372        >>> print(result)
373        (2, 2)
374    """
375
376
377class RandomChoiceWithMask(PrimitiveWithInfer):
378    """
379    Generates a random sample as index tensor with a mask tensor from a given tensor.
380
381    The input must be a tensor of rank not less than 1. If its rank is greater than or equal to 2,
382    the first dimension specifies the number of samples.
383    The index tensor and the mask tensor have the fixed shapes. The index tensor denotes the index of the nonzero
384    sample, while the mask tensor denotes which elements in the index tensor are valid.
385
386    Args:
387        count (int): Number of items expected to get and the number must be greater than 0. Default: 256.
388        seed (int): Random seed. Default: 0.
389        seed2 (int): Random seed2. Default: 0.
390
391    Inputs:
392        - **input_x** (Tensor[bool]) - The input tensor.
393          The input tensor rank must be greater than or equal to 1 and less than or equal to 5.
394
395    Outputs:
396        Two tensors, the first one is the index tensor and the other one is the mask tensor.
397
398        - **index** (Tensor) - The output shape is 2-D.
399        - **mask** (Tensor) - The output shape is 1-D.
400
401    Raises:
402        TypeError: If `count` is not an int.
403        TypeError: If neither `seed` nor `seed2` is an int.
404        TypeError: If `input_x` is not a Tensor.
405
406    Supported Platforms:
407        ``Ascend`` ``GPU`` ``CPU``
408
409    Examples:
410        >>> rnd_choice_mask = ops.RandomChoiceWithMask()
411        >>> input_x = Tensor(np.ones(shape=[240000, 4]).astype(np.bool))
412        >>> output_y, output_mask = rnd_choice_mask(input_x)
413        >>> result = output_y.shape
414        >>> print(result)
415        (256, 2)
416        >>> result = output_mask.shape
417        >>> print(result)
418        (256,)
419    """
420
421    @prim_attr_register
422    def __init__(self, count=256, seed=0, seed2=0):
423        """Initialize RandomChoiceWithMask"""
424        Validator.check_value_type("count", count, [int], self.name)
425        Validator.check_positive_int(count, "count", self.name)
426        Validator.check_value_type('seed', seed, [int], self.name)
427        Validator.check_value_type('seed2', seed2, [int], self.name)
428        self.add_prim_attr("_random_effect", True)
429
430    def infer_shape(self, x_shape):
431        Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name)
432        Validator.check_int(len(x_shape), 5, Rel.LE, "input_x rank", self.name)
433        return [self.count, len(x_shape)], [self.count]
434
435    def infer_dtype(self, x_dtype):
436        Validator.check_tensor_dtype_valid('x', x_dtype, [mstype.bool_], self.name)
437        return mstype.int32, mstype.bool_
438
439
440class RandomCategorical(PrimitiveWithInfer):
441    """
442    Generates random samples from a given categorical distribution tensor.
443
444    Args:
445        dtype (mindspore.dtype): The type of output. Its value must be one of mindspore.int16,
446            mindspore.int32 and mindspore.int64. Default: mindspore.int64.
447
448    Inputs:
449        - **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes].
450        - **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed.
451        - **seed** (int) - Random seed. Default: 0. Only constant values is allowed.
452
453    Outputs:
454        - **output** (Tensor) - The output Tensor with shape [batch_size, num_samples].
455
456    Raises:
457        TypeError: If `dtype` is not one of the following: mindspore.int16, mindspore.int32, mindspore.int64.
458        TypeError: If `logits` is not a Tensor.
459        TypeError: If neither `num_sample` nor `seed` is an int.
460
461    Supported Platforms:
462        ``Ascend`` ``GPU``
463
464    Examples:
465        >>> class Net(nn.Cell):
466        ...   def __init__(self, num_sample):
467        ...     super(Net, self).__init__()
468        ...     self.random_categorical = ops.RandomCategorical(mindspore.int64)
469        ...     self.num_sample = num_sample
470        ...   def construct(self, logits, seed=0):
471        ...     return self.random_categorical(logits, self.num_sample, seed)
472        ...
473        >>> x = np.random.random((10, 5)).astype(np.float32)
474        >>> net = Net(8)
475        >>> output = net(Tensor(x))
476        >>> result = output.shape
477        >>> print(result)
478        (10, 8)
479    """
480
481    @prim_attr_register
482    def __init__(self, dtype=mstype.int64):
483        """Initialize RandomCategorical"""
484        self.dtype = dtype
485
486        valid_values = (mstype.int32, mstype.int16, mstype.int64)
487        Validator.check_type_name("dtype", dtype, valid_values, self.name)
488        self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
489                                outputs=['output'])
490        self.add_prim_attr("_random_effect", True)
491
492    def __infer__(self, logits, num_samples, seed):
493        logits_dtype = logits['dtype']
494        valid_dtypes = (mstype.float32, mstype.float16, mstype.float64)
495        Validator.check_tensor_dtype_valid('logits', logits_dtype, valid_dtypes, self.name)
496        num_samples_v = num_samples['value']
497        seed_v = seed['value']
498        Validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
499        Validator.check_value_type('seed', seed_v, (int,), self.name)
500        Validator.check_positive_int(num_samples_v, "num_samples", self.name)
501        x_shape = list(logits['shape'])
502        if len(x_shape) != 2:
503            raise ValueError(f"For '{self.name}', the shape of 'logits' should be 2-dimension, "
504                             f"but got {len(x_shape)}.")
505        ndim = len(x_shape) - 1
506        x_shape[ndim] = num_samples_v
507        self.add_prim_attr('num_samples', num_samples_v)
508        self.add_prim_attr('seed', seed_v)
509        return {'shape': (x_shape),
510                'dtype': (self.dtype),
511                'value': None}
512
513
514class Multinomial(PrimitiveWithInfer):
515    r"""
516    Returns a tensor sampled from the multinomial probability distribution located in the corresponding
517    row of tensor input.
518
519    Note:
520        The rows of input do not need to sum to one (in which case we use the values as weights),
521        but must be non-negative, finite and have a non-zero sum.
522
523    Args:
524        seed (int): Random seed, must be non-negative. Default: 0.
525        seed2 (int): Random seed2, must be non-negative. Default: 0.
526
527    Inputs:
528        - **x** (Tensor[float32]) - the input tensor containing the cumsum of probabilities, must be 1 or 2
529          dimensions.
530        - **num_samples** (int32) - number of samples to draw.
531
532    Outputs:
533        Tensor with the same rows as `x`, each row has num_samples sampled indices.
534
535    Raises:
536        TypeError: If neither `seed` nor `seed2` is an int.
537        TypeError: If `input` is not a Tensor whose dtype is float32.
538        TypeError: If dtype of `num_samples` is not int32.
539
540    Supported Platforms:
541        ``GPU``
542
543    Examples:
544        >>> x = Tensor([0., 9., 4., 0.], mstype.float32)
545        >>> multinomial = ops.Multinomial(seed=10)
546        >>> output = multinomial(x, 2)
547        >>> print(output)
548        [2 1]
549    """
550
551    @prim_attr_register
552    def __init__(self, seed=0, seed2=0):
553        """Initialize Multinomial."""
554        Validator.check_non_negative_int(seed, "seed", self.name)
555        Validator.check_non_negative_int(seed2, "seed2", self.name)
556        self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output'])
557        self.add_prim_attr("_random_effect", True)
558
559    def __infer__(self, inputs, num_samples):
560        input_shape = inputs["shape"]
561        if len(input_shape) != 1 and len(input_shape) != 2:
562            raise ValueError(f"For '{self.name}', the dimension of 'inputs' must be 1 or 2, "
563                             f"but got {len(input_shape)}.")
564        Validator.check_tensor_dtype_valid('inputs', inputs['dtype'], [mstype.float32], self.name)
565        num_samples_value = num_samples["value"]
566        if num_samples_value is None:
567            raise ValueError(f"For '{self.name}', the 'num_samples' cannot be None.")
568        Validator.check_value_type("num_samples", num_samples_value, (int,), self.name)
569        Validator.check_positive_int(num_samples_value, "num_samples")
570        y_shape = (num_samples_value,)
571        if len(input_shape) == 2:
572            y_shape = (input_shape[0], num_samples_value)
573        out = {
574            "shape": y_shape,
575            "dtype": mstype.int32,
576            "value": None}
577        return out
578
579
580class UniformCandidateSampler(PrimitiveWithInfer):
581    r"""
582    Uniform candidate sampler.
583
584    This function samples a set of classes(sampled_candidates) from [0, range_max-1] based on uniform distribution.
585    If unique=True, candidates are drawn without replacement, else unique=False with replacement.
586
587    Args:
588        num_true (int): The number of target classes in each training example.
589        num_sampled (int): The number of classes to randomly sample. The sampled_candidates will have a shape
590            of num_sampled. If unique=True, num_sampled must be less than or equal to range_max.
591        unique (bool): Whether all sampled classes in a batch are unique.
592        range_max (int): The number of possible classes, must be non-negative.
593        seed (int): Used for random number generation, must be non-negative. If seed has a value of 0,
594            seed will be replaced with a randomly generated value. Default: 0.
595        remove_accidental_hits (bool): Whether accidental hit is removed. Default: False.
596
597    Inputs:
598        - **true_classes** (Tensor) - A Tensor. The target classes with a Tensor shape of (batch_size, num_true).
599
600    Outputs:
601        - **sampled_candidates** (Tensor) - The sampled_candidates is independent of the true classes.
602          Shape: (num_sampled, ).
603        - **true_expected_count** (Tensor) - The expected counts under the sampling distribution of each
604          of true_classes. Shape: (batch_size, num_true).
605        - **sampled_expected_count** (Tensor) - The expected counts under the sampling distribution of
606          each of sampled_candidates. Shape: (num_sampled, ).
607
608    Raises:
609        TypeError: If neither `num_true` nor `num_sampled` is an int.
610        TypeError: If neither `unique` nor `remove_accidental_hits` is a bool.
611        TypeError: If neither `range_max` nor `seed` is a int.
612        TypeError: If `true_classes` is not a Tensor.
613
614    Supported Platforms:
615        ``GPU``
616
617    Examples:
618        >>> sampler = ops.UniformCandidateSampler(1, 3, False, 4)
619        >>> output1, output2, output3 = sampler(Tensor(np.array([[1], [3], [4], [6], [3]], dtype=np.int32)))
620        >>> print(output1, output2, output3)
621        [1, 1, 3], [[0.75], [0.75], [0.75], [0.75], [0.75]], [0.75, 0.75, 0.75]
622    """
623
624    @prim_attr_register
625    def __init__(self, num_true, num_sampled, unique, range_max, seed=0, remove_accidental_hits=False):
626        """Initialize UniformCandidateSampler"""
627        Validator.check_value_type("num_true", num_true, [int], self.name)
628        Validator.check_value_type("num_sampled", num_sampled, [int], self.name)
629        Validator.check_value_type("unique", unique, [bool], self.name)
630        Validator.check_value_type("range_max", range_max, [int], self.name)
631        Validator.check_value_type("seed", seed, [int], self.name)
632        Validator.check_value_type("remove_accidental_hits", remove_accidental_hits, [bool], self.name)
633        Validator.check("value of num_sampled", num_sampled, '', 0, Rel.GT, self.name)
634        Validator.check("value of range_max", range_max, '', 0, Rel.GT, self.name)
635        self.num_true = num_true
636        if unique:
637            Validator.check('value of num_sampled', num_sampled, "value of range_max", range_max, Rel.LE, self.name)
638        Validator.check("value of seed", seed, '', 0, Rel.GE, self.name)
639        self.num_sampled = num_sampled
640
641    def infer_dtype(self, true_classes_type):
642        Validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
643        Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type,
644                                           (mstype.int32, mstype.int64), self.name)
645        return true_classes_type, mstype.float32, mstype.float32
646
647    def infer_shape(self, true_classes_shape):
648        Validator.check("true_class.shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
649        return [self.num_sampled], true_classes_shape, [self.num_sampled]
650
651
652class LogUniformCandidateSampler(PrimitiveWithInfer):
653    """
654    Generates random labels with a log-uniform distribution for sampled_candidates.
655
656    Random sampling a tensor of sampled classes from the range of integers [0, range_max).
657
658    Args:
659        num_true (int): The number of target classes per training example. Default: 1.
660        num_sampled (int): The number of classes to randomly sample. Default: 5.
661        unique (bool): Determines whether sample with rejection. If `unique` is True,
662          all sampled classes in a batch are unique. Default: True.
663        range_max (int): The number of possible classes. When `unique` is True,
664          `range_max` must be greater than or equal to `num_sampled`. Default: 5.
665        seed (int): Random seed, must be non-negative. Default: 0.
666
667    Inputs:
668        - **true_classes** (Tensor) - The target classes. With data type of int64 and shape [batch_size, num_true].
669
670    Outputs:
671        Tuple of 3 Tensors.
672
673        - **sampled_candidates** (Tensor) - A Tensor with shape (num_sampled,) and the same type as `true_classes`.
674        - **true_expected_count** (Tensor) - A Tensor with the same shape as `true_classes and` type float32.
675        - **sampled_expected_count** (Tensor) - A Tensor with the same shape as `sampled_candidates` and type float32.
676
677    Raises:
678        TypeError: If neither `num_true` nor `num_sampled` is an int.
679        TypeError: If `unique` is not a bool.
680        TypeError: If neither `range_max` nor `seed` is an int.
681        TypeError: If `true_classes` is not a Tensor.
682
683    Supported Platforms:
684        ``Ascend``
685
686    Examples:
687        >>> sampler = ops.LogUniformCandidateSampler(2, 5, True, 5)
688        >>> output1, output2, output3 = sampler(Tensor(np.array([[1, 7], [0, 4], [3, 3]])))
689        >>> print(output1, output2, output3)
690        [3 2 0 4 1]
691        [[0.92312991 0.49336370]
692         [0.99248987 0.65806371]
693         [0.73553443 0.73553443]]
694        [0.73553443 0.82625800 0.99248987 0.65806371 0.92312991]
695
696    """
697
698    @prim_attr_register
699    def __init__(self, num_true=1, num_sampled=5, unique=True, range_max=5, seed=0):
700        """Initialize LogUniformCandidateSampler"""
701        self.init_prim_io_names(inputs=['true_classes'],
702                                outputs=['sampled_candidates', 'true_expected_count', 'sampled_expected_count'])
703        Validator.check_value_type("num_true", num_true, [int], self.name)
704        Validator.check_value_type("num_sampled", num_sampled, [int], self.name)
705        Validator.check_value_type("unique", unique, [bool], self.name)
706        Validator.check_value_type("range_max", range_max, [int], self.name)
707        Validator.check_value_type("seed", seed, [int], self.name)
708        self.num_true = Validator.check_number("num_true", num_true, 1, Rel.GE, self.name)
709        self.num_sampled = Validator.check_number("num_sampled", num_sampled, 1, Rel.GE, self.name)
710        Validator.check_number("range_max", range_max, 1, Rel.GE, self.name)
711        if unique:
712            Validator.check("range_max", range_max, "num_sampled", num_sampled, Rel.GE, self.name)
713        self.range_max = range_max
714        self.unique = unique
715        self.seed = Validator.check_number("seed", seed, 0, Rel.GE, self.name)
716
717    def infer_shape(self, true_classes_shape):
718        Validator.check_int(len(true_classes_shape), 2, Rel.EQ, "dim of true_classes", self.name)
719        Validator.check("true_classes_shape[1]", true_classes_shape[1], "num_true", self.num_true, Rel.EQ, self.name)
720        return (self.num_sampled,), true_classes_shape, (self.num_sampled,)
721
722    def infer_dtype(self, true_classes_type):
723        Validator.check_subclass("true_classes_type", true_classes_type, mstype.tensor, self.name)
724        valid_types = (mstype.int64,)
725        Validator.check_tensor_dtype_valid("true_classes_type", true_classes_type, valid_types, self.name)
726        expected_type = mstype.float32
727        return true_classes_type, expected_type, expected_type
728