• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1mindspore.ops.RandomCategorical
2===============================
3
4.. py:class:: mindspore.ops.RandomCategorical(dtype=mstype.int64)
5
6    从分类分布中抽取样本。
7
8    参数:
9        - **dtype** (mindspore.dtype) - 输出的类型。它的值必须是 mstype.int16mstype.int32mstype.int64 之一。默认值: ``mstype.int64`` 。
10
11    输入:
12        - **logits** (Tensor) - 输入Tensor。Shape为 :math:`(batch\_size, num\_classes)` 的二维Tensor。
13        - **num_sample** (int) - 要抽取的样本数。只允许使用常量值。
14        - **seed** (int) - 随机种子。只允许使用常量值。默认值: ``0`` 。
15
16    输出:
17        - **output** (Tensor) - Shape为 :math:`(batch\_size, num\_samples)` 的输出Tensor。
18
19    异常:
20        - **TypeError** - 如果 `dtype` 不是以下之一:mstype.int16mstype.int32mstype.int6421        - **TypeError** - 如果 `logits` 不是Tensor。
22        - **TypeError** - 如果 `num_sample` 或者 `seed` 不是 int。
23