• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for generating random numbers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import enum  # pylint: disable=g-bad-import-order
22
23import numpy as np
24import six
25
26from tensorflow.python.compat import compat
27from tensorflow.python.distribute import distribution_strategy_context as ds_context
28from tensorflow.python.distribute import values_util
29from tensorflow.python.eager import context
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import gen_stateful_random_ops
34from tensorflow.python.ops import gen_stateless_random_ops_v2
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.training.tracking import tracking
38from tensorflow.python.util.tf_export import tf_export
39
40
41# A seed for random ops (stateful and stateless) will always be 1024
42# bits, all of which will be sent to the C++ code. The actual C++
43# implementation of some algorithms may only use a lower part of the bits.
44
45MAX_INT64 = 2**63 - 1
46MIN_INT64 = -(2**63)
47UINT64_SPAN = 2**64
48# 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained in
49# b/111604096 and cl/171681867), so I use signed int here. I choose int64
50# instead of int32 here because `VarHandleOp` doesn't support int32 on GPU.
51SEED_TYPE = "int64"
52SEED_MIN = MIN_INT64
53SEED_MAX = MAX_INT64
54SEED_UINT_SPAN = UINT64_SPAN
55SEED_TYPE_BITS = 64
56SEED_BIT_MASK = 0xFFFFFFFFFFFFFFFF
57SEED_SIZE = 16  # in units of SEED_TYPE
58
59
60STATE_TYPE = SEED_TYPE
61ALGORITHM_TYPE = STATE_TYPE
62PHILOX_STATE_SIZE = 3
63THREEFRY_STATE_SIZE = 2
64
65
66@tf_export("random.Algorithm", "random.experimental.Algorithm")
67class Algorithm(enum.Enum):
68  PHILOX = 1
69  THREEFRY = 2
70
71
72RNG_ALG_PHILOX = Algorithm.PHILOX.value
73RNG_ALG_THREEFRY = Algorithm.THREEFRY.value
74DEFAULT_ALGORITHM = RNG_ALG_PHILOX
75
76
77def non_deterministic_ints(shape, dtype=dtypes.int64):
78  """Non-deterministically generates some integers.
79
80  This op may use some OS-provided source of non-determinism (e.g. an RNG), so
81  each execution will give different results.
82
83  Args:
84    shape: the shape of the result.
85    dtype: (optional) the dtype of the result.
86
87  Returns:
88    a tensor whose element values are non-deterministically chosen.
89  """
90  return gen_stateful_random_ops.non_deterministic_ints(
91      shape=shape, dtype=dtype)
92
93
94def _uint_to_int(n):
95  if n > SEED_MAX:
96    n = n - SEED_UINT_SPAN
97  return n
98
99
100def _make_1d_state(state_size, seed):
101  """Makes a 1-D RNG state.
102
103  Args:
104    state_size: an integer.
105    seed: an integer or 1-D tensor.
106
107  Returns:
108    a 1-D tensor of shape [state_size] and dtype STATE_TYPE.
109  """
110  if isinstance(seed, six.integer_types):
111    # chop the Python integer (infinite precision) into chunks of SEED_TYPE
112    ls = []
113    for _ in range(state_size):
114      ls.append(seed & SEED_BIT_MASK)
115      seed >>= SEED_TYPE_BITS
116    seed = ls
117  # to avoid overflow error from np.asarray
118  seed = list(map(_uint_to_int, seed))
119  seed = np.asarray(seed, dtype=STATE_TYPE)
120  if len(seed.shape) != 1:
121    raise ValueError(
122        "seed should only have one dimension; got shape: %s" % seed.shape)
123  seed = seed[0:state_size]
124  # Padding with zeros on the *left* if too short. Padding on the right would
125  # cause a small seed to be used as the "counter" while the "key" is always
126  # zero (for counter-based RNG algorithms), because in the current memory
127  # layout counter is stored before key. In such a situation two RNGs with
128  # two different small seeds may generate overlapping outputs.
129  seed_size = seed.shape[0]
130  if seed_size < state_size:
131    seed = np.pad(
132        seed, [(state_size - seed_size, 0)],
133        mode="constant",
134        constant_values=0)
135  assert seed.shape == (state_size,), "Wrong seed.shape: %s" % seed.shape
136  return seed
137
138
139def _get_counter_size(alg):
140  if alg == RNG_ALG_PHILOX:
141    return 2
142  elif alg == RNG_ALG_THREEFRY:
143    return 1
144  else:
145    raise ValueError("Unsupported algorithm id: %s" % alg)
146
147
148def _get_state_size(alg):
149  if alg == RNG_ALG_PHILOX:
150    return PHILOX_STATE_SIZE
151  elif alg == RNG_ALG_THREEFRY:
152    return THREEFRY_STATE_SIZE
153  else:
154    raise ValueError("Unsupported algorithm id: %s" % alg)
155
156
157def _check_state_shape(shape, alg):
158  if isinstance(alg, ops.Tensor) and not context.executing_eagerly():
159    return
160  shape.assert_is_compatible_with([_get_state_size(int(alg))])
161
162
163def _make_state_from_seed(seed, alg):
164  return _make_1d_state(_get_state_size(alg), seed)
165
166
167def _convert_alg_to_int(alg):
168  """Converts algorithm to an integer.
169
170  Args:
171    alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed
172      strings are "philox" and "threefry".
173
174  Returns:
175    An integer, unless the input is a Tensor in which case a Tensor is returned.
176  """
177  if isinstance(alg, six.integer_types):
178    return alg
179  if isinstance(alg, Algorithm):
180    return alg.value
181  if isinstance(alg, ops.Tensor):
182    return alg
183  if isinstance(alg, str):
184    if alg == "philox":
185      return RNG_ALG_PHILOX
186    elif alg == "threefry":
187      return RNG_ALG_THREEFRY
188    else:
189      raise ValueError("Unknown algorithm name: %s" % alg)
190  else:
191    raise TypeError("Can't convert algorithm %s of type %s to int" %
192                    (alg, type(alg)))
193
194
195@tf_export("random.create_rng_state", "random.experimental.create_rng_state")
196def create_rng_state(seed, alg):
197  """Creates a RNG state from an integer or a vector.
198
199  Example:
200
201  >>> tf.random.create_rng_state(
202  ...     1234, "philox")
203  array([1234,    0,    0])
204  >>> tf.random.create_rng_state(
205  ...     [12, 34], "threefry")
206  array([12, 34])
207
208  Args:
209    seed: an integer or 1-D numpy array.
210    alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer.
211
212  Returns:
213    a 1-D numpy array whose size depends on the algorithm.
214  """
215  alg = _convert_alg_to_int(alg)
216  return _make_state_from_seed(seed, alg)
217
218
219def _shape_tensor(shape):
220  """Convert to an int32 or int64 tensor, defaulting to int64 if empty."""
221  if isinstance(shape, (tuple, list)) and not shape:
222    dtype = dtypes.int64
223  else:
224    dtype = None
225  return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
226
227
228def _convert_to_state_tensor(t):
229  if isinstance(t, list):
230    # to avoid out-of-range error from ops.convert_to_tensor
231    t = list(map(_uint_to_int, t))
232  return ops.convert_to_tensor(t, dtype=STATE_TYPE)
233
234
235def get_replica_id():
236  rctx = ds_context.get_replica_context()
237  if rctx is None:
238    return None
239  return rctx.replica_id_in_sync_group
240
241
242@tf_export("random.Generator", "random.experimental.Generator")
243class Generator(tracking.AutoTrackable):
244  """Random-number generator.
245
246  Example:
247
248  Creating a generator from a seed:
249
250  >>> g = tf.random.Generator.from_seed(1234)
251  >>> g.normal(shape=(2, 3))
252  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
253  array([[ 0.9356609 ,  1.0854305 , -0.93788373],
254         [-0.5061547 ,  1.3169702 ,  0.7137579 ]], dtype=float32)>
255
256  Creating a generator from a non-deterministic state:
257
258  >>> g = tf.random.Generator.from_non_deterministic_state()
259  >>> g.normal(shape=(2, 3))
260  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
261
262  All the constructors allow explicitly choosing an Random-Number-Generation
263  (RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For
264  example:
265
266  >>> g = tf.random.Generator.from_seed(123, alg="philox")
267  >>> g.normal(shape=(2, 3))
268  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=
269  array([[ 0.8673864 , -0.29899067, -0.9310337 ],
270         [-1.5828488 ,  1.2481191 , -0.6770643 ]], dtype=float32)>
271
272  CPU, GPU and TPU with the same algorithm and seed will generate the same
273  integer random numbers. Float-point results (such as the output of `normal`)
274  may have small numerical discrepancies between different devices.
275
276  This class uses a `tf.Variable` to manage its internal state. Every time
277  random numbers are generated, the state of the generator will change. For
278  example:
279
280  >>> g = tf.random.Generator.from_seed(1234)
281  >>> g.state
282  <tf.Variable ... numpy=array([1234,    0,    0])>
283  >>> g.normal(shape=(2, 3))
284  <...>
285  >>> g.state
286  <tf.Variable ... numpy=array([2770,    0,    0])>
287
288  The shape of the state is algorithm-specific.
289
290  There is also a global generator:
291
292  >>> g = tf.random.get_global_generator()
293  >>> g.normal(shape=(2, 3))
294  <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...>
295
296  When creating a generator inside a `tf.distribute.Strategy` scope, each
297  replica will get a different stream of random numbers.
298
299  Note: `tf.distribute.experimental.CentralStorageStrategy` and
300  `tf.distribute.experimental.ParameterServerStrategy` are not supported yet.
301
302  For example, in this code:
303
304  ```
305  strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"])
306  with strat.scope():
307    g = tf.random.Generator.from_seed(1)
308    def f():
309      return g.normal([])
310    results = strat.run(f).values
311  ```
312
313  `results[0]` and `results[1]` will have different values.
314
315  If the generator is seeded (e.g. created via `Generator.from_seed`), the
316  random numbers will be determined by the seed, even though different replicas
317  get different numbers.  One can think of a random number generated on a
318  replica as a hash of the replica ID and a "master" random number that may be
319  common to all replicas. Hence, the whole system is still deterministic.
320
321  (Note that the random numbers on different replicas are not correlated, even
322  if they are deterministically determined by the same seed. They are not
323  correlated in the sense that no matter what statistics one calculates on them,
324  there won't be any discernable correlation.)
325
326  Generators can be freely saved and restored using `tf.train.Checkpoint`. The
327  checkpoint can be restored in a distribution strategy with a different number
328  of replicas than the original strategy. If a replica ID is present in both the
329  original and the new distribution strategy, its state will be properly
330  restored (i.e. the random-number stream from the restored point will be the
331  same as that from the saving point) unless the replicas have already diverged
332  in their RNG call traces before saving (e.g. one replica has made one RNG call
333  while another has made two RNG calls). We don't have such guarantee if the
334  generator is saved in a strategy scope and restored outside of any strategy
335  scope, or vice versa.
336  """
337
338  @classmethod
339  def from_state(cls, state, alg):
340    """Creates a generator from a state.
341
342    See `__init__` for description of `state` and `alg`.
343
344    Args:
345      state: the new state.
346      alg: the RNG algorithm.
347
348    Returns:
349      The new generator.
350    """
351    return cls(alg=alg, state=state)
352
353  @classmethod
354  def from_seed(cls, seed, alg=None):
355    """Creates a generator from a seed.
356
357    A seed is a 1024-bit unsigned integer represented either as a Python
358    integer or a vector of integers. Seeds shorter than 1024-bit will be
359    padded. The padding, the internal structure of a seed and the way a seed
360    is converted to a state are all opaque (unspecified). The only semantics
361    specification of seeds is that two different seeds are likely to produce
362    two independent generators (but no guarantee).
363
364    Args:
365      seed: the seed for the RNG.
366      alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
367        `__init__` for its possible values.
368
369    Returns:
370      The new generator.
371    """
372    if alg is None:
373      # TODO(b/170668986): more sophisticated algorithm selection
374      alg = DEFAULT_ALGORITHM
375    alg = _convert_alg_to_int(alg)
376    state = create_rng_state(seed, alg)
377    return cls(state=state, alg=alg)
378
379  @classmethod
380  def from_non_deterministic_state(cls, alg=None):
381    """Creates a generator by non-deterministically initializing its state.
382
383    The source of the non-determinism will be platform- and time-dependent.
384
385    Args:
386      alg: (optional) the RNG algorithm. If None, it will be auto-selected. See
387        `__init__` for its possible values.
388
389    Returns:
390      The new generator.
391    """
392    if alg is None:
393      # TODO(b/170668986): more sophisticated algorithm selection
394      alg = DEFAULT_ALGORITHM
395    alg = _convert_alg_to_int(alg)
396    state = non_deterministic_ints(shape=[_get_state_size(alg)],
397                                   dtype=SEED_TYPE)
398    return cls(state=state, alg=alg)
399
400  @classmethod
401  def from_key_counter(cls, key, counter, alg):
402    """Creates a generator from a key and a counter.
403
404    This constructor only applies if the algorithm is a counter-based algorithm.
405    See method `key` for the meaning of "key" and "counter".
406
407    Args:
408      key: the key for the RNG, a scalar of type STATE_TYPE.
409      counter: a vector of dtype STATE_TYPE representing the initial counter for
410        the RNG, whose length is algorithm-specific.,
411      alg: the RNG algorithm. If None, it will be auto-selected. See
412        `__init__` for its possible values.
413
414    Returns:
415      The new generator.
416    """
417    counter = _convert_to_state_tensor(counter)
418    key = _convert_to_state_tensor(key)
419    alg = _convert_alg_to_int(alg)
420    counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1])
421    key.shape.assert_is_compatible_with([])
422    key = array_ops.reshape(key, [1])
423    state = array_ops.concat([counter, key], 0)
424    return cls(state=state, alg=alg)
425
426  def __init__(self, copy_from=None, state=None, alg=None):
427    """Creates a generator.
428
429    The new generator will be initialized by one of the following ways, with
430    decreasing precedence:
431    (1) If `copy_from` is not None, the new generator is initialized by copying
432        information from another generator.
433    (2) If `state` and `alg` are not None (they must be set together), the new
434        generator is initialized by a state.
435
436    Args:
437      copy_from: a generator to be copied from.
438      state: a vector of dtype STATE_TYPE representing the initial state of the
439        RNG, whose length and semantics are algorithm-specific. If it's a
440        variable, the generator will reuse it instead of creating a new
441        variable.
442      alg: the RNG algorithm. Possible values are
443        `tf.random.Algorithm.PHILOX` for the Philox algorithm and
444        `tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm
445        (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
446        [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]).
447        The string names `"philox"` and `"threefry"` can also be used.
448        Note `PHILOX` guarantees the same numbers are produced (given
449        the same random state) across all architectures (CPU, GPU, XLA etc).
450    """
451    # TODO(b/175072242): Remove distribution-strategy dependencies in this file.
452    if ds_context.has_strategy():
453      self._distribution_strategy = ds_context.get_strategy()
454    else:
455      self._distribution_strategy = None
456    if copy_from is not None:
457      # All other arguments should be None
458      assert (alg or state) is None
459      self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE,
460                                              trainable=False)
461      self._alg = copy_from.algorithm
462    else:
463      assert alg is not None and state is not None
464      if ds_context.has_strategy():
465        strat_name = type(ds_context.get_strategy()).__name__
466        # TODO(b/174610856): Support CentralStorageStrategy and
467        #   ParameterServerStrategy.
468        if "CentralStorage" in strat_name or "ParameterServer" in strat_name:
469          raise ValueError("%s is not supported yet" % strat_name)
470      alg = _convert_alg_to_int(alg)
471      if isinstance(state, variables.Variable):
472        _check_state_shape(state.shape, alg)
473        self._state_var = state
474      else:
475        state = _convert_to_state_tensor(state)
476        _check_state_shape(state.shape, alg)
477        self._state_var = self._create_variable(state, dtype=STATE_TYPE,
478                                                trainable=False)
479      self._alg = alg
480
481  def _create_variable(self, *args, **kwargs):
482    """Creates a variable.
483
484    Args:
485      *args: positional arguments passed along to `variables.Variable.
486      **kwargs: keyword arguments passed along to `variables.Variable.
487
488    Returns:
489      The created variable.
490    """
491    return variables.Variable(*args, **kwargs)
492
493  def reset(self, state):
494    """Resets the generator by a new state.
495
496    See `__init__` for the meaning of "state".
497
498    Args:
499      state: the new state.
500    """
501    state = _convert_to_state_tensor(state)
502    state.shape.assert_is_compatible_with([_get_state_size(self.algorithm)])
503    self._state_var.assign(state)
504
505  def reset_from_seed(self, seed):
506    """Resets the generator by a new seed.
507
508    See `from_seed` for the meaning of "seed".
509
510    Args:
511      seed: the new seed.
512    """
513    state = create_rng_state(seed, self.algorithm)
514    self._state_var.assign(state)
515
516  def reset_from_key_counter(self, key, counter):
517    """Resets the generator by a new key-counter pair.
518
519    See `from_key_counter` for the meaning of "key" and "counter".
520
521    Args:
522      key: the new key.
523      counter: the new counter.
524    """
525    counter = _convert_to_state_tensor(counter)
526    key = _convert_to_state_tensor(key)
527    counter.shape.assert_is_compatible_with(
528        [_get_state_size(self.algorithm) - 1])
529    key.shape.assert_is_compatible_with([])
530    key = array_ops.reshape(key, [1])
531    state = array_ops.concat([counter, key], 0)
532    self._state_var.assign(state)
533
534  @property
535  def state(self):
536    """The internal state of the RNG."""
537    return self._state_var
538
539  @property
540  def algorithm(self):
541    """The RNG algorithm id (a Python integer or scalar integer Tensor)."""
542    return self._alg
543
544  def _standard_normal(self, shape, dtype):
545    if compat.forward_compatible(2020, 10, 25):
546      key, counter = self._prepare_key_counter(shape)
547      return gen_stateless_random_ops_v2.stateless_random_normal_v2(
548          shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
549    return gen_stateful_random_ops.stateful_standard_normal_v2(
550        self.state.handle, self.algorithm, shape, dtype=dtype)
551
552  @property
553  def key(self):
554    """The 'key' part of the state of a counter-based RNG.
555
556    For a counter-base RNG algorithm such as Philox and ThreeFry (as
557    described in paper 'Parallel Random Numbers: As Easy as 1, 2, 3'
558    [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]),
559    the RNG state consists of two parts: counter and key. The output is
560    generated via the formula: output=hash(key, counter), i.e. a hashing of
561    the counter parametrized by the key. Two RNGs with two different keys can
562    be thought as generating two independent random-number streams (a stream
563    is formed by increasing the counter).
564
565    Returns:
566      A scalar which is the 'key' part of the state, if the RNG algorithm is
567        counter-based; otherwise it raises a ValueError.
568    """
569    alg = self.algorithm
570    if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
571      return self._state_var[-1]
572    else:
573      raise ValueError("Unsupported algorithm id: %s" % alg)
574
575  # TODO(wangpeng): Add "Returns" section to docstring once new version kicks in
576  # pylint: disable=g-doc-return-or-yield
577  def skip(self, delta):
578    """Advance the counter of a counter-based RNG.
579
580    Args:
581      delta: the amount of advancement. The state of the RNG after
582        `skip(n)` will be the same as that after `normal([n])`
583        (or any other distribution). The actual increment added to the
584        counter is an unspecified implementation detail.
585    """
586    if compat.forward_compatible(2020, 10, 25):
587      return self._skip(delta)
588    gen_stateful_random_ops.rng_skip(
589        self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
590        math_ops.cast(delta, dtypes.int64))
591  # pylint: enable=g-doc-return-or-yield
592
593  def _skip_single_var(self, var, delta):
594    # TODO(wangpeng): Cache the cast algorithm instead of casting everytime.
595    return gen_stateful_random_ops.rng_read_and_skip(
596        var.handle, alg=math_ops.cast(self.algorithm, dtypes.int32),
597        delta=math_ops.cast(delta, dtypes.uint64))
598
599  def _skip(self, delta):
600    def update_fn(v):
601      return self._skip_single_var(v, delta)
602    # TODO(b/170515001): Always call strategy.extended.update after calling it
603    #   from both replica context and cross-replica context is supported.
604    if values_util.is_saving_non_distributed():
605      # Assumes replica context with replica_id=0, since we only save the first
606      # replica.
607      return update_fn(self.state)
608    if self._distribution_strategy is not None:
609      with ds_context.enter_or_assert_strategy(self._distribution_strategy):
610        if ds_context.in_cross_replica_context():
611          # Code that operates on all replicas of a variable cannot be saved
612          # without retracing.
613          values_util.mark_as_unsaveable()
614          # In cross-replica context we need to use strategy.extended.update.
615          return ds_context.get_strategy().extended.update(
616              self.state, update_fn)
617    return update_fn(self.state)
618
619  def _preprocess_key(self, key):
620    if self._distribution_strategy is None:
621      return key
622    with ds_context.enter_or_assert_strategy(self._distribution_strategy):
623      replica_id = get_replica_id()
624      if replica_id is not None:
625        replica_id = array_ops.stack([replica_id, 0], axis=0)
626        replica_id = math_ops.cast(replica_id, dtypes.uint64)
627        # Conceptually: key = hash(key, replica_id)
628        key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
629            shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64,
630            alg=self.algorithm)
631      return key
632
633  def _prepare_key_counter(self, shape):
634    delta = math_ops.reduce_prod(shape)
635    counter_key = self.skip(delta)
636    counter_size = _get_counter_size(self.algorithm)
637    counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64)
638    key = array_ops.bitcast(counter_key[counter_size:counter_size + 1],
639                            dtypes.uint64)
640    key = self._preprocess_key(key)
641    return key, counter
642
643  # The following functions return a tensor and as a side effect update
644  # self._state_var.
645  def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
646             name=None):
647    """Outputs random values from a normal distribution.
648
649    Args:
650      shape: A 1-D integer Tensor or Python array. The shape of the output
651        tensor.
652      mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
653        distribution.
654      stddev: A 0-D Tensor or Python value of type `dtype`. The standard
655        deviation of the normal distribution.
656      dtype: The type of the output.
657      name: A name for the operation (optional).
658
659    Returns:
660      A tensor of the specified shape filled with random normal values.
661    """
662    with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name:
663      shape = _shape_tensor(shape)
664      mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
665      stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
666      rnd = self._standard_normal(shape, dtype=dtype)
667      return math_ops.add(rnd * stddev, mean, name=name)
668
669  def _truncated_normal(self, shape, dtype):
670    if compat.forward_compatible(2020, 10, 25):
671      key, counter = self._prepare_key_counter(shape)
672      return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
673          shape=shape,
674          key=key,
675          counter=counter,
676          dtype=dtype,
677          alg=self.algorithm)
678    return gen_stateful_random_ops.stateful_truncated_normal(
679        self.state.handle, self.algorithm, shape, dtype=dtype)
680
681  def truncated_normal(self, shape,
682                       mean=0.0,
683                       stddev=1.0,
684                       dtype=dtypes.float32,
685                       name=None):
686    """Outputs random values from a truncated normal distribution.
687
688    The generated values follow a normal distribution with specified mean and
689    standard deviation, except that values whose magnitude is more than
690    2 standard deviations from the mean are dropped and re-picked.
691
692    Args:
693      shape: A 1-D integer Tensor or Python array. The shape of the output
694        tensor.
695      mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
696        truncated normal distribution.
697      stddev: A 0-D Tensor or Python value of type `dtype`. The standard
698        deviation of the normal distribution, before truncation.
699      dtype: The type of the output.
700      name: A name for the operation (optional).
701
702    Returns:
703      A tensor of the specified shape filled with random truncated normal
704        values.
705    """
706    with ops.name_scope(
707        name, "truncated_normal", [shape, mean, stddev]) as name:
708      shape_tensor = _shape_tensor(shape)
709      mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
710      stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
711      rnd = self._truncated_normal(shape_tensor, dtype=dtype)
712      mul = rnd * stddev_tensor
713      return math_ops.add(mul, mean_tensor, name=name)
714
715  def _uniform(self, shape, dtype):
716    if compat.forward_compatible(2020, 10, 25):
717      key, counter = self._prepare_key_counter(shape)
718      return gen_stateless_random_ops_v2.stateless_random_uniform_v2(
719          shape=shape,
720          key=key,
721          counter=counter,
722          dtype=dtype,
723          alg=self.algorithm)
724    return gen_stateful_random_ops.stateful_uniform(
725        self.state.handle, self.algorithm, shape=shape, dtype=dtype)
726
727  def _uniform_full_int(self, shape, dtype, name=None):
728    if compat.forward_compatible(2020, 10, 25):
729      key, counter = self._prepare_key_counter(shape)
730      return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
731          shape=shape,
732          key=key,
733          counter=counter,
734          dtype=dtype,
735          alg=self.algorithm,
736          name=name)
737    return gen_stateful_random_ops.stateful_uniform_full_int(
738        self.state.handle, self.algorithm, shape=shape,
739        dtype=dtype, name=name)
740
741  def uniform(self, shape, minval=0, maxval=None,
742              dtype=dtypes.float32, name=None):
743    """Outputs random values from a uniform distribution.
744
745    The generated values follow a uniform distribution in the range
746    `[minval, maxval)`. The lower bound `minval` is included in the range, while
747    the upper bound `maxval` is excluded. (For float numbers especially
748    low-precision types like bfloat16, because of
749    rounding, the result may sometimes include `maxval`.)
750
751    For floats, the default range is `[0, 1)`.  For ints, at least `maxval` must
752    be specified explicitly.
753
754    In the integer case, the random integers are slightly biased unless
755    `maxval - minval` is an exact power of two.  The bias is small for values of
756    `maxval - minval` significantly smaller than the range of the output (either
757    `2**32` or `2**64`).
758
759    For full-range random integers, pass `minval=None` and `maxval=None` with an
760    integer `dtype` (for integer dtypes, `minval` and `maxval` must be both
761    `None` or both not `None`).
762
763    Args:
764      shape: A 1-D integer Tensor or Python array. The shape of the output
765        tensor.
766      minval: A Tensor or Python value of type `dtype`, broadcastable with
767        `shape` (for integer types, broadcasting is not supported, so it needs
768        to be a scalar). The lower bound (included) on the range of random
769        values to generate. Pass `None` for full-range integers. Defaults to 0.
770      maxval: A Tensor or Python value of type `dtype`, broadcastable with
771        `shape` (for integer types, broadcasting is not supported, so it needs
772        to be a scalar). The upper bound (excluded) on the range of random
773        values to generate. Pass `None` for full-range integers. Defaults to 1
774        if `dtype` is floating point.
775      dtype: The type of the output.
776      name: A name for the operation (optional).
777
778    Returns:
779      A tensor of the specified shape filled with random uniform values.
780
781    Raises:
782      ValueError: If `dtype` is integral and `maxval` is not specified.
783    """
784    dtype = dtypes.as_dtype(dtype)
785    if dtype.is_integer:
786      if (minval is None) != (maxval is None):
787        raise ValueError("For integer dtype {}, minval and maxval must be both "
788                         "`None` or both non-`None`; got minval={} and "
789                         "maxval={}".format(dtype, minval, maxval))
790    elif maxval is None:
791      maxval = 1
792    with ops.name_scope(name, "stateful_uniform",
793                        [shape, minval, maxval]) as name:
794      shape = _shape_tensor(shape)
795      if dtype.is_integer and minval is None:
796        return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
797      minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
798      maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
799      if dtype.is_integer:
800        if compat.forward_compatible(2020, 10, 25):
801          key, counter = self._prepare_key_counter(shape)
802          return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
803              shape=shape,
804              key=key,
805              counter=counter,
806              minval=minval,
807              maxval=maxval,
808              alg=self.algorithm,
809              name=name)
810        return gen_stateful_random_ops.stateful_uniform_int(
811            self.state.handle, self.algorithm, shape=shape,
812            minval=minval, maxval=maxval, name=name)
813      else:
814        rnd = self._uniform(shape=shape, dtype=dtype)
815        return math_ops.add(rnd * (maxval - minval), minval, name=name)
816
817  def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None):
818    """Uniform distribution on an integer type's entire range.
819
820    This method is the same as setting `minval` and `maxval` to `None` in the
821    `uniform` method.
822
823    Args:
824      shape: the shape of the output.
825      dtype: (optional) the integer type, default to uint64.
826      name: (optional) the name of the node.
827
828    Returns:
829      A tensor of random numbers of the required shape.
830    """
831    dtype = dtypes.as_dtype(dtype)
832    with ops.name_scope(name, "stateful_uniform_full_int",
833                        [shape]) as name:
834      shape = _shape_tensor(shape)
835      return self._uniform_full_int(shape=shape, dtype=dtype, name=name)
836
837  def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None):
838    """Outputs random values from a binomial distribution.
839
840    The generated values follow a binomial distribution with specified count and
841    probability of success parameters.
842
843    Example:
844
845    ```python
846    counts = [10., 20.]
847    # Probability of success.
848    probs = [0.8]
849
850    rng = tf.random.Generator.from_seed(seed=234)
851    binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs)
852
853
854    counts = ... # Shape [3, 1, 2]
855    probs = ...  # Shape [1, 4, 2]
856    shape = [3, 4, 3, 4, 2]
857    rng = tf.random.Generator.from_seed(seed=1717)
858    # Sample shape will be [3, 4, 3, 4, 2]
859    binomial_samples = rng.binomial(shape=shape, counts=counts, probs=probs)
860    ```
861
862
863    Args:
864      shape: A 1-D integer Tensor or Python array. The shape of the output
865        tensor.
866      counts: Tensor. The counts of the binomial distribution. Must be
867        broadcastable with `probs`, and broadcastable with the rightmost
868        dimensions of `shape`.
869      probs: Tensor. The probability of success for the
870        binomial distribution. Must be broadcastable with `counts` and
871        broadcastable with the rightmost dimensions of `shape`.
872      dtype: The type of the output. Default: tf.int32
873      name: A name for the operation (optional).
874
875    Returns:
876      samples: A Tensor of the specified shape filled with random binomial
877        values.  For each i, each samples[i, ...] is an independent draw from
878        the binomial distribution on counts[i] trials with probability of
879        success probs[i].
880    """
881    dtype = dtypes.as_dtype(dtype)
882    with ops.name_scope(name, "binomial", [shape, counts, probs]) as name:
883      counts = ops.convert_to_tensor(counts, name="counts")
884      probs = ops.convert_to_tensor(probs, name="probs")
885      shape_tensor = _shape_tensor(shape)
886      return gen_stateful_random_ops.stateful_random_binomial(
887          self.state.handle,
888          self.algorithm,
889          shape=shape_tensor,
890          counts=counts,
891          probs=probs,
892          dtype=dtype,
893          name=name)
894
895  # TODO(wangpeng): implement other distributions
896
897  def _make_int64_keys(self, shape=()):
898    # New independent keys are generated via
899    # `new_key[i] = hash(old_key, counter+i)`, which is exactly what
900    # `uniform_full_int(dtype=int64)` does for PhiloxRandom_64_128_128 and
901    # ThreeFry_64_64_64.
902    return self.uniform_full_int(shape=shape, dtype=dtypes.int64)
903
904  def make_seeds(self, count=1):
905    """Generates seeds for stateless random ops.
906
907    For example:
908
909    ```python
910    seeds = get_global_generator().make_seeds(count=10)
911    for i in range(10):
912      seed = seeds[:, i]
913      numbers = stateless_random_normal(shape=[2, 3], seed=seed)
914      ...
915    ```
916
917    Args:
918      count: the number of seed pairs (note that stateless random ops need a
919        pair of seeds to invoke).
920
921    Returns:
922      A tensor of shape [2, count] and dtype int64.
923    """
924    alg = self.algorithm
925    if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
926      keys = self._make_int64_keys(shape=[count])
927      # The two seeds for stateless random ops don't have individual semantics
928      # and are scrambled together, so setting one to zero is fine.
929      zeros = array_ops.zeros_like(keys)
930      return array_ops.stack([keys, zeros])
931    else:
932      raise ValueError("Unsupported algorithm id: %s" % alg)
933
934  def split(self, count=1):
935    """Returns a list of independent `Generator` objects.
936
937    Two generators are independent of each other in the sense that the
938    random-number streams they generate don't have statistically detectable
939    correlations. The new generators are also independent of the old one.
940    The old generator's state will be changed (like other random-number
941    generating methods), so two calls of `split` will return different
942    new generators.
943
944    For example:
945
946    ```python
947    gens = get_global_generator().split(count=10)
948    for gen in gens:
949      numbers = gen.normal(shape=[2, 3])
950      # ...
951    gens2 = get_global_generator().split(count=10)
952    # gens2 will be different from gens
953    ```
954
955    The new generators will be put on the current device (possible different
956    from the old generator's), for example:
957
958    ```python
959    with tf.device("/device:CPU:0"):
960      gen = Generator(seed=1234)  # gen is on CPU
961    with tf.device("/device:GPU:0"):
962      gens = gen.split(count=10)  # gens are on GPU
963    ```
964
965    Args:
966      count: the number of generators to return.
967
968    Returns:
969      A list (length `count`) of `Generator` objects independent of each other.
970      The new generators have the same RNG algorithm as the old one.
971    """
972    def _key_to_state(alg, key):
973      # Padding with zeros on the left. The zeros will be the counter.
974      return [0] * (_get_state_size(alg) - 1) + [key]
975
976    alg = self.algorithm
977    if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY:
978      keys = self._make_int64_keys(shape=[count])
979      return [Generator(state=_key_to_state(alg, key), alg=alg)
980              for key in keys.numpy()]
981    else:
982      raise ValueError("Unsupported algorithm id: %s" % alg)
983
984
985# It's not safe to create TF ops before `init_google` is called, so this is
986# initialized to None and get a value the first time `get_global_generator` is
987# called.
988global_generator = None
989
990
991@tf_export("random.get_global_generator",
992           "random.experimental.get_global_generator")
993def get_global_generator():
994  """Retrieves the global generator.
995
996  This function will create the global generator the first time it is called,
997  and the generator will be placed at the default device at that time, so one
998  needs to be careful when this function is first called. Using a generator
999  placed on a less-ideal device will incur performance regression.
1000
1001  Returns:
1002    The global `tf.random.Generator` object.
1003  """
1004  global global_generator
1005  if global_generator is None:
1006    with ops.init_scope():
1007      global_generator = Generator.from_non_deterministic_state()
1008  return global_generator
1009
1010
1011@tf_export("random.set_global_generator",
1012           "random.experimental.set_global_generator")
1013def set_global_generator(generator):
1014  """Replaces the global generator with another `Generator` object.
1015
1016  This function creates a new Generator object (and the Variable object within),
1017  which does not work well with tf.function because (1) tf.function puts
1018  restrictions on Variable creation thus reset_global_generator can't be freely
1019  used inside tf.function; (2) redirecting a global variable to
1020  a new object is problematic with tf.function because the old object may be
1021  captured by a 'tf.function'ed function and still be used by it.
1022  A 'tf.function'ed function only keeps weak references to variables,
1023  so deleting a variable and then calling that function again may raise an
1024  error, as demonstrated by
1025  random_test.py/RandomTest.testResetGlobalGeneratorBadWithDefun .
1026
1027  Args:
1028    generator: the new `Generator` object.
1029  """
1030  global global_generator
1031  global_generator = generator
1032