• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Stateless random ops which take seed as a tensor input."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import enum
22import numpy as np
23import six
24
25from tensorflow.python.compat import compat
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import bitwise_ops
32from tensorflow.python.ops import gen_stateless_random_ops
33from tensorflow.python.ops import gen_stateless_random_ops_v2
34from tensorflow.python.ops import math_ops
35from tensorflow.python.util import deprecation
36from tensorflow.python.util import dispatch
37from tensorflow.python.util.tf_export import tf_export
38
39
40ops.NotDifferentiable("StatelessMultinomial")
41ops.NotDifferentiable("StatelessRandomBinomial")
42ops.NotDifferentiable("StatelessRandomNormal")
43ops.NotDifferentiable("StatelessRandomPoisson")
44ops.NotDifferentiable("StatelessRandomUniform")
45ops.NotDifferentiable("StatelessRandomUniformInt")
46ops.NotDifferentiable("StatelessRandomUniformFullInt")
47ops.NotDifferentiable("StatelessTruncatedNormal")
48
49
50ops.NotDifferentiable("StatelessRandomNormalV2")
51ops.NotDifferentiable("StatelessRandomUniformV2")
52ops.NotDifferentiable("StatelessRandomUniformIntV2")
53ops.NotDifferentiable("StatelessRandomUniformFullIntV2")
54ops.NotDifferentiable("StatelessTruncatedNormalV2")
55
56
57@tf_export("random.Algorithm", "random.experimental.Algorithm")
58class Algorithm(enum.Enum):
59  # The numbers here must match framework/rng_alg.h
60  PHILOX = 1
61  THREEFRY = 2
62  AUTO_SELECT = 3
63
64
65def convert_alg_to_int(alg):
66  """Converts algorithm to an integer.
67
68  Args:
69    alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed
70      strings are "philox" and "threefry".
71
72  Returns:
73    An integer, unless the input is a Tensor in which case a Tensor is returned.
74  """
75  if isinstance(alg, six.integer_types):
76    return alg
77  if isinstance(alg, Algorithm):
78    return alg.value
79  if isinstance(alg, ops.Tensor):
80    return alg
81  if isinstance(alg, str):
82    if alg == "philox":
83      return Algorithm.PHILOX.value
84    elif alg in ("threefry", "three-fry", "three_fry"):
85      return Algorithm.THREEFRY.value
86    elif alg in ("autoselect", "auto-select", "auto_select"):
87      return Algorithm.AUTO_SELECT.value
88    else:
89      raise ValueError(
90          f"Argument `alg` got unsupported string value {alg}. Supported "
91          f"string values are 'philox' for the Philox algorithm, 'threefry' "
92          f"for the ThreeFry algorithm, and 'auto_select' for auto-selection.")
93  else:
94    raise TypeError(
95        f"Can't convert argument `alg` (of value {alg} and type {type(alg)}) "
96        f"to int.")
97
98
99def _resolve_alg(alg):
100  if alg == Algorithm.AUTO_SELECT.value:
101    return gen_stateless_random_ops_v2.stateless_random_get_alg()
102  return alg
103
104
105def _get_key_counter(seed, alg):
106  """Calculates the key and counter to pass to raw RNG ops.
107
108  This function calculates the key and counter that will be passed to
109  the raw RNG ops like `StatelessRandomUniformV2`. Depending on the
110  input `alg`, the key and counter may be scrambled or copied from
111  `seed`. If `alg` is `"auto_select"`, the key and counter will be
112  determined at runtime based on device type.
113
114  Args:
115    seed: An integer tensor of shape [2]. The seed to calculate the
116      key and counter from.
117    alg: The RNG algorithm. See `tf.random.stateless_uniform` for an
118      explanation.
119
120  Returns:
121    A pair (key, counter) suitable for V2 stateless RNG ops like
122    `StatelessRandomUniformV2`.
123  """
124  if alg == Algorithm.AUTO_SELECT.value:
125    key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
126        seed)
127  elif alg == Algorithm.PHILOX.value:
128    key, counter = _philox_scramble_seed(seed)
129  elif alg == Algorithm.THREEFRY.value:
130    key = array_ops.reshape(
131        uint32s_to_uint64(math_ops.cast(seed, dtypes.uint32)), [1])
132    counter = array_ops.zeros([1], dtypes.uint64)
133  else:
134    raise ValueError(
135        f"Argument `alg` got unsupported value {alg}. Supported values are "
136        f"{Algorithm.PHILOX.value} for the Philox algorithm, "
137        f"{Algorithm.THREEFRY.value} for the ThreeFry algorithm, and "
138        f"{Algorithm.AUTO_SELECT.value} for auto-selection.")
139  return key, counter
140
141
142def _get_key_counter_alg(seed, alg):
143  if alg is None:
144    alg = Algorithm.AUTO_SELECT.value
145  alg = convert_alg_to_int(alg)
146  key, counter = _get_key_counter(seed, alg)
147  if compat.forward_compatible(2021, 8, 11):
148    return key, counter, alg
149  else:
150    return key, counter, _resolve_alg(alg)
151
152
153def _philox_scramble_seed(seed):
154  # the same scrambling procedure as core/kernels/stateless_random_ops.cc
155  key = constant_op.constant([0x02461e293ec8f720], dtypes.uint64)
156  counter = math_ops.cast(seed, dtypes.uint64)
157  mix = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
158      [4], key=key, counter=counter, dtype=dtypes.uint32,
159      alg=Algorithm.PHILOX.value)
160  key = array_ops.reshape(uint32s_to_uint64(mix[:2]), [1])
161  counter = array_ops.stack([0, uint32s_to_uint64(mix[2:])], axis=0)
162  return key, counter
163
164
165def uint32s_to_uint64(x):
166  return bitwise_ops.bitwise_or(
167      math_ops.cast(x[0], dtypes.uint64),
168      bitwise_ops.left_shift(math_ops.cast(x[1], dtypes.uint64),
169                             constant_op.constant(32, dtypes.uint64)))
170
171
172@tf_export("random.experimental.stateless_split")
173@dispatch.add_dispatch_support
174def split(seed, num=2, alg="auto_select"):
175  """Splits an RNG seed into `num` new seeds by adding a leading axis.
176
177  Example:
178
179  >>> seed = [1, 2]
180  >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3)
181  >>> print(new_seeds)
182  tf.Tensor(
183  [[1105988140 1738052849]
184   [-335576002  370444179]
185   [  10670227 -246211131]], shape=(3, 2), dtype=int32)
186  >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :])
187  <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.59835213, -0.9578608 ,
188  0.9002807 ], dtype=float32)>
189
190  Args:
191    seed: an RNG seed (a tensor with shape [2] and dtype `int32` or
192      `int64`). (When using XLA, only `int32` is allowed.)
193    num: optional, a positive integer or scalar tensor indicating the number of
194      seeds to produce (default 2).
195    alg: The RNG algorithm used to generate the random numbers. See
196      `tf.random.stateless_uniform` for a detailed explanation.
197
198  Returns:
199    A tensor with shape [num, 2] representing `num` new seeds. It will have the
200    same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype
201    will be determined by `tf.convert_to_tensor`).
202  """
203  seed = ops.convert_to_tensor(seed)
204  return stateless_random_uniform(shape=[num, 2], seed=seed, dtype=seed.dtype,
205                                  minval=None, maxval=None, alg=alg)
206
207
208@tf_export("random.experimental.stateless_fold_in")
209@dispatch.add_dispatch_support
210def fold_in(seed, data, alg="auto_select"):
211  """Folds in data to an RNG seed to form a new RNG seed.
212
213  For example, in a distributed-training setting, suppose we have a master seed
214  and a replica ID. We want to fold the replica ID into the master seed to
215  form a "replica seed" to be used by that replica later on, so that different
216  replicas will generate different random numbers but the reproducibility of the
217  whole system can still be controlled by the master seed:
218
219  >>> master_seed = [1, 2]
220  >>> replica_id = 3
221  >>> replica_seed = tf.random.experimental.stateless_fold_in(
222  ...   master_seed, replica_id)
223  >>> print(replica_seed)
224  tf.Tensor([1105988140          3], shape=(2,), dtype=int32)
225  >>> tf.random.stateless_normal(shape=[3], seed=replica_seed)
226  <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.03197195, 0.8979765 ,
227  0.13253039], dtype=float32)>
228
229  Args:
230    seed: an RNG seed (a tensor with shape [2] and dtype `int32` or
231      `int64`). (When using XLA, only `int32` is allowed.)
232    data: an `int32` or `int64` scalar representing data to be folded in to the
233      seed.
234    alg: The RNG algorithm used to generate the random numbers. See
235      `tf.random.stateless_uniform` for a detailed explanation.
236
237  Returns:
238    A new RNG seed that is a deterministic function of the inputs and is
239    statistically safe for producing a stream of new pseudo-random values. It
240    will have the same dtype as `data` (if `data` doesn't have an explict dtype,
241    the dtype will be determined by `tf.convert_to_tensor`).
242  """
243  data = ops.convert_to_tensor(data)
244  seed1 = stateless_random_uniform(shape=[], seed=seed, dtype=data.dtype,
245                                   minval=None, maxval=None, alg=alg)
246  return array_ops.stack([seed1, data])
247
248
249@tf_export("random.stateless_uniform")
250@dispatch.add_dispatch_support
251def stateless_random_uniform(shape,
252                             seed,
253                             minval=0,
254                             maxval=None,
255                             dtype=dtypes.float32,
256                             name=None,
257                             alg="auto_select"):
258  """Outputs deterministic pseudorandom values from a uniform distribution.
259
260  This is a stateless version of `tf.random.uniform`: if run twice with the
261  same seeds and shapes, it will produce the same pseudorandom numbers.  The
262  output is consistent across multiple runs on the same hardware (and between
263  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
264  hardware.
265
266  The generated values follow a uniform distribution in the range
267  `[minval, maxval)`. The lower bound `minval` is included in the range, while
268  the upper bound `maxval` is excluded.
269
270  For floats, the default range is `[0, 1)`.  For ints, at least `maxval` must
271  be specified explicitly.
272
273  In the integer case, the random integers are slightly biased unless
274  `maxval - minval` is an exact power of two.  The bias is small for values of
275  `maxval - minval` significantly smaller than the range of the output (either
276  `2**32` or `2**64`).
277
278  For full-range (i.e. inclusive of both max and min) random integers, pass
279  `minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype
280  either both `minval` and `maxval` must be `None` or neither may be `None`. For
281  example:
282  ```python
283  ints = tf.random.stateless_uniform(
284      [10], seed=(2, 3), minval=None, maxval=None, dtype=tf.int32)
285  ```
286
287  Args:
288    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
289    seed: A shape [2] Tensor, the seed to the random number generator. Must have
290      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
291    minval: A Tensor or Python value of type `dtype`, broadcastable with
292      `shape` (for integer types, broadcasting is not supported, so it needs to
293      be a scalar). The lower bound on the range of random values to
294      generate. Pass `None` for full-range integers.  Defaults to 0.
295    maxval: A Tensor or Python value of type `dtype`, broadcastable with
296      `shape` (for integer types, broadcasting is not supported, so it needs to
297      be a scalar). The upper bound on the range of random values to generate.
298      Defaults to 1 if `dtype` is floating point. Pass `None` for full-range
299      integers.
300    dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or
301      `int64`. For unbounded uniform ints (`minval`, `maxval` both `None`),
302      `uint32` and `uint64` may be used.
303    name: A name for the operation (optional).
304    alg: The RNG algorithm used to generate the random numbers. Valid
305      choices are `"philox"` for [the Philox
306      algorithm](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf),
307      `"threefry"` for [the ThreeFry
308      algorithm](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf),
309      and `"auto_select"` (default) for the system to automatically
310      select an algorithm based the device type. Values of
311      `tf.random.Algorithm` can also be used. Note that with
312      `"auto_select"`, the outputs of this function may change when
313      it is running on a different device.
314
315  Returns:
316    A tensor of the specified shape filled with random uniform values.
317
318  Raises:
319    ValueError: If `dtype` is integral and only one of `minval` or `maxval` is
320      specified.
321  """
322  dtype = dtypes.as_dtype(dtype)
323  accepted_dtypes = (dtypes.float16, dtypes.bfloat16, dtypes.float32,
324                     dtypes.float64, dtypes.int32, dtypes.int64, dtypes.uint32,
325                     dtypes.uint64)
326  if dtype not in accepted_dtypes:
327    raise ValueError(
328        f"Argument `dtype` got invalid value {dtype}. Accepted dtypes are "
329        f"{accepted_dtypes}.")
330  if dtype.is_integer:
331    if (minval is None) != (maxval is None):
332      raise ValueError(
333          f"For integer `dtype` argument {dtype}, argument `minval` and "
334          f"`maxval` must be both None or not None. Got `minval`={minval} and "
335          f"`maxval`={maxval}.")
336    if minval is not None and dtype in (dtypes.uint32, dtypes.uint64):
337      raise ValueError(
338          f"Argument `dtype` got invalid value {dtype} when argument `minval` "
339          f"is not None. Please don't use unsigned integers in this case.")
340  elif maxval is None:
341    maxval = 1
342  with ops.name_scope(name, "stateless_random_uniform",
343                      [shape, seed, minval, maxval]) as name:
344    shape = tensor_util.shape_tensor(shape)
345    if dtype.is_integer and minval is None:
346      key, counter, alg = _get_key_counter_alg(seed, alg)
347      result = (
348          gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
349              shape, key=key, counter=counter, dtype=dtype, alg=alg, name=name))
350    else:
351      minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
352      maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
353      if dtype.is_integer:
354        key, counter, alg = _get_key_counter_alg(seed, alg)
355        result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
356            shape,
357            key=key,
358            counter=counter,
359            minval=minval,
360            maxval=maxval,
361            alg=alg,
362            name=name)
363      else:
364        key, counter, alg = _get_key_counter_alg(seed, alg)
365        rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2(
366            shape, key=key, counter=counter, dtype=dtype, alg=alg)
367        result = math_ops.add(rnd * (maxval - minval), minval, name=name)
368    tensor_util.maybe_set_static_shape(result, shape)
369    return result
370
371
372@tf_export("random.stateless_binomial")
373@dispatch.add_dispatch_support
374def stateless_random_binomial(shape,
375                              seed,
376                              counts,
377                              probs,
378                              output_dtype=dtypes.int32,
379                              name=None):
380  """Outputs deterministic pseudorandom values from a binomial distribution.
381
382  The generated values follow a binomial distribution with specified count and
383  probability of success parameters.
384
385  This is a stateless version of `tf.random.Generator.binomial`: if run twice
386  with the same seeds and shapes, it will produce the same pseudorandom numbers.
387  The output is consistent across multiple runs on the same hardware (and
388  between CPU and GPU), but may change between versions of TensorFlow or on
389  non-CPU/GPU hardware.
390
391  Example:
392
393  ```python
394  counts = [10., 20.]
395  # Probability of success.
396  probs = [0.8]
397
398  binomial_samples = tf.random.stateless_binomial(
399      shape=[2], seed=[123, 456], counts=counts, probs=probs)
400
401  counts = ... # Shape [3, 1, 2]
402  probs = ...  # Shape [1, 4, 2]
403  shape = [3, 4, 3, 4, 2]
404  # Sample shape will be [3, 4, 3, 4, 2]
405  binomial_samples = tf.random.stateless_binomial(
406      shape=shape, seed=[123, 456], counts=counts, probs=probs)
407  ```
408
409  Args:
410    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
411    seed: A shape [2] Tensor, the seed to the random number generator. Must have
412      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
413    counts: Tensor. The counts of the binomial distribution. Must be
414      broadcastable with `probs`, and broadcastable with the rightmost
415      dimensions of `shape`.
416    probs: Tensor. The probability of success for the binomial distribution.
417      Must be broadcastable with `counts` and broadcastable with the rightmost
418      dimensions of `shape`.
419    output_dtype: The type of the output. Default: tf.int32
420    name: A name for the operation (optional).
421
422  Returns:
423    samples: A Tensor of the specified shape filled with random binomial
424      values.  For each i, each samples[..., i] is an independent draw from
425      the binomial distribution on counts[i] trials with probability of
426      success probs[i].
427
428  """
429  with ops.name_scope(name, "stateless_random_binomial",
430                      [shape, seed, counts, probs]) as name:
431    shape = tensor_util.shape_tensor(shape)
432    probs = ops.convert_to_tensor(
433        probs, dtype_hint=dtypes.float32, name="probs")
434    counts = ops.convert_to_tensor(
435        counts, dtype_hint=probs.dtype, name="counts")
436    result = gen_stateless_random_ops.stateless_random_binomial(
437        shape=shape, seed=seed, counts=counts, probs=probs, dtype=output_dtype)
438    tensor_util.maybe_set_static_shape(result, shape)
439    return result
440
441
442@tf_export("random.stateless_gamma")
443@dispatch.add_dispatch_support
444def stateless_random_gamma(shape,
445                           seed,
446                           alpha,
447                           beta=None,
448                           dtype=dtypes.float32,
449                           name=None):
450  """Outputs deterministic pseudorandom values from a gamma distribution.
451
452  The generated values follow a gamma distribution with specified concentration
453  (`alpha`) and inverse scale (`beta`) parameters.
454
455  This is a stateless version of `tf.random.gamma`: if run twice with the same
456  seeds and shapes, it will produce the same pseudorandom numbers. The output is
457  consistent across multiple runs on the same hardware (and between CPU and
458  GPU),
459  but may change between versions of TensorFlow or on non-CPU/GPU hardware.
460
461  A slight difference exists in the interpretation of the `shape` parameter
462  between `stateless_gamma` and `gamma`: in `gamma`, the `shape` is always
463  prepended to the shape of the broadcast of `alpha` with `beta`; whereas in
464  `stateless_gamma` the `shape` parameter must always encompass the shapes of
465  each of `alpha` and `beta` (which must broadcast together to match the
466  trailing dimensions of `shape`).
467
468  Note: Because internal calculations are done using `float64` and casting has
469  `floor` semantics, we must manually map zero outcomes to the smallest
470  possible positive floating-point value, i.e., `np.finfo(dtype).tiny`.  This
471  means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise
472  should.  This bias can only happen for small values of `alpha`, i.e.,
473  `alpha << 1` or large values of `beta`, i.e., `beta >> 1`.
474
475  The samples are differentiable w.r.t. alpha and beta.
476  The derivatives are computed using the approach described in
477  (Figurnov et al., 2018).
478
479  Example:
480
481  ```python
482  samples = tf.random.stateless_gamma([10, 2], seed=[12, 34], alpha=[0.5, 1.5])
483  # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
484  # the samples drawn from each distribution
485
486  samples = tf.random.stateless_gamma([7, 5, 2], seed=[12, 34], alpha=[.5, 1.5])
487  # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
488  # represents the 7x5 samples drawn from each of the two distributions
489
490  alpha = tf.constant([[1.], [3.], [5.]])
491  beta = tf.constant([[3., 4.]])
492  samples = tf.random.stateless_gamma(
493      [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)
494  # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions.
495
496  with tf.GradientTape() as tape:
497    tape.watch([alpha, beta])
498    loss = tf.reduce_mean(tf.square(tf.random.stateless_gamma(
499        [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)))
500  dloss_dalpha, dloss_dbeta = tape.gradient(loss, [alpha, beta])
501  # unbiased stochastic derivatives of the loss function
502  alpha.shape == dloss_dalpha.shape  # True
503  beta.shape == dloss_dbeta.shape  # True
504  ```
505
506  Args:
507    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
508    seed: A shape [2] Tensor, the seed to the random number generator. Must have
509      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
510    alpha: Tensor. The concentration parameter of the gamma distribution. Must
511      be broadcastable with `beta`, and broadcastable with the rightmost
512      dimensions of `shape`.
513    beta: Tensor. The inverse scale parameter of the gamma distribution. Must be
514      broadcastable with `alpha` and broadcastable with the rightmost dimensions
515      of `shape`.
516    dtype: Floating point dtype of `alpha`, `beta`, and the output.
517    name: A name for the operation (optional).
518
519  Returns:
520    samples: A Tensor of the specified shape filled with random gamma values.
521      For each i, each `samples[..., i] is an independent draw from the gamma
522      distribution with concentration alpha[i] and scale beta[i].
523
524  """
525  with ops.name_scope(name, "stateless_random_gamma",
526                      [shape, seed, alpha, beta]) as name:
527    shape = tensor_util.shape_tensor(shape)
528    alpha = ops.convert_to_tensor(alpha, dtype=dtype, name="alpha")
529    beta = ops.convert_to_tensor(
530        beta if beta is not None else 1, name="beta", dtype=dtype)
531    broadcast_shape = array_ops.broadcast_dynamic_shape(
532        array_ops.shape(alpha), array_ops.shape(beta))
533    alpha_broadcast = array_ops.broadcast_to(alpha, broadcast_shape)
534    result = math_ops.maximum(
535        np.finfo(alpha.dtype.as_numpy_dtype).tiny,
536        gen_stateless_random_ops.stateless_random_gamma_v2(
537            shape, seed=seed, alpha=alpha_broadcast) / beta)
538    tensor_util.maybe_set_static_shape(result, shape)
539    return result
540
541
542@tf_export("random.stateless_poisson")
543@dispatch.add_dispatch_support
544def stateless_random_poisson(shape,
545                             seed,
546                             lam,
547                             dtype=dtypes.int32,
548                             name=None):
549  """Outputs deterministic pseudorandom values from a Poisson distribution.
550
551  The generated values follow a Poisson distribution with specified rate
552  parameter.
553
554  This is a stateless version of `tf.random.poisson`: if run twice with the same
555  seeds and shapes, it will produce the same pseudorandom numbers. The output is
556  consistent across multiple runs on the same hardware, but may change between
557  versions of TensorFlow or on non-CPU/GPU hardware.
558
559  A slight difference exists in the interpretation of the `shape` parameter
560  between `stateless_poisson` and `poisson`: in `poisson`, the `shape` is always
561  prepended to the shape of `lam`; whereas in `stateless_poisson` the shape of
562  `lam` must match the trailing dimensions of `shape`.
563
564  Example:
565
566  ```python
567  samples = tf.random.stateless_poisson([10, 2], seed=[12, 34], lam=[5, 15])
568  # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
569  # the samples drawn from each distribution
570
571  samples = tf.random.stateless_poisson([7, 5, 2], seed=[12, 34], lam=[5, 15])
572  # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
573  # represents the 7x5 samples drawn from each of the two distributions
574
575  rate = tf.constant([[1.], [3.], [5.]])
576  samples = tf.random.stateless_poisson([30, 3, 1], seed=[12, 34], lam=rate)
577  # samples has shape [30, 3, 1], with 30 samples each of 3x1 distributions.
578  ```
579
580  Args:
581    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
582    seed: A shape [2] Tensor, the seed to the random number generator. Must have
583      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
584    lam: Tensor. The rate parameter "lambda" of the Poisson distribution. Shape
585      must match the rightmost dimensions of `shape`.
586    dtype: Dtype of the samples (int or float dtypes are permissible, as samples
587      are discrete). Default: int32.
588    name: A name for the operation (optional).
589
590  Returns:
591    samples: A Tensor of the specified shape filled with random Poisson values.
592      For each i, each `samples[..., i]` is an independent draw from the Poisson
593      distribution with rate `lam[i]`.
594
595  """
596  with ops.name_scope(name, "stateless_random_poisson",
597                      [shape, seed, lam]) as name:
598    shape = tensor_util.shape_tensor(shape)
599    result = gen_stateless_random_ops.stateless_random_poisson(
600        shape, seed=seed, lam=lam, dtype=dtype)
601    tensor_util.maybe_set_static_shape(result, shape)
602    return result
603
604
605@tf_export("random.stateless_normal")
606@dispatch.add_dispatch_support
607def stateless_random_normal(shape,
608                            seed,
609                            mean=0.0,
610                            stddev=1.0,
611                            dtype=dtypes.float32,
612                            name=None,
613                            alg="auto_select"):
614  """Outputs deterministic pseudorandom values from a normal distribution.
615
616  This is a stateless version of `tf.random.normal`: if run twice with the
617  same seeds and shapes, it will produce the same pseudorandom numbers.  The
618  output is consistent across multiple runs on the same hardware (and between
619  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
620  hardware.
621
622  Args:
623    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
624    seed: A shape [2] Tensor, the seed to the random number generator. Must have
625      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
626    mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
627      distribution.
628    stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
629      of the normal distribution.
630    dtype: The type of the output.
631    name: A name for the operation (optional).
632    alg: The RNG algorithm used to generate the random numbers. See
633      `tf.random.stateless_uniform` for a detailed explanation.
634
635  Returns:
636    A tensor of the specified shape filled with random normal values.
637  """
638  with ops.name_scope(name, "stateless_random_normal",
639                      [shape, seed, mean, stddev]) as name:
640    shape = tensor_util.shape_tensor(shape)
641    mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
642    stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
643    key, counter, alg = _get_key_counter_alg(seed, alg)
644    rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
645        shape, key=key, counter=counter, dtype=dtype, alg=alg)
646    result = math_ops.add(rnd * stddev, mean, name=name)
647    tensor_util.maybe_set_static_shape(result, shape)
648    return result
649
650
651@tf_export("random.stateless_truncated_normal")
652@dispatch.add_dispatch_support
653def stateless_truncated_normal(shape,
654                               seed,
655                               mean=0.0,
656                               stddev=1.0,
657                               dtype=dtypes.float32,
658                               name=None,
659                               alg="auto_select"):
660  """Outputs deterministic pseudorandom values, truncated normally distributed.
661
662  This is a stateless version of `tf.random.truncated_normal`: if run twice with
663  the same seeds and shapes, it will produce the same pseudorandom numbers.  The
664  output is consistent across multiple runs on the same hardware (and between
665  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
666  hardware.
667
668  The generated values follow a normal distribution with specified mean and
669  standard deviation, except that values whose magnitude is more than 2 standard
670  deviations from the mean are dropped and re-picked.
671
672  Args:
673    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
674    seed: A shape [2] Tensor, the seed to the random number generator. Must have
675      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
676    mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
677      truncated normal distribution.
678    stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
679      of the normal distribution, before truncation.
680    dtype: The type of the output.
681    name: A name for the operation (optional).
682    alg: The RNG algorithm used to generate the random numbers. See
683      `tf.random.stateless_uniform` for a detailed explanation.
684
685  Returns:
686    A tensor of the specified shape filled with random truncated normal values.
687  """
688  with ops.name_scope(name, "stateless_truncated_normal",
689                      [shape, seed, mean, stddev]) as name:
690    shape = tensor_util.shape_tensor(shape)
691    mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
692    stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
693    key, counter, alg = _get_key_counter_alg(seed, alg)
694    rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
695        shape, key=key, counter=counter, dtype=dtype, alg=alg)
696    result = math_ops.add(rnd * stddev, mean, name=name)
697    tensor_util.maybe_set_static_shape(result, shape)
698    return result
699
700
701@tf_export(v1=["random.stateless_multinomial"])
702@dispatch.add_dispatch_support
703@deprecation.deprecated(
704    date=None, instructions="Use `tf.random.stateless_categorical` instead.")
705def stateless_multinomial(logits,
706                          num_samples,
707                          seed,
708                          output_dtype=dtypes.int64,
709                          name=None):
710  """Draws deterministic pseudorandom samples from a multinomial distribution.
711
712  This is a stateless version of `tf.random.categorical`: if run twice with the
713  same seeds and shapes, it will produce the same pseudorandom numbers.  The
714  output is consistent across multiple runs on the same hardware (and between
715  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
716  hardware.
717
718  Example:
719
720  ```python
721  # samples has shape [1, 5], where each value is either 0 or 1 with equal
722  # probability.
723  samples = tf.random.stateless_categorical(
724      tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17])
725  ```
726
727  Args:
728    logits: 2-D Tensor with shape `[batch_size, num_classes]`.  Each slice
729      `[i, :]` represents the unnormalized log-probabilities for all classes.
730    num_samples: 0-D.  Number of independent samples to draw for each row slice.
731    seed: A shape [2] Tensor, the seed to the random number generator. Must have
732      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
733    output_dtype: integer type to use for the output. Defaults to int64.
734    name: Optional name for the operation.
735
736  Returns:
737    The drawn samples of shape `[batch_size, num_samples]`.
738  """
739  with ops.name_scope(name, "stateless_multinomial", [logits, seed]):
740    return stateless_multinomial_categorical_impl(logits, num_samples,
741                                                  output_dtype, seed)
742
743
744@tf_export("random.stateless_categorical")
745@dispatch.add_dispatch_support
746def stateless_categorical(logits,
747                          num_samples,
748                          seed,
749                          dtype=dtypes.int64,
750                          name=None):
751  """Draws deterministic pseudorandom samples from a categorical distribution.
752
753  This is a stateless version of `tf.categorical`: if run twice with the
754  same seeds and shapes, it will produce the same pseudorandom numbers.  The
755  output is consistent across multiple runs on the same hardware (and between
756  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
757  hardware.
758
759
760  Example:
761
762  ```python
763  # samples has shape [1, 5], where each value is either 0 or 1 with equal
764  # probability.
765  samples = tf.random.stateless_categorical(
766      tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17])
767  ```
768
769  Args:
770    logits: 2-D Tensor with shape `[batch_size, num_classes]`.  Each slice
771      `[i, :]` represents the unnormalized log-probabilities for all classes.
772    num_samples: 0-D.  Number of independent samples to draw for each row slice.
773    seed: A shape [2] Tensor, the seed to the random number generator. Must have
774      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
775    dtype: integer type to use for the output. Defaults to int64.
776    name: Optional name for the operation.
777
778  Returns:
779    The drawn samples of shape `[batch_size, num_samples]`.
780  """
781  with ops.name_scope(name, "stateless_categorical", [logits, seed]):
782    return stateless_multinomial_categorical_impl(logits, num_samples, dtype,
783                                                  seed)
784
785
786def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed):
787  """Implementation for stateless multinomial/categorical ops (v1/v2)."""
788  logits = ops.convert_to_tensor(logits, name="logits")
789  return gen_stateless_random_ops.stateless_multinomial(
790      logits, num_samples, seed, output_dtype=dtype)
791
792
793@dispatch.add_dispatch_support
794@tf_export("random.stateless_parameterized_truncated_normal")
795def stateless_parameterized_truncated_normal(shape,
796                                             seed,
797                                             means=0.0,
798                                             stddevs=1.0,
799                                             minvals=-2.0,
800                                             maxvals=2.0,
801                                             name=None):
802  """Outputs random values from a truncated normal distribution.
803
804  The generated values follow a normal distribution with specified mean and
805  standard deviation, except that values whose magnitude is more than 2 standard
806  deviations from the mean are dropped and re-picked.
807
808
809  Examples:
810
811  Sample from a Truncated normal, with deferring shape parameters that
812  broadcast.
813
814  >>> means = 0.
815  >>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3]))
816  >>> minvals = [-1., -2., -1000.]
817  >>> maxvals = [[10000.], [1.]]
818  >>> y = tf.random.stateless_parameterized_truncated_normal(
819  ...   shape=[10, 2, 3], seed=[7, 17],
820  ...   means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals)
821  >>> y.shape
822  TensorShape([10, 2, 3])
823
824  Args:
825    shape: A 1-D integer `Tensor` or Python array. The shape of the output
826      tensor.
827    seed: A shape [2] Tensor, the seed to the random number generator. Must have
828      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
829    means: A `Tensor` or Python value of type `dtype`. The mean of the truncated
830      normal distribution. This must broadcast with `stddevs`, `minvals` and
831      `maxvals`, and the broadcasted shape must be dominated by `shape`.
832    stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation
833      of the truncated normal distribution. This must broadcast with `means`,
834      `minvals` and `maxvals`, and the broadcasted shape must be dominated by
835      `shape`.
836    minvals: A `Tensor` or Python value of type `dtype`. The minimum value of
837      the truncated normal distribution. This must broadcast with `means`,
838      `stddevs` and `maxvals`, and the broadcasted shape must be dominated by
839      `shape`.
840    maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of
841      the truncated normal distribution. This must broadcast with `means`,
842      `stddevs` and `minvals`, and the broadcasted shape must be dominated by
843      `shape`.
844    name: A name for the operation (optional).
845
846  Returns:
847    A tensor of the specified shape filled with random truncated normal values.
848  """
849  with ops.name_scope(name, "stateless_parameterized_truncated_normal",
850                      [shape, means, stddevs, minvals, maxvals]) as name:
851    shape_tensor = tensor_util.shape_tensor(shape)
852    means_tensor = ops.convert_to_tensor(means, name="means")
853    stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs")
854    minvals_tensor = ops.convert_to_tensor(minvals, name="minvals")
855    maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals")
856    rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal(
857        shape_tensor, seed, means_tensor, stddevs_tensor, minvals_tensor,
858        maxvals_tensor)
859    tensor_util.maybe_set_static_shape(rnd, shape)
860    return rnd
861