• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Stateless random ops which take seed as a tensor input."""
16import enum
17import numpy as np
18
19from tensorflow.python.compat import compat
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import bitwise_ops
26from tensorflow.python.ops import gen_random_index_shuffle_ops
27from tensorflow.python.ops import gen_stateless_random_ops
28from tensorflow.python.ops import gen_stateless_random_ops_v2
29from tensorflow.python.ops import math_ops
30from tensorflow.python.util import deprecation
31from tensorflow.python.util import dispatch
32from tensorflow.python.util.tf_export import tf_export
33
34
35ops.NotDifferentiable("StatelessMultinomial")
36ops.NotDifferentiable("StatelessRandomBinomial")
37ops.NotDifferentiable("StatelessRandomNormal")
38ops.NotDifferentiable("StatelessRandomPoisson")
39ops.NotDifferentiable("StatelessRandomUniform")
40ops.NotDifferentiable("StatelessRandomUniformInt")
41ops.NotDifferentiable("StatelessRandomUniformFullInt")
42ops.NotDifferentiable("StatelessTruncatedNormal")
43ops.NotDifferentiable("StatelessRandomNormalV2")
44ops.NotDifferentiable("StatelessRandomUniformV2")
45ops.NotDifferentiable("StatelessRandomUniformIntV2")
46ops.NotDifferentiable("StatelessRandomUniformFullIntV2")
47ops.NotDifferentiable("StatelessTruncatedNormalV2")
48ops.NotDifferentiable("StatelessRandomShuffle")
49ops.NotDifferentiable("RandomIndexShuffle")
50
51
52@tf_export("random.Algorithm", "random.experimental.Algorithm")
53class Algorithm(enum.Enum):
54  # The numbers here must match framework/rng_alg.h
55  PHILOX = 1
56  THREEFRY = 2
57  AUTO_SELECT = 3
58
59
60def convert_alg_to_int(alg):
61  """Converts algorithm to an integer.
62
63  Args:
64    alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed
65      strings are "philox" and "threefry".
66
67  Returns:
68    An integer, unless the input is a Tensor in which case a Tensor is returned.
69  """
70  if isinstance(alg, int):
71    return alg
72  if isinstance(alg, Algorithm):
73    return alg.value
74  if isinstance(alg, ops.Tensor):
75    return alg
76  if isinstance(alg, str):
77    if alg == "philox":
78      return Algorithm.PHILOX.value
79    elif alg in ("threefry", "three-fry", "three_fry"):
80      return Algorithm.THREEFRY.value
81    elif alg in ("autoselect", "auto-select", "auto_select"):
82      return Algorithm.AUTO_SELECT.value
83    else:
84      raise ValueError(
85          f"Argument `alg` got unsupported string value {alg}. Supported "
86          f"string values are 'philox' for the Philox algorithm, 'threefry' "
87          f"for the ThreeFry algorithm, and 'auto_select' for auto-selection.")
88  else:
89    raise TypeError(
90        f"Can't convert argument `alg` (of value {alg} and type {type(alg)}) "
91        f"to int.")
92
93
94def _resolve_alg(alg):
95  if alg == Algorithm.AUTO_SELECT.value:
96    return gen_stateless_random_ops_v2.stateless_random_get_alg()
97  return alg
98
99
100def _get_key_counter(seed, alg):
101  """Calculates the key and counter to pass to raw RNG ops.
102
103  This function calculates the key and counter that will be passed to
104  the raw RNG ops like `StatelessRandomUniformV2`. Depending on the
105  input `alg`, the key and counter may be scrambled or copied from
106  `seed`. If `alg` is `"auto_select"`, the key and counter will be
107  determined at runtime based on device type.
108
109  Args:
110    seed: An integer tensor of shape [2]. The seed to calculate the
111      key and counter from.
112    alg: The RNG algorithm. See `tf.random.stateless_uniform` for an
113      explanation.
114
115  Returns:
116    A pair (key, counter) suitable for V2 stateless RNG ops like
117    `StatelessRandomUniformV2`.
118  """
119  if alg == Algorithm.AUTO_SELECT.value:
120    key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
121        seed)
122  elif alg == Algorithm.PHILOX.value:
123    key, counter = _philox_scramble_seed(seed)
124  elif alg == Algorithm.THREEFRY.value:
125    key = array_ops.reshape(
126        uint32s_to_uint64(math_ops.cast(seed, dtypes.uint32)), [1])
127    counter = array_ops.zeros([1], dtypes.uint64)
128  else:
129    raise ValueError(
130        f"Argument `alg` got unsupported value {alg}. Supported values are "
131        f"{Algorithm.PHILOX.value} for the Philox algorithm, "
132        f"{Algorithm.THREEFRY.value} for the ThreeFry algorithm, and "
133        f"{Algorithm.AUTO_SELECT.value} for auto-selection.")
134  return key, counter
135
136
137def _get_key_counter_alg(seed, alg):
138  if alg is None:
139    alg = Algorithm.AUTO_SELECT.value
140  alg = convert_alg_to_int(alg)
141  key, counter = _get_key_counter(seed, alg)
142  if compat.forward_compatible(2021, 8, 11):
143    return key, counter, alg
144  else:
145    return key, counter, _resolve_alg(alg)
146
147
148def _philox_scramble_seed(seed):
149  # the same scrambling procedure as core/kernels/stateless_random_ops.cc
150  key = constant_op.constant([0x02461e293ec8f720], dtypes.uint64)
151  counter = math_ops.cast(seed, dtypes.uint64)
152  mix = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
153      [4], key=key, counter=counter, dtype=dtypes.uint32,
154      alg=Algorithm.PHILOX.value)
155  key = array_ops.reshape(uint32s_to_uint64(mix[:2]), [1])
156  counter = array_ops.stack([0, uint32s_to_uint64(mix[2:])], axis=0)
157  return key, counter
158
159
160def uint32s_to_uint64(x):
161  return bitwise_ops.bitwise_or(
162      math_ops.cast(x[0], dtypes.uint64),
163      bitwise_ops.left_shift(math_ops.cast(x[1], dtypes.uint64),
164                             constant_op.constant(32, dtypes.uint64)))
165
166
167@tf_export("random.experimental.stateless_split")
168@dispatch.add_dispatch_support
169def split(seed, num=2, alg="auto_select"):
170  """Splits an RNG seed into `num` new seeds by adding a leading axis.
171
172  Example:
173
174  >>> seed = [1, 2]
175  >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3)
176  >>> print(new_seeds)
177  tf.Tensor(
178  [[1105988140 1738052849]
179   [-335576002  370444179]
180   [  10670227 -246211131]], shape=(3, 2), dtype=int32)
181  >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :])
182  <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.59835213, -0.9578608 ,
183  0.9002807 ], dtype=float32)>
184
185  Args:
186    seed: an RNG seed (a tensor with shape [2] and dtype `int32` or
187      `int64`). (When using XLA, only `int32` is allowed.)
188    num: optional, a positive integer or scalar tensor indicating the number of
189      seeds to produce (default 2).
190    alg: The RNG algorithm used to generate the random numbers. See
191      `tf.random.stateless_uniform` for a detailed explanation.
192
193  Returns:
194    A tensor with shape [num, 2] representing `num` new seeds. It will have the
195    same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype
196    will be determined by `tf.convert_to_tensor`).
197  """
198  seed = ops.convert_to_tensor(seed)
199  return stateless_random_uniform(shape=[num, 2], seed=seed, dtype=seed.dtype,
200                                  minval=None, maxval=None, alg=alg)
201
202
203@tf_export("random.experimental.stateless_fold_in")
204@dispatch.add_dispatch_support
205def fold_in(seed, data, alg="auto_select"):
206  """Folds in data to an RNG seed to form a new RNG seed.
207
208  For example, in a distributed-training setting, suppose we have a master seed
209  and a replica ID. We want to fold the replica ID into the master seed to
210  form a "replica seed" to be used by that replica later on, so that different
211  replicas will generate different random numbers but the reproducibility of the
212  whole system can still be controlled by the master seed:
213
214  >>> master_seed = [1, 2]
215  >>> replica_id = 3
216  >>> replica_seed = tf.random.experimental.stateless_fold_in(
217  ...   master_seed, replica_id)
218  >>> print(replica_seed)
219  tf.Tensor([1105988140          3], shape=(2,), dtype=int32)
220  >>> tf.random.stateless_normal(shape=[3], seed=replica_seed)
221  <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.03197195, 0.8979765 ,
222  0.13253039], dtype=float32)>
223
224  Args:
225    seed: an RNG seed (a tensor with shape [2] and dtype `int32` or
226      `int64`). (When using XLA, only `int32` is allowed.)
227    data: an `int32` or `int64` scalar representing data to be folded in to the
228      seed.
229    alg: The RNG algorithm used to generate the random numbers. See
230      `tf.random.stateless_uniform` for a detailed explanation.
231
232  Returns:
233    A new RNG seed that is a deterministic function of the inputs and is
234    statistically safe for producing a stream of new pseudo-random values. It
235    will have the same dtype as `data` (if `data` doesn't have an explict dtype,
236    the dtype will be determined by `tf.convert_to_tensor`).
237  """
238  data = ops.convert_to_tensor(data)
239  seed1 = stateless_random_uniform(shape=[], seed=seed, dtype=data.dtype,
240                                   minval=None, maxval=None, alg=alg)
241  return array_ops.stack([seed1, data])
242
243
244@tf_export("random.experimental.index_shuffle")
245@dispatch.add_dispatch_support
246def index_shuffle(index, seed, max_index):
247  """Outputs the position of `index` in a permutation of [0, ..., max_index].
248
249  For each possible `seed` and `max_index` there is one pseudorandom permutation
250  of the sequence S=[0, ..., max_index]. Instead of materializing the full array
251  we can compute the new position of any single element in S. This can be useful
252  for very large `max_index`s.
253
254  The input `index` and output can be used as indices to shuffle a vector.
255  For example:
256
257  >>> vector = tf.constant(['e0', 'e1', 'e2', 'e3'])
258  >>> indices = tf.random.experimental.index_shuffle(tf.range(4), [5, 9], 3)
259  >>> shuffled_vector = tf.gather(vector, indices)
260  >>> print(shuffled_vector)
261  tf.Tensor([b'e2' b'e0' b'e1' b'e3'], shape=(4,), dtype=string)
262
263  More usefully, it can be used in a streaming (aka online) scenario such as
264  `tf.data`,  where each element of `vector` is processed individually and the
265  whole `vector` is never materialized in memory.
266
267  >>> dataset = tf.data.Dataset.range(10)
268  >>> dataset = dataset.map(
269  ...  lambda idx: tf.random.experimental.index_shuffle(idx, [5, 8], 9))
270  >>> print(list(dataset.as_numpy_iterator()))
271  [3, 8, 0, 1, 2, 7, 6, 9, 4, 5]
272
273  This operation is stateless (like other `tf.random.stateless_*` functions),
274  meaning the output is fully determined by the `seed` (other inputs being
275  equal).
276  Each `seed` choice corresponds to one permutation, so when calling this
277  function
278  multiple times for the same shuffling, please make sure to use the same
279  `seed`. For example:
280
281  >>> seed = [5, 9]
282  >>> idx0 = tf.random.experimental.index_shuffle(0, seed, 3)
283  >>> idx1 = tf.random.experimental.index_shuffle(1, seed, 3)
284  >>> idx2 = tf.random.experimental.index_shuffle(2, seed, 3)
285  >>> idx3 = tf.random.experimental.index_shuffle(3, seed, 3)
286  >>> shuffled_vector = tf.gather(vector, [idx0, idx1, idx2, idx3])
287  >>> print(shuffled_vector)
288  tf.Tensor([b'e2' b'e0' b'e1' b'e3'], shape=(4,), dtype=string)
289
290  Args:
291    index: An integer scalar tensor or vector with values in [0, `max_index`].
292      It can be seen as either a value `v` in the sequence `S`=[0, ...,
293      `max_index`] to be permutated, or as an index of an element `e` in a
294      shuffled vector.
295    seed: A tensor of shape [2] or [n, 2] with dtype int32/uint32/int64/uint64.
296      The RNG seed. If the rank is unknown during graph building it must be 1 at
297      runtime.
298    max_index: A non-negative tensor with the same shape and dtype as `index`.
299      The upper bound (inclusive).
300
301  Returns:
302    If all inputs were scalar (shape [2] for `seed`) the output will be a scalar
303    with the same dtype as `index`. The output can be seen as the new position
304    of `v` in `S`, or as the index of `e` in the vector before shuffling.
305    If one or multiple inputs were vectors (shape [n, 2] for `seed`) then the
306    output will be a vector of the same size which each element shuffled
307    independently. Scalar values are broadcasted in this case.
308  """
309  # We expect users to pass a seed with shape [2] to be consistent with other
310  # stateless_* ops, but the raw op expects shape [3].
311  seed = ops.convert_to_tensor(seed)
312  # Pad the first dimension with an arbitrary number since our raw op expects
313  # shape [3].
314  if seed.shape.rank is None:
315    paddings = [[1, 0]]
316  else:
317    paddings = [[1, 0]] + (seed.shape.rank - 1) * [[0, 0]]
318  seed = array_ops.pad(seed, paddings, constant_values=498247692)
319  return gen_random_index_shuffle_ops.random_index_shuffle(
320      index, seed=seed, max_index=max_index)
321
322
323@tf_export("random.experimental.stateless_shuffle")
324@dispatch.add_dispatch_support
325def stateless_shuffle(value, seed, alg="auto_select", name=None):
326  """Randomly and deterministically shuffles a tensor along its first dimension.
327
328  The tensor is shuffled along dimension 0, such that each `value[j]` is mapped
329  to one and only one `output[i]`. For example, a mapping that might occur for a
330  3x2 tensor is:
331
332  ```python
333  [[1, 2],       [[5, 6],
334   [3, 4],  ==>   [1, 2],
335   [5, 6]]        [3, 4]]
336  ```
337
338  >>> v = tf.constant([[1, 2], [3, 4], [5, 6]])
339  >>> shuffled = tf.random.experimental.stateless_shuffle(v, seed=[8, 9])
340  >>> print(shuffled)
341  tf.Tensor(
342  [[5 6]
343    [1 2]
344    [3 4]], shape=(3, 2), dtype=int32)
345
346  This is a stateless version of `tf.random.shuffle`: if run twice with the
347  same `value` and `seed`, it will produce the same result.  The
348  output is consistent across multiple runs on the same hardware (and between
349  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
350  hardware.
351
352  Args:
353    value: A Tensor to be shuffled.
354    seed: A shape [2] Tensor. The seed to the random number generator. Must have
355      dtype `int32` or `int64`.
356    alg: The RNG algorithm used to generate the random numbers. See
357      `tf.random.stateless_uniform` for a detailed explanation.
358    name: A name for the operation.
359
360  Returns:
361    A tensor of same shape and type as `value`, shuffled along its first
362    dimension.
363  """
364  with ops.name_scope(name, "stateless_shuffle", [value, seed]) as name:
365    key, counter, alg = _get_key_counter_alg(seed, alg)
366    return gen_stateless_random_ops_v2.stateless_shuffle(
367        value, key=key, counter=counter, alg=alg)
368
369
370@tf_export("random.stateless_uniform")
371@dispatch.add_dispatch_support
372def stateless_random_uniform(shape,
373                             seed,
374                             minval=0,
375                             maxval=None,
376                             dtype=dtypes.float32,
377                             name=None,
378                             alg="auto_select"):
379  """Outputs deterministic pseudorandom values from a uniform distribution.
380
381  This is a stateless version of `tf.random.uniform`: if run twice with the
382  same seeds and shapes, it will produce the same pseudorandom numbers.  The
383  output is consistent across multiple runs on the same hardware (and between
384  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
385  hardware.
386
387  The generated values follow a uniform distribution in the range
388  `[minval, maxval)`. The lower bound `minval` is included in the range, while
389  the upper bound `maxval` is excluded.
390
391  For floats, the default range is `[0, 1)`.  For ints, at least `maxval` must
392  be specified explicitly.
393
394  In the integer case, the random integers are slightly biased unless
395  `maxval - minval` is an exact power of two.  The bias is small for values of
396  `maxval - minval` significantly smaller than the range of the output (either
397  `2**32` or `2**64`).
398
399  For full-range (i.e. inclusive of both max and min) random integers, pass
400  `minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype
401  either both `minval` and `maxval` must be `None` or neither may be `None`. For
402  example:
403  ```python
404  ints = tf.random.stateless_uniform(
405      [10], seed=(2, 3), minval=None, maxval=None, dtype=tf.int32)
406  ```
407
408  Args:
409    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
410    seed: A shape [2] Tensor, the seed to the random number generator. Must have
411      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
412    minval: A Tensor or Python value of type `dtype`, broadcastable with
413      `shape` (for integer types, broadcasting is not supported, so it needs to
414      be a scalar). The lower bound on the range of random values to
415      generate. Pass `None` for full-range integers.  Defaults to 0.
416    maxval: A Tensor or Python value of type `dtype`, broadcastable with
417      `shape` (for integer types, broadcasting is not supported, so it needs to
418      be a scalar). The upper bound on the range of random values to generate.
419      Defaults to 1 if `dtype` is floating point. Pass `None` for full-range
420      integers.
421    dtype: The type of the output: `float16`, `bfloat16`, `float32`, `float64`,
422      `int32`, or `int64`. For unbounded uniform ints (`minval`, `maxval` both
423      `None`), `uint32` and `uint64` may be used. Defaults to `float32`.
424    name: A name for the operation (optional).
425    alg: The RNG algorithm used to generate the random numbers. Valid
426      choices are `"philox"` for [the Philox
427      algorithm](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf),
428      `"threefry"` for [the ThreeFry
429      algorithm](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf),
430      and `"auto_select"` (default) for the system to automatically
431      select an algorithm based the device type. Values of
432      `tf.random.Algorithm` can also be used. Note that with
433      `"auto_select"`, the outputs of this function may change when
434      it is running on a different device.
435
436  Returns:
437    A tensor of the specified shape filled with random uniform values.
438
439  Raises:
440    ValueError: If `dtype` is integral and only one of `minval` or `maxval` is
441      specified.
442  """
443  dtype = dtypes.as_dtype(dtype)
444  accepted_dtypes = (dtypes.float16, dtypes.bfloat16, dtypes.float32,
445                     dtypes.float64, dtypes.int32, dtypes.int64, dtypes.uint32,
446                     dtypes.uint64)
447  if dtype not in accepted_dtypes:
448    raise ValueError(
449        f"Argument `dtype` got invalid value {dtype}. Accepted dtypes are "
450        f"{accepted_dtypes}.")
451  if dtype.is_integer:
452    if (minval is None) != (maxval is None):
453      raise ValueError(
454          f"For integer `dtype` argument {dtype}, argument `minval` and "
455          f"`maxval` must be both None or not None. Got `minval`={minval} and "
456          f"`maxval`={maxval}.")
457    if minval is not None and dtype in (dtypes.uint32, dtypes.uint64):
458      raise ValueError(
459          f"Argument `dtype` got invalid value {dtype} when argument `minval` "
460          f"is not None. Please don't use unsigned integers in this case.")
461  elif maxval is None:
462    maxval = 1
463  with ops.name_scope(name, "stateless_random_uniform",
464                      [shape, seed, minval, maxval]) as name:
465    shape = tensor_util.shape_tensor(shape)
466    if dtype.is_integer and minval is None:
467      key, counter, alg = _get_key_counter_alg(seed, alg)
468      result = (
469          gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
470              shape, key=key, counter=counter, dtype=dtype, alg=alg, name=name))
471    else:
472      minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
473      maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
474      if dtype.is_integer:
475        key, counter, alg = _get_key_counter_alg(seed, alg)
476        result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
477            shape,
478            key=key,
479            counter=counter,
480            minval=minval,
481            maxval=maxval,
482            alg=alg,
483            name=name)
484      else:
485        key, counter, alg = _get_key_counter_alg(seed, alg)
486        rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2(
487            shape, key=key, counter=counter, dtype=dtype, alg=alg)
488        result = math_ops.add(rnd * (maxval - minval), minval, name=name)
489    tensor_util.maybe_set_static_shape(result, shape)
490    return result
491
492
493@tf_export("random.stateless_binomial")
494@dispatch.add_dispatch_support
495def stateless_random_binomial(shape,
496                              seed,
497                              counts,
498                              probs,
499                              output_dtype=dtypes.int32,
500                              name=None):
501  """Outputs deterministic pseudorandom values from a binomial distribution.
502
503  The generated values follow a binomial distribution with specified count and
504  probability of success parameters.
505
506  This is a stateless version of `tf.random.Generator.binomial`: if run twice
507  with the same seeds and shapes, it will produce the same pseudorandom numbers.
508  The output is consistent across multiple runs on the same hardware (and
509  between CPU and GPU), but may change between versions of TensorFlow or on
510  non-CPU/GPU hardware.
511
512  Example:
513
514  ```python
515  counts = [10., 20.]
516  # Probability of success.
517  probs = [0.8]
518
519  binomial_samples = tf.random.stateless_binomial(
520      shape=[2], seed=[123, 456], counts=counts, probs=probs)
521
522  counts = ... # Shape [3, 1, 2]
523  probs = ...  # Shape [1, 4, 2]
524  shape = [3, 4, 3, 4, 2]
525  # Sample shape will be [3, 4, 3, 4, 2]
526  binomial_samples = tf.random.stateless_binomial(
527      shape=shape, seed=[123, 456], counts=counts, probs=probs)
528  ```
529
530  Args:
531    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
532    seed: A shape [2] Tensor, the seed to the random number generator. Must have
533      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
534    counts: Tensor. The counts of the binomial distribution. Must be
535      broadcastable with `probs`, and broadcastable with the rightmost
536      dimensions of `shape`.
537    probs: Tensor. The probability of success for the binomial distribution.
538      Must be broadcastable with `counts` and broadcastable with the rightmost
539      dimensions of `shape`.
540    output_dtype: The type of the output. Default: tf.int32
541    name: A name for the operation (optional).
542
543  Returns:
544    samples: A Tensor of the specified shape filled with random binomial
545      values.  For each i, each samples[..., i] is an independent draw from
546      the binomial distribution on counts[i] trials with probability of
547      success probs[i].
548
549  """
550  with ops.name_scope(name, "stateless_random_binomial",
551                      [shape, seed, counts, probs]) as name:
552    shape = tensor_util.shape_tensor(shape)
553    probs = ops.convert_to_tensor(
554        probs, dtype_hint=dtypes.float32, name="probs")
555    counts = ops.convert_to_tensor(
556        counts, dtype_hint=probs.dtype, name="counts")
557    result = gen_stateless_random_ops.stateless_random_binomial(
558        shape=shape, seed=seed, counts=counts, probs=probs, dtype=output_dtype)
559    tensor_util.maybe_set_static_shape(result, shape)
560    return result
561
562
563@tf_export("random.stateless_gamma")
564@dispatch.add_dispatch_support
565def stateless_random_gamma(shape,
566                           seed,
567                           alpha,
568                           beta=None,
569                           dtype=dtypes.float32,
570                           name=None):
571  """Outputs deterministic pseudorandom values from a gamma distribution.
572
573  The generated values follow a gamma distribution with specified concentration
574  (`alpha`) and inverse scale (`beta`) parameters.
575
576  This is a stateless version of `tf.random.gamma`: if run twice with the same
577  seeds and shapes, it will produce the same pseudorandom numbers. The output is
578  consistent across multiple runs on the same hardware (and between CPU and
579  GPU),
580  but may change between versions of TensorFlow or on non-CPU/GPU hardware.
581
582  A slight difference exists in the interpretation of the `shape` parameter
583  between `stateless_gamma` and `gamma`: in `gamma`, the `shape` is always
584  prepended to the shape of the broadcast of `alpha` with `beta`; whereas in
585  `stateless_gamma` the `shape` parameter must always encompass the shapes of
586  each of `alpha` and `beta` (which must broadcast together to match the
587  trailing dimensions of `shape`).
588
589  Note: Because internal calculations are done using `float64` and casting has
590  `floor` semantics, we must manually map zero outcomes to the smallest
591  possible positive floating-point value, i.e., `np.finfo(dtype).tiny`.  This
592  means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise
593  should.  This bias can only happen for small values of `alpha`, i.e.,
594  `alpha << 1` or large values of `beta`, i.e., `beta >> 1`.
595
596  The samples are differentiable w.r.t. alpha and beta.
597  The derivatives are computed using the approach described in
598  (Figurnov et al., 2018).
599
600  Example:
601
602  ```python
603  samples = tf.random.stateless_gamma([10, 2], seed=[12, 34], alpha=[0.5, 1.5])
604  # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
605  # the samples drawn from each distribution
606
607  samples = tf.random.stateless_gamma([7, 5, 2], seed=[12, 34], alpha=[.5, 1.5])
608  # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
609  # represents the 7x5 samples drawn from each of the two distributions
610
611  alpha = tf.constant([[1.], [3.], [5.]])
612  beta = tf.constant([[3., 4.]])
613  samples = tf.random.stateless_gamma(
614      [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)
615  # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions.
616
617  with tf.GradientTape() as tape:
618    tape.watch([alpha, beta])
619    loss = tf.reduce_mean(tf.square(tf.random.stateless_gamma(
620        [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)))
621  dloss_dalpha, dloss_dbeta = tape.gradient(loss, [alpha, beta])
622  # unbiased stochastic derivatives of the loss function
623  alpha.shape == dloss_dalpha.shape  # True
624  beta.shape == dloss_dbeta.shape  # True
625  ```
626
627  Args:
628    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
629    seed: A shape [2] Tensor, the seed to the random number generator. Must have
630      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
631    alpha: Tensor. The concentration parameter of the gamma distribution. Must
632      be broadcastable with `beta`, and broadcastable with the rightmost
633      dimensions of `shape`.
634    beta: Tensor. The inverse scale parameter of the gamma distribution. Must be
635      broadcastable with `alpha` and broadcastable with the rightmost dimensions
636      of `shape`.
637    dtype: Floating point dtype of `alpha`, `beta`, and the output.
638    name: A name for the operation (optional).
639
640  Returns:
641    samples: A Tensor of the specified shape filled with random gamma values.
642      For each i, each `samples[..., i] is an independent draw from the gamma
643      distribution with concentration alpha[i] and scale beta[i].
644
645  """
646  with ops.name_scope(name, "stateless_random_gamma",
647                      [shape, seed, alpha, beta]) as name:
648    shape = tensor_util.shape_tensor(shape)
649    alpha = ops.convert_to_tensor(alpha, dtype=dtype, name="alpha")
650    beta = ops.convert_to_tensor(
651        beta if beta is not None else 1, name="beta", dtype=dtype)
652    broadcast_shape = array_ops.broadcast_dynamic_shape(
653        array_ops.shape(alpha), array_ops.shape(beta))
654    alpha_broadcast = array_ops.broadcast_to(alpha, broadcast_shape)
655    result = math_ops.maximum(
656        np.finfo(alpha.dtype.as_numpy_dtype).tiny,
657        gen_stateless_random_ops.stateless_random_gamma_v2(
658            shape, seed=seed, alpha=alpha_broadcast) / beta)
659    tensor_util.maybe_set_static_shape(result, shape)
660    return result
661
662
663@tf_export("random.stateless_poisson")
664@dispatch.add_dispatch_support
665def stateless_random_poisson(shape,
666                             seed,
667                             lam,
668                             dtype=dtypes.int32,
669                             name=None):
670  """Outputs deterministic pseudorandom values from a Poisson distribution.
671
672  The generated values follow a Poisson distribution with specified rate
673  parameter.
674
675  This is a stateless version of `tf.random.poisson`: if run twice with the same
676  seeds and shapes, it will produce the same pseudorandom numbers. The output is
677  consistent across multiple runs on the same hardware, but may change between
678  versions of TensorFlow or on non-CPU/GPU hardware.
679
680  A slight difference exists in the interpretation of the `shape` parameter
681  between `stateless_poisson` and `poisson`: in `poisson`, the `shape` is always
682  prepended to the shape of `lam`; whereas in `stateless_poisson` the shape of
683  `lam` must match the trailing dimensions of `shape`.
684
685  Example:
686
687  ```python
688  samples = tf.random.stateless_poisson([10, 2], seed=[12, 34], lam=[5, 15])
689  # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
690  # the samples drawn from each distribution
691
692  samples = tf.random.stateless_poisson([7, 5, 2], seed=[12, 34], lam=[5, 15])
693  # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
694  # represents the 7x5 samples drawn from each of the two distributions
695
696  rate = tf.constant([[1.], [3.], [5.]])
697  samples = tf.random.stateless_poisson([30, 3, 1], seed=[12, 34], lam=rate)
698  # samples has shape [30, 3, 1], with 30 samples each of 3x1 distributions.
699  ```
700
701  Args:
702    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
703    seed: A shape [2] Tensor, the seed to the random number generator. Must have
704      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
705    lam: Tensor. The rate parameter "lambda" of the Poisson distribution. Shape
706      must match the rightmost dimensions of `shape`.
707    dtype: Dtype of the samples (int or float dtypes are permissible, as samples
708      are discrete). Default: int32.
709    name: A name for the operation (optional).
710
711  Returns:
712    samples: A Tensor of the specified shape filled with random Poisson values.
713      For each i, each `samples[..., i]` is an independent draw from the Poisson
714      distribution with rate `lam[i]`.
715
716  """
717  with ops.name_scope(name, "stateless_random_poisson",
718                      [shape, seed, lam]) as name:
719    shape = tensor_util.shape_tensor(shape)
720    result = gen_stateless_random_ops.stateless_random_poisson(
721        shape, seed=seed, lam=lam, dtype=dtype)
722    tensor_util.maybe_set_static_shape(result, shape)
723    return result
724
725
726@tf_export("random.stateless_normal")
727@dispatch.add_dispatch_support
728def stateless_random_normal(shape,
729                            seed,
730                            mean=0.0,
731                            stddev=1.0,
732                            dtype=dtypes.float32,
733                            name=None,
734                            alg="auto_select"):
735  """Outputs deterministic pseudorandom values from a normal distribution.
736
737  This is a stateless version of `tf.random.normal`: if run twice with the
738  same seeds and shapes, it will produce the same pseudorandom numbers.  The
739  output is consistent across multiple runs on the same hardware (and between
740  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
741  hardware.
742
743  Args:
744    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
745    seed: A shape [2] Tensor, the seed to the random number generator. Must have
746      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
747    mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
748      distribution.
749    stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
750      of the normal distribution.
751    dtype: The float type of the output: `float16`, `bfloat16`, `float32`,
752      `float64`. Defaults to `float32`.
753    name: A name for the operation (optional).
754    alg: The RNG algorithm used to generate the random numbers. See
755      `tf.random.stateless_uniform` for a detailed explanation.
756
757  Returns:
758    A tensor of the specified shape filled with random normal values.
759  """
760  with ops.name_scope(name, "stateless_random_normal",
761                      [shape, seed, mean, stddev]) as name:
762    shape = tensor_util.shape_tensor(shape)
763    mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
764    stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
765    key, counter, alg = _get_key_counter_alg(seed, alg)
766    rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
767        shape, key=key, counter=counter, dtype=dtype, alg=alg)
768    result = math_ops.add(rnd * stddev, mean, name=name)
769    tensor_util.maybe_set_static_shape(result, shape)
770    return result
771
772
773@tf_export("random.stateless_truncated_normal")
774@dispatch.add_dispatch_support
775def stateless_truncated_normal(shape,
776                               seed,
777                               mean=0.0,
778                               stddev=1.0,
779                               dtype=dtypes.float32,
780                               name=None,
781                               alg="auto_select"):
782  """Outputs deterministic pseudorandom values, truncated normally distributed.
783
784  This is a stateless version of `tf.random.truncated_normal`: if run twice with
785  the same seeds and shapes, it will produce the same pseudorandom numbers.  The
786  output is consistent across multiple runs on the same hardware (and between
787  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
788  hardware.
789
790  The generated values follow a normal distribution with specified mean and
791  standard deviation, except that values whose magnitude is more than 2 standard
792  deviations from the mean are dropped and re-picked.
793
794  Args:
795    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
796    seed: A shape [2] Tensor, the seed to the random number generator. Must have
797      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
798    mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
799      truncated normal distribution.
800    stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
801      of the normal distribution, before truncation.
802    dtype: The type of the output.
803    name: A name for the operation (optional).
804    alg: The RNG algorithm used to generate the random numbers. See
805      `tf.random.stateless_uniform` for a detailed explanation.
806
807  Returns:
808    A tensor of the specified shape filled with random truncated normal values.
809  """
810  with ops.name_scope(name, "stateless_truncated_normal",
811                      [shape, seed, mean, stddev]) as name:
812    shape = tensor_util.shape_tensor(shape)
813    mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
814    stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
815    key, counter, alg = _get_key_counter_alg(seed, alg)
816    rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
817        shape, key=key, counter=counter, dtype=dtype, alg=alg)
818    result = math_ops.add(rnd * stddev, mean, name=name)
819    tensor_util.maybe_set_static_shape(result, shape)
820    return result
821
822
823@tf_export(v1=["random.stateless_multinomial"])
824@dispatch.add_dispatch_support
825@deprecation.deprecated(
826    date=None, instructions="Use `tf.random.stateless_categorical` instead.")
827def stateless_multinomial(logits,
828                          num_samples,
829                          seed,
830                          output_dtype=dtypes.int64,
831                          name=None):
832  """Draws deterministic pseudorandom samples from a multinomial distribution.
833
834  This is a stateless version of `tf.random.categorical`: if run twice with the
835  same seeds and shapes, it will produce the same pseudorandom numbers.  The
836  output is consistent across multiple runs on the same hardware (and between
837  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
838  hardware.
839
840  Example:
841
842  ```python
843  # samples has shape [1, 5], where each value is either 0 or 1 with equal
844  # probability.
845  samples = tf.random.stateless_categorical(
846      tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17])
847  ```
848
849  Args:
850    logits: 2-D Tensor with shape `[batch_size, num_classes]`.  Each slice
851      `[i, :]` represents the unnormalized log-probabilities for all classes.
852    num_samples: 0-D.  Number of independent samples to draw for each row slice.
853    seed: A shape [2] Tensor, the seed to the random number generator. Must have
854      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
855    output_dtype: The integer type of the output: `int32` or `int64`. Defaults
856      to `int64`.
857    name: Optional name for the operation.
858
859  Returns:
860    The drawn samples of shape `[batch_size, num_samples]`.
861  """
862  with ops.name_scope(name, "stateless_multinomial", [logits, seed]):
863    return stateless_multinomial_categorical_impl(logits, num_samples,
864                                                  output_dtype, seed)
865
866
867@tf_export("random.stateless_categorical")
868@dispatch.add_dispatch_support
869def stateless_categorical(logits,
870                          num_samples,
871                          seed,
872                          dtype=dtypes.int64,
873                          name=None):
874  """Draws deterministic pseudorandom samples from a categorical distribution.
875
876  This is a stateless version of `tf.categorical`: if run twice with the
877  same seeds and shapes, it will produce the same pseudorandom numbers.  The
878  output is consistent across multiple runs on the same hardware (and between
879  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
880  hardware.
881
882
883  Example:
884
885  ```python
886  # samples has shape [1, 5], where each value is either 0 or 1 with equal
887  # probability.
888  samples = tf.random.stateless_categorical(
889      tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17])
890  ```
891
892  Args:
893    logits: 2-D Tensor with shape `[batch_size, num_classes]`.  Each slice
894      `[i, :]` represents the unnormalized log-probabilities for all classes.
895    num_samples: 0-D.  Number of independent samples to draw for each row slice.
896    seed: A shape [2] Tensor, the seed to the random number generator. Must have
897      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
898    dtype: The integer type of the output: `int32` or `int64`. Defaults to
899      `int64`.
900    name: Optional name for the operation.
901
902  Returns:
903    The drawn samples of shape `[batch_size, num_samples]`.
904  """
905  with ops.name_scope(name, "stateless_categorical", [logits, seed]):
906    return stateless_multinomial_categorical_impl(logits, num_samples, dtype,
907                                                  seed)
908
909
910def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed):
911  """Implementation for stateless multinomial/categorical ops (v1/v2)."""
912  logits = ops.convert_to_tensor(logits, name="logits")
913  dtype = dtypes.as_dtype(dtype) if dtype else dtypes.int64
914  accepted_dtypes = (dtypes.int32, dtypes.int64)
915  if dtype not in accepted_dtypes:
916    raise ValueError(
917        f"Argument `dtype` got invalid value {dtype}. Accepted dtypes are "
918        f"{accepted_dtypes}.")
919  return gen_stateless_random_ops.stateless_multinomial(
920      logits, num_samples, seed, output_dtype=dtype)
921
922
923@dispatch.add_dispatch_support
924@tf_export("random.stateless_parameterized_truncated_normal")
925def stateless_parameterized_truncated_normal(shape,
926                                             seed,
927                                             means=0.0,
928                                             stddevs=1.0,
929                                             minvals=-2.0,
930                                             maxvals=2.0,
931                                             name=None):
932  """Outputs random values from a truncated normal distribution.
933
934  The generated values follow a normal distribution with specified mean and
935  standard deviation, except that values whose magnitude is more than 2 standard
936  deviations from the mean are dropped and re-picked.
937
938
939  Examples:
940
941  Sample from a Truncated normal, with deferring shape parameters that
942  broadcast.
943
944  >>> means = 0.
945  >>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3]))
946  >>> minvals = [-1., -2., -1000.]
947  >>> maxvals = [[10000.], [1.]]
948  >>> y = tf.random.stateless_parameterized_truncated_normal(
949  ...   shape=[10, 2, 3], seed=[7, 17],
950  ...   means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals)
951  >>> y.shape
952  TensorShape([10, 2, 3])
953
954  Args:
955    shape: A 1-D integer `Tensor` or Python array. The shape of the output
956      tensor.
957    seed: A shape [2] Tensor, the seed to the random number generator. Must have
958      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
959    means: A `Tensor` or Python value of type `dtype`. The mean of the truncated
960      normal distribution. This must broadcast with `stddevs`, `minvals` and
961      `maxvals`, and the broadcasted shape must be dominated by `shape`.
962    stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation
963      of the truncated normal distribution. This must broadcast with `means`,
964      `minvals` and `maxvals`, and the broadcasted shape must be dominated by
965      `shape`.
966    minvals: A `Tensor` or Python value of type `dtype`. The minimum value of
967      the truncated normal distribution. This must broadcast with `means`,
968      `stddevs` and `maxvals`, and the broadcasted shape must be dominated by
969      `shape`.
970    maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of
971      the truncated normal distribution. This must broadcast with `means`,
972      `stddevs` and `minvals`, and the broadcasted shape must be dominated by
973      `shape`.
974    name: A name for the operation (optional).
975
976  Returns:
977    A tensor of the specified shape filled with random truncated normal values.
978  """
979  with ops.name_scope(name, "stateless_parameterized_truncated_normal",
980                      [shape, means, stddevs, minvals, maxvals]) as name:
981    shape_tensor = tensor_util.shape_tensor(shape)
982    means_tensor = ops.convert_to_tensor(means, name="means")
983    stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs")
984    minvals_tensor = ops.convert_to_tensor(minvals, name="minvals")
985    maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals")
986    rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal(
987        shape_tensor, seed, means_tensor, stddevs_tensor, minvals_tensor,
988        maxvals_tensor)
989    tensor_util.maybe_set_static_shape(rnd, shape)
990    return rnd
991