• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Stateless random ops which take seed as a tensor input."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.compat import compat
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gen_stateless_random_ops
29from tensorflow.python.ops import gen_stateless_random_ops_v2
30from tensorflow.python.ops import math_ops
31from tensorflow.python.util import deprecation
32from tensorflow.python.util import dispatch
33from tensorflow.python.util.tf_export import tf_export
34
35
36ops.NotDifferentiable("StatelessMultinomial")
37ops.NotDifferentiable("StatelessRandomBinomial")
38ops.NotDifferentiable("StatelessRandomNormal")
39ops.NotDifferentiable("StatelessRandomPoisson")
40ops.NotDifferentiable("StatelessRandomUniform")
41ops.NotDifferentiable("StatelessRandomUniformInt")
42ops.NotDifferentiable("StatelessRandomUniformFullInt")
43ops.NotDifferentiable("StatelessTruncatedNormal")
44
45
46ops.NotDifferentiable("StatelessRandomNormalV2")
47ops.NotDifferentiable("StatelessRandomUniformV2")
48ops.NotDifferentiable("StatelessRandomUniformIntV2")
49ops.NotDifferentiable("StatelessRandomUniformFullIntV2")
50ops.NotDifferentiable("StatelessTruncatedNormalV2")
51
52
53@tf_export("random.experimental.stateless_split")
54@dispatch.add_dispatch_support
55def split(seed, num=2):
56  """Splits an RNG seed into `num` new seeds by adding a leading axis.
57
58  Example:
59
60  >>> seed = [1, 2]
61  >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3)
62  >>> print(new_seeds)
63  tf.Tensor(
64  [[1105988140 1738052849]
65   [-335576002  370444179]
66   [  10670227 -246211131]], shape=(3, 2), dtype=int32)
67  >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :])
68  <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.59835213, -0.9578608 ,
69  0.9002807 ], dtype=float32)>
70
71  Args:
72    seed: an RNG seed (a tensor with shape [2] and dtype `int32` or
73      `int64`). (When using XLA, only `int32` is allowed.)
74    num: optional, a positive integer or scalar tensor indicating the number of
75      seeds to produce (default 2).
76
77  Returns:
78    A tensor with shape [num, 2] representing `num` new seeds. It will have the
79    same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype
80    will be determined by `tf.convert_to_tensor`).
81  """
82  seed = ops.convert_to_tensor(seed)
83  return stateless_random_uniform(shape=[num, 2], seed=seed, dtype=seed.dtype,
84                                  minval=None, maxval=None)
85
86
87@tf_export("random.experimental.stateless_fold_in")
88@dispatch.add_dispatch_support
89def fold_in(seed, data):
90  """Folds in data to an RNG seed to form a new RNG seed.
91
92  For example, in a distributed-training setting, suppose we have a master seed
93  and a replica ID. We want to fold the replica ID into the master seed to
94  form a "replica seed" to be used by that replica later on, so that different
95  replicas will generate different random numbers but the reproducibility of the
96  whole system can still be controlled by the master seed:
97
98  >>> master_seed = [1, 2]
99  >>> replica_id = 3
100  >>> replica_seed = tf.random.experimental.stateless_fold_in(
101  ...   master_seed, replica_id)
102  >>> print(replica_seed)
103  tf.Tensor([1105988140          3], shape=(2,), dtype=int32)
104  >>> tf.random.stateless_normal(shape=[3], seed=replica_seed)
105  <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.03197195, 0.8979765 ,
106  0.13253039], dtype=float32)>
107
108  Args:
109    seed: an RNG seed (a tensor with shape [2] and dtype `int32` or
110      `int64`). (When using XLA, only `int32` is allowed.)
111    data: an `int32` or `int64` scalar representing data to be folded in to the
112      seed.
113
114  Returns:
115    A new RNG seed that is a deterministic function of the inputs and is
116    statistically safe for producing a stream of new pseudo-random values. It
117    will have the same dtype as `data` (if `data` doesn't have an explict dtype,
118    the dtype will be determined by `tf.convert_to_tensor`).
119  """
120  data = ops.convert_to_tensor(data)
121  seed1 = stateless_random_uniform(shape=[], seed=seed, dtype=data.dtype,
122                                   minval=None, maxval=None)
123  return array_ops.stack([seed1, data])
124
125
126def _get_key_counter_alg(seed):
127  if compat.forward_compatible(2021, 3, 1):
128    key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
129        seed)
130    alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
131    return key, counter, alg
132  else:
133    return gen_stateless_random_ops_v2.stateless_random_get_key_counter_alg(
134        seed)
135
136
137@tf_export("random.stateless_uniform")
138@dispatch.add_dispatch_support
139def stateless_random_uniform(shape,
140                             seed,
141                             minval=0,
142                             maxval=None,
143                             dtype=dtypes.float32,
144                             name=None):
145  """Outputs deterministic pseudorandom values from a uniform distribution.
146
147  This is a stateless version of `tf.random.uniform`: if run twice with the
148  same seeds and shapes, it will produce the same pseudorandom numbers.  The
149  output is consistent across multiple runs on the same hardware (and between
150  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
151  hardware.
152
153  The generated values follow a uniform distribution in the range
154  `[minval, maxval)`. The lower bound `minval` is included in the range, while
155  the upper bound `maxval` is excluded.
156
157  For floats, the default range is `[0, 1)`.  For ints, at least `maxval` must
158  be specified explicitly.
159
160  In the integer case, the random integers are slightly biased unless
161  `maxval - minval` is an exact power of two.  The bias is small for values of
162  `maxval - minval` significantly smaller than the range of the output (either
163  `2**32` or `2**64`).
164
165  For full-range (i.e. inclusive of both max and min) random integers, pass
166  `minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype
167  either both `minval` and `maxval` must be `None` or neither may be `None`. For
168  example:
169  ```python
170  ints = tf.random.stateless_uniform(
171      [10], seed=(2, 3), minval=None, maxval=None, dtype=tf.int32)
172  ```
173
174  Args:
175    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
176    seed: A shape [2] Tensor, the seed to the random number generator. Must have
177      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
178    minval: A Tensor or Python value of type `dtype`, broadcastable with
179      `shape` (for integer types, broadcasting is not supported, so it needs to
180      be a scalar). The lower bound on the range of random values to
181      generate. Pass `None` for full-range integers.  Defaults to 0.
182    maxval: A Tensor or Python value of type `dtype`, broadcastable with
183      `shape` (for integer types, broadcasting is not supported, so it needs to
184      be a scalar). The upper bound on the range of random values to generate.
185      Defaults to 1 if `dtype` is floating point. Pass `None` for full-range
186      integers.
187    dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or
188      `int64`. For unbounded uniform ints (`minval`, `maxval` both `None`),
189      `uint32` and `uint64` may be used.
190    name: A name for the operation (optional).
191
192  Returns:
193    A tensor of the specified shape filled with random uniform values.
194
195  Raises:
196    ValueError: If `dtype` is integral and only one of `minval` or `maxval` is
197      specified.
198  """
199  dtype = dtypes.as_dtype(dtype)
200  if dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32,
201                   dtypes.float64, dtypes.int32, dtypes.int64, dtypes.uint32,
202                   dtypes.uint64):
203    raise ValueError("Invalid dtype %r" % dtype)
204  if dtype.is_integer:
205    if (minval is None) != (maxval is None):
206      raise ValueError("For integer dtype {}, minval and maxval must be both "
207                       "`None` or both non-`None`.".format(dtype))
208    if minval is not None and dtype in (dtypes.uint32, dtypes.uint64):
209      raise ValueError("Invalid dtype for bounded uniform integers: %r" % dtype)
210  elif maxval is None:
211    maxval = 1
212  with ops.name_scope(name, "stateless_random_uniform",
213                      [shape, seed, minval, maxval]) as name:
214    shape = tensor_util.shape_tensor(shape)
215    if dtype.is_integer and minval is None:
216      if compat.forward_compatible(2020, 10, 25):
217        key, counter, alg = _get_key_counter_alg(seed)
218        result = (gen_stateless_random_ops_v2
219                  .stateless_random_uniform_full_int_v2(
220                      shape, key=key, counter=counter, dtype=dtype, alg=alg,
221                      name=name))
222      else:
223        result = gen_stateless_random_ops.stateless_random_uniform_full_int(
224            shape, seed=seed, dtype=dtype, name=name)
225    else:
226      minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
227      maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
228      if dtype.is_integer:
229        if compat.forward_compatible(2020, 10, 25):
230          key, counter, alg = _get_key_counter_alg(seed)
231          result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
232              shape, key=key, counter=counter, minval=minval, maxval=maxval,
233              alg=alg, name=name)
234        else:
235          result = gen_stateless_random_ops.stateless_random_uniform_int(
236              shape, seed=seed, minval=minval, maxval=maxval, name=name)
237      else:
238        if compat.forward_compatible(2020, 10, 25):
239          key, counter, alg = _get_key_counter_alg(seed)
240          rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2(
241              shape, key=key, counter=counter, dtype=dtype, alg=alg)
242        else:
243          rnd = gen_stateless_random_ops.stateless_random_uniform(
244              shape, seed=seed, dtype=dtype)
245        result = math_ops.add(rnd * (maxval - minval), minval, name=name)
246    tensor_util.maybe_set_static_shape(result, shape)
247    return result
248
249
250@tf_export("random.stateless_binomial")
251@dispatch.add_dispatch_support
252def stateless_random_binomial(shape,
253                              seed,
254                              counts,
255                              probs,
256                              output_dtype=dtypes.int32,
257                              name=None):
258  """Outputs deterministic pseudorandom values from a binomial distribution.
259
260  The generated values follow a binomial distribution with specified count and
261  probability of success parameters.
262
263  This is a stateless version of `tf.random.Generator.binomial`: if run twice
264  with the same seeds and shapes, it will produce the same pseudorandom numbers.
265  The output is consistent across multiple runs on the same hardware (and
266  between CPU and GPU), but may change between versions of TensorFlow or on
267  non-CPU/GPU hardware.
268
269  Example:
270
271  ```python
272  counts = [10., 20.]
273  # Probability of success.
274  probs = [0.8]
275
276  binomial_samples = tf.random.stateless_binomial(
277      shape=[2], seed=[123, 456], counts=counts, probs=probs)
278
279  counts = ... # Shape [3, 1, 2]
280  probs = ...  # Shape [1, 4, 2]
281  shape = [3, 4, 3, 4, 2]
282  # Sample shape will be [3, 4, 3, 4, 2]
283  binomial_samples = tf.random.stateless_binomial(
284      shape=shape, seed=[123, 456], counts=counts, probs=probs)
285  ```
286
287  Args:
288    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
289    seed: A shape [2] Tensor, the seed to the random number generator. Must have
290      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
291    counts: Tensor. The counts of the binomial distribution. Must be
292      broadcastable with `probs`, and broadcastable with the rightmost
293      dimensions of `shape`.
294    probs: Tensor. The probability of success for the binomial distribution.
295      Must be broadcastable with `counts` and broadcastable with the rightmost
296      dimensions of `shape`.
297    output_dtype: The type of the output. Default: tf.int32
298    name: A name for the operation (optional).
299
300  Returns:
301    samples: A Tensor of the specified shape filled with random binomial
302      values.  For each i, each samples[..., i] is an independent draw from
303      the binomial distribution on counts[i] trials with probability of
304      success probs[i].
305
306  """
307  with ops.name_scope(name, "stateless_random_binomial",
308                      [shape, seed, counts, probs]) as name:
309    shape = tensor_util.shape_tensor(shape)
310    probs = ops.convert_to_tensor(
311        probs, dtype_hint=dtypes.float32, name="probs")
312    counts = ops.convert_to_tensor(
313        counts, dtype_hint=probs.dtype, name="counts")
314    result = gen_stateless_random_ops.stateless_random_binomial(
315        shape=shape, seed=seed, counts=counts, probs=probs, dtype=output_dtype)
316    tensor_util.maybe_set_static_shape(result, shape)
317    return result
318
319
320@tf_export("random.stateless_gamma")
321@dispatch.add_dispatch_support
322def stateless_random_gamma(shape,
323                           seed,
324                           alpha,
325                           beta=None,
326                           dtype=dtypes.float32,
327                           name=None):
328  """Outputs deterministic pseudorandom values from a gamma distribution.
329
330  The generated values follow a gamma distribution with specified concentration
331  (`alpha`) and inverse scale (`beta`) parameters.
332
333  This is a stateless version of `tf.random.gamma`: if run twice with the same
334  seeds and shapes, it will produce the same pseudorandom numbers. The output is
335  consistent across multiple runs on the same hardware (and between CPU and
336  GPU),
337  but may change between versions of TensorFlow or on non-CPU/GPU hardware.
338
339  A slight difference exists in the interpretation of the `shape` parameter
340  between `stateless_gamma` and `gamma`: in `gamma`, the `shape` is always
341  prepended to the shape of the broadcast of `alpha` with `beta`; whereas in
342  `stateless_gamma` the `shape` parameter must always encompass the shapes of
343  each of `alpha` and `beta` (which must broadcast together to match the
344  trailing dimensions of `shape`).
345
346  Note: Because internal calculations are done using `float64` and casting has
347  `floor` semantics, we must manually map zero outcomes to the smallest
348  possible positive floating-point value, i.e., `np.finfo(dtype).tiny`.  This
349  means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise
350  should.  This bias can only happen for small values of `alpha`, i.e.,
351  `alpha << 1` or large values of `beta`, i.e., `beta >> 1`.
352
353  The samples are differentiable w.r.t. alpha and beta.
354  The derivatives are computed using the approach described in
355  (Figurnov et al., 2018).
356
357  Example:
358
359  ```python
360  samples = tf.random.stateless_gamma([10, 2], seed=[12, 34], alpha=[0.5, 1.5])
361  # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
362  # the samples drawn from each distribution
363
364  samples = tf.random.stateless_gamma([7, 5, 2], seed=[12, 34], alpha=[.5, 1.5])
365  # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
366  # represents the 7x5 samples drawn from each of the two distributions
367
368  alpha = tf.constant([[1.], [3.], [5.]])
369  beta = tf.constant([[3., 4.]])
370  samples = tf.random.stateless_gamma(
371      [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)
372  # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions.
373
374  with tf.GradientTape() as tape:
375    tape.watch([alpha, beta])
376    loss = tf.reduce_mean(tf.square(tf.random.stateless_gamma(
377        [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta)))
378  dloss_dalpha, dloss_dbeta = tape.gradient(loss, [alpha, beta])
379  # unbiased stochastic derivatives of the loss function
380  alpha.shape == dloss_dalpha.shape  # True
381  beta.shape == dloss_dbeta.shape  # True
382  ```
383
384  Args:
385    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
386    seed: A shape [2] Tensor, the seed to the random number generator. Must have
387      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
388    alpha: Tensor. The concentration parameter of the gamma distribution. Must
389      be broadcastable with `beta`, and broadcastable with the rightmost
390      dimensions of `shape`.
391    beta: Tensor. The inverse scale parameter of the gamma distribution. Must be
392      broadcastable with `alpha` and broadcastable with the rightmost dimensions
393      of `shape`.
394    dtype: Floating point dtype of `alpha`, `beta`, and the output.
395    name: A name for the operation (optional).
396
397  Returns:
398    samples: A Tensor of the specified shape filled with random gamma values.
399      For each i, each `samples[..., i] is an independent draw from the gamma
400      distribution with concentration alpha[i] and scale beta[i].
401
402  """
403  with ops.name_scope(name, "stateless_random_gamma",
404                      [shape, seed, alpha, beta]) as name:
405    shape = tensor_util.shape_tensor(shape)
406    alpha = ops.convert_to_tensor(alpha, dtype=dtype, name="alpha")
407    beta = ops.convert_to_tensor(
408        beta if beta is not None else 1, name="beta", dtype=dtype)
409    broadcast_shape = array_ops.broadcast_dynamic_shape(
410        array_ops.shape(alpha), array_ops.shape(beta))
411    alpha_broadcast = array_ops.broadcast_to(alpha, broadcast_shape)
412    result = math_ops.maximum(
413        np.finfo(alpha.dtype.as_numpy_dtype).tiny,
414        gen_stateless_random_ops.stateless_random_gamma_v2(
415            shape, seed=seed, alpha=alpha_broadcast) / beta)
416    tensor_util.maybe_set_static_shape(result, shape)
417    return result
418
419
420@tf_export("random.stateless_poisson")
421@dispatch.add_dispatch_support
422def stateless_random_poisson(shape,
423                             seed,
424                             lam,
425                             dtype=dtypes.int32,
426                             name=None):
427  """Outputs deterministic pseudorandom values from a Poisson distribution.
428
429  The generated values follow a Poisson distribution with specified rate
430  parameter.
431
432  This is a stateless version of `tf.random.poisson`: if run twice with the same
433  seeds and shapes, it will produce the same pseudorandom numbers. The output is
434  consistent across multiple runs on the same hardware, but may change between
435  versions of TensorFlow or on non-CPU/GPU hardware.
436
437  A slight difference exists in the interpretation of the `shape` parameter
438  between `stateless_poisson` and `poisson`: in `poisson`, the `shape` is always
439  prepended to the shape of `lam`; whereas in `stateless_poisson` the shape of
440  `lam` must match the trailing dimensions of `shape`.
441
442  Example:
443
444  ```python
445  samples = tf.random.stateless_poisson([10, 2], seed=[12, 34], lam=[5, 15])
446  # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents
447  # the samples drawn from each distribution
448
449  samples = tf.random.stateless_poisson([7, 5, 2], seed=[12, 34], lam=[5, 15])
450  # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1]
451  # represents the 7x5 samples drawn from each of the two distributions
452
453  rate = tf.constant([[1.], [3.], [5.]])
454  samples = tf.random.stateless_poisson([30, 3, 1], seed=[12, 34], lam=rate)
455  # samples has shape [30, 3, 1], with 30 samples each of 3x1 distributions.
456  ```
457
458  Args:
459    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
460    seed: A shape [2] Tensor, the seed to the random number generator. Must have
461      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
462    lam: Tensor. The rate parameter "lambda" of the Poisson distribution. Shape
463      must match the rightmost dimensions of `shape`.
464    dtype: Dtype of the samples (int or float dtypes are permissible, as samples
465      are discrete). Default: int32.
466    name: A name for the operation (optional).
467
468  Returns:
469    samples: A Tensor of the specified shape filled with random Poisson values.
470      For each i, each `samples[..., i]` is an independent draw from the Poisson
471      distribution with rate `lam[i]`.
472
473  """
474  with ops.name_scope(name, "stateless_random_poisson",
475                      [shape, seed, lam]) as name:
476    shape = tensor_util.shape_tensor(shape)
477    result = gen_stateless_random_ops.stateless_random_poisson(
478        shape, seed=seed, lam=lam, dtype=dtype)
479    tensor_util.maybe_set_static_shape(result, shape)
480    return result
481
482
483@tf_export("random.stateless_normal")
484@dispatch.add_dispatch_support
485def stateless_random_normal(shape,
486                            seed,
487                            mean=0.0,
488                            stddev=1.0,
489                            dtype=dtypes.float32,
490                            name=None):
491  """Outputs deterministic pseudorandom values from a normal distribution.
492
493  This is a stateless version of `tf.random.normal`: if run twice with the
494  same seeds and shapes, it will produce the same pseudorandom numbers.  The
495  output is consistent across multiple runs on the same hardware (and between
496  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
497  hardware.
498
499  Args:
500    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
501    seed: A shape [2] Tensor, the seed to the random number generator. Must have
502      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
503    mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
504      distribution.
505    stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
506      of the normal distribution.
507    dtype: The type of the output.
508    name: A name for the operation (optional).
509
510  Returns:
511    A tensor of the specified shape filled with random normal values.
512  """
513  with ops.name_scope(name, "stateless_random_normal",
514                      [shape, seed, mean, stddev]) as name:
515    shape = tensor_util.shape_tensor(shape)
516    mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
517    stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
518    if compat.forward_compatible(2021, 3, 1):
519      key, counter, alg = _get_key_counter_alg(seed)
520      rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2(
521          shape, key=key, counter=counter, dtype=dtype, alg=alg)
522    else:
523      rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
524    result = math_ops.add(rnd * stddev, mean, name=name)
525    tensor_util.maybe_set_static_shape(result, shape)
526    return result
527
528
529@tf_export("random.stateless_truncated_normal")
530@dispatch.add_dispatch_support
531def stateless_truncated_normal(shape,
532                               seed,
533                               mean=0.0,
534                               stddev=1.0,
535                               dtype=dtypes.float32,
536                               name=None):
537  """Outputs deterministic pseudorandom values, truncated normally distributed.
538
539  This is a stateless version of `tf.random.truncated_normal`: if run twice with
540  the same seeds and shapes, it will produce the same pseudorandom numbers.  The
541  output is consistent across multiple runs on the same hardware (and between
542  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
543  hardware.
544
545  The generated values follow a normal distribution with specified mean and
546  standard deviation, except that values whose magnitude is more than 2 standard
547  deviations from the mean are dropped and re-picked.
548
549  Args:
550    shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
551    seed: A shape [2] Tensor, the seed to the random number generator. Must have
552      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
553    mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
554      truncated normal distribution.
555    stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
556      of the normal distribution, before truncation.
557    dtype: The type of the output.
558    name: A name for the operation (optional).
559
560  Returns:
561    A tensor of the specified shape filled with random truncated normal values.
562  """
563  with ops.name_scope(name, "stateless_truncated_normal",
564                      [shape, seed, mean, stddev]) as name:
565    shape = tensor_util.shape_tensor(shape)
566    mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
567    stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
568    if compat.forward_compatible(2020, 10, 25):
569      key, counter, alg = _get_key_counter_alg(seed)
570      rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
571          shape, key=key, counter=counter, dtype=dtype, alg=alg)
572    else:
573      rnd = gen_stateless_random_ops.stateless_truncated_normal(
574          shape, seed, dtype)
575    result = math_ops.add(rnd * stddev, mean, name=name)
576    tensor_util.maybe_set_static_shape(result, shape)
577    return result
578
579
580@tf_export(v1=["random.stateless_multinomial"])
581@dispatch.add_dispatch_support
582@deprecation.deprecated(
583    date=None, instructions="Use `tf.random.stateless_categorical` instead.")
584def stateless_multinomial(logits,
585                          num_samples,
586                          seed,
587                          output_dtype=dtypes.int64,
588                          name=None):
589  """Draws deterministic pseudorandom samples from a multinomial distribution.
590
591  This is a stateless version of `tf.random.categorical`: if run twice with the
592  same seeds and shapes, it will produce the same pseudorandom numbers.  The
593  output is consistent across multiple runs on the same hardware (and between
594  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
595  hardware.
596
597  Example:
598
599  ```python
600  # samples has shape [1, 5], where each value is either 0 or 1 with equal
601  # probability.
602  samples = tf.random.stateless_categorical(
603      tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17])
604  ```
605
606  Args:
607    logits: 2-D Tensor with shape `[batch_size, num_classes]`.  Each slice
608      `[i, :]` represents the unnormalized log-probabilities for all classes.
609    num_samples: 0-D.  Number of independent samples to draw for each row slice.
610    seed: A shape [2] Tensor, the seed to the random number generator. Must have
611      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
612    output_dtype: integer type to use for the output. Defaults to int64.
613    name: Optional name for the operation.
614
615  Returns:
616    The drawn samples of shape `[batch_size, num_samples]`.
617  """
618  with ops.name_scope(name, "stateless_multinomial", [logits, seed]):
619    return stateless_multinomial_categorical_impl(logits, num_samples,
620                                                  output_dtype, seed)
621
622
623@tf_export("random.stateless_categorical")
624@dispatch.add_dispatch_support
625def stateless_categorical(logits,
626                          num_samples,
627                          seed,
628                          dtype=dtypes.int64,
629                          name=None):
630  """Draws deterministic pseudorandom samples from a categorical distribution.
631
632  This is a stateless version of `tf.categorical`: if run twice with the
633  same seeds and shapes, it will produce the same pseudorandom numbers.  The
634  output is consistent across multiple runs on the same hardware (and between
635  CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
636  hardware.
637
638
639  Example:
640
641  ```python
642  # samples has shape [1, 5], where each value is either 0 or 1 with equal
643  # probability.
644  samples = tf.random.stateless_categorical(
645      tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17])
646  ```
647
648  Args:
649    logits: 2-D Tensor with shape `[batch_size, num_classes]`.  Each slice
650      `[i, :]` represents the unnormalized log-probabilities for all classes.
651    num_samples: 0-D.  Number of independent samples to draw for each row slice.
652    seed: A shape [2] Tensor, the seed to the random number generator. Must have
653      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
654    dtype: integer type to use for the output. Defaults to int64.
655    name: Optional name for the operation.
656
657  Returns:
658    The drawn samples of shape `[batch_size, num_samples]`.
659  """
660  with ops.name_scope(name, "stateless_categorical", [logits, seed]):
661    return stateless_multinomial_categorical_impl(logits, num_samples, dtype,
662                                                  seed)
663
664
665def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed):
666  """Implementation for stateless multinomial/categorical ops (v1/v2)."""
667  logits = ops.convert_to_tensor(logits, name="logits")
668  return gen_stateless_random_ops.stateless_multinomial(
669      logits, num_samples, seed, output_dtype=dtype)
670
671
672@dispatch.add_dispatch_support
673@tf_export("random.stateless_parameterized_truncated_normal")
674def stateless_parameterized_truncated_normal(shape,
675                                             seed,
676                                             means=0.0,
677                                             stddevs=1.0,
678                                             minvals=-2.0,
679                                             maxvals=2.0,
680                                             name=None):
681  """Outputs random values from a truncated normal distribution.
682
683  The generated values follow a normal distribution with specified mean and
684  standard deviation, except that values whose magnitude is more than 2 standard
685  deviations from the mean are dropped and re-picked.
686
687
688  Examples:
689
690  Sample from a Truncated normal, with deferring shape parameters that
691  broadcast.
692
693  >>> means = 0.
694  >>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3]))
695  >>> minvals = [-1., -2., -1000.]
696  >>> maxvals = [[10000.], [1.]]
697  >>> y = tf.random.stateless_parameterized_truncated_normal(
698  ...   shape=[10, 2, 3], seed=[7, 17],
699  ...   means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals)
700  >>> y.shape
701  TensorShape([10, 2, 3])
702
703  Args:
704    shape: A 1-D integer `Tensor` or Python array. The shape of the output
705      tensor.
706    seed: A shape [2] Tensor, the seed to the random number generator. Must have
707      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
708    means: A `Tensor` or Python value of type `dtype`. The mean of the truncated
709      normal distribution. This must broadcast with `stddevs`, `minvals` and
710      `maxvals`, and the broadcasted shape must be dominated by `shape`.
711    stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation
712      of the truncated normal distribution. This must broadcast with `means`,
713      `minvals` and `maxvals`, and the broadcasted shape must be dominated by
714      `shape`.
715    minvals: A `Tensor` or Python value of type `dtype`. The minimum value of
716      the truncated normal distribution. This must broadcast with `means`,
717      `stddevs` and `maxvals`, and the broadcasted shape must be dominated by
718      `shape`.
719    maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of
720      the truncated normal distribution. This must broadcast with `means`,
721      `stddevs` and `minvals`, and the broadcasted shape must be dominated by
722      `shape`.
723    name: A name for the operation (optional).
724
725  Returns:
726    A tensor of the specified shape filled with random truncated normal values.
727  """
728  with ops.name_scope(name, "stateless_parameterized_truncated_normal",
729                      [shape, means, stddevs, minvals, maxvals]) as name:
730    shape_tensor = tensor_util.shape_tensor(shape)
731    means_tensor = ops.convert_to_tensor(means, name="means")
732    stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs")
733    minvals_tensor = ops.convert_to_tensor(minvals, name="minvals")
734    maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals")
735    rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal(
736        shape_tensor, seed, means_tensor, stddevs_tensor, minvals_tensor,
737        maxvals_tensor)
738    tensor_util.maybe_set_static_shape(rnd, shape)
739    return rnd
740