1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations for generating random numbers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import enum # pylint: disable=g-bad-import-order 22 23import numpy as np 24import six 25 26from tensorflow.python.compat import compat 27from tensorflow.python.distribute import distribution_strategy_context as ds_context 28from tensorflow.python.distribute import values_util 29from tensorflow.python.eager import context 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import gen_stateful_random_ops 34from tensorflow.python.ops import gen_stateless_random_ops_v2 35from tensorflow.python.ops import math_ops 36from tensorflow.python.ops import variables 37from tensorflow.python.training.tracking import tracking 38from tensorflow.python.util.tf_export import tf_export 39 40 41# A seed for random ops (stateful and stateless) will always be 1024 42# bits, all of which will be sent to the C++ code. The actual C++ 43# implementation of some algorithms may only use a lower part of the bits. 44 45MAX_INT64 = 2**63 - 1 46MIN_INT64 = -(2**63) 47UINT64_SPAN = 2**64 48# 'Variable' doesn't support uint32 or uint64 yet (due to reasons explained in 49# b/111604096 and cl/171681867), so I use signed int here. I choose int64 50# instead of int32 here because `VarHandleOp` doesn't support int32 on GPU. 51SEED_TYPE = "int64" 52SEED_MIN = MIN_INT64 53SEED_MAX = MAX_INT64 54SEED_UINT_SPAN = UINT64_SPAN 55SEED_TYPE_BITS = 64 56SEED_BIT_MASK = 0xFFFFFFFFFFFFFFFF 57SEED_SIZE = 16 # in units of SEED_TYPE 58 59 60STATE_TYPE = SEED_TYPE 61ALGORITHM_TYPE = STATE_TYPE 62PHILOX_STATE_SIZE = 3 63THREEFRY_STATE_SIZE = 2 64 65 66@tf_export("random.Algorithm", "random.experimental.Algorithm") 67class Algorithm(enum.Enum): 68 PHILOX = 1 69 THREEFRY = 2 70 71 72RNG_ALG_PHILOX = Algorithm.PHILOX.value 73RNG_ALG_THREEFRY = Algorithm.THREEFRY.value 74DEFAULT_ALGORITHM = RNG_ALG_PHILOX 75 76 77def non_deterministic_ints(shape, dtype=dtypes.int64): 78 """Non-deterministically generates some integers. 79 80 This op may use some OS-provided source of non-determinism (e.g. an RNG), so 81 each execution will give different results. 82 83 Args: 84 shape: the shape of the result. 85 dtype: (optional) the dtype of the result. 86 87 Returns: 88 a tensor whose element values are non-deterministically chosen. 89 """ 90 return gen_stateful_random_ops.non_deterministic_ints( 91 shape=shape, dtype=dtype) 92 93 94def _uint_to_int(n): 95 if n > SEED_MAX: 96 n = n - SEED_UINT_SPAN 97 return n 98 99 100def _make_1d_state(state_size, seed): 101 """Makes a 1-D RNG state. 102 103 Args: 104 state_size: an integer. 105 seed: an integer or 1-D tensor. 106 107 Returns: 108 a 1-D tensor of shape [state_size] and dtype STATE_TYPE. 109 """ 110 if isinstance(seed, six.integer_types): 111 # chop the Python integer (infinite precision) into chunks of SEED_TYPE 112 ls = [] 113 for _ in range(state_size): 114 ls.append(seed & SEED_BIT_MASK) 115 seed >>= SEED_TYPE_BITS 116 seed = ls 117 # to avoid overflow error from np.asarray 118 seed = list(map(_uint_to_int, seed)) 119 seed = np.asarray(seed, dtype=STATE_TYPE) 120 if len(seed.shape) != 1: 121 raise ValueError( 122 "seed should only have one dimension; got shape: %s" % seed.shape) 123 seed = seed[0:state_size] 124 # Padding with zeros on the *left* if too short. Padding on the right would 125 # cause a small seed to be used as the "counter" while the "key" is always 126 # zero (for counter-based RNG algorithms), because in the current memory 127 # layout counter is stored before key. In such a situation two RNGs with 128 # two different small seeds may generate overlapping outputs. 129 seed_size = seed.shape[0] 130 if seed_size < state_size: 131 seed = np.pad( 132 seed, [(state_size - seed_size, 0)], 133 mode="constant", 134 constant_values=0) 135 assert seed.shape == (state_size,), "Wrong seed.shape: %s" % seed.shape 136 return seed 137 138 139def _get_counter_size(alg): 140 if alg == RNG_ALG_PHILOX: 141 return 2 142 elif alg == RNG_ALG_THREEFRY: 143 return 1 144 else: 145 raise ValueError("Unsupported algorithm id: %s" % alg) 146 147 148def _get_state_size(alg): 149 if alg == RNG_ALG_PHILOX: 150 return PHILOX_STATE_SIZE 151 elif alg == RNG_ALG_THREEFRY: 152 return THREEFRY_STATE_SIZE 153 else: 154 raise ValueError("Unsupported algorithm id: %s" % alg) 155 156 157def _check_state_shape(shape, alg): 158 if isinstance(alg, ops.Tensor) and not context.executing_eagerly(): 159 return 160 shape.assert_is_compatible_with([_get_state_size(int(alg))]) 161 162 163def _make_state_from_seed(seed, alg): 164 return _make_1d_state(_get_state_size(alg), seed) 165 166 167def _convert_alg_to_int(alg): 168 """Converts algorithm to an integer. 169 170 Args: 171 alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed 172 strings are "philox" and "threefry". 173 174 Returns: 175 An integer, unless the input is a Tensor in which case a Tensor is returned. 176 """ 177 if isinstance(alg, six.integer_types): 178 return alg 179 if isinstance(alg, Algorithm): 180 return alg.value 181 if isinstance(alg, ops.Tensor): 182 return alg 183 if isinstance(alg, str): 184 if alg == "philox": 185 return RNG_ALG_PHILOX 186 elif alg == "threefry": 187 return RNG_ALG_THREEFRY 188 else: 189 raise ValueError("Unknown algorithm name: %s" % alg) 190 else: 191 raise TypeError("Can't convert algorithm %s of type %s to int" % 192 (alg, type(alg))) 193 194 195@tf_export("random.create_rng_state", "random.experimental.create_rng_state") 196def create_rng_state(seed, alg): 197 """Creates a RNG state from an integer or a vector. 198 199 Example: 200 201 >>> tf.random.create_rng_state( 202 ... 1234, "philox") 203 array([1234, 0, 0]) 204 >>> tf.random.create_rng_state( 205 ... [12, 34], "threefry") 206 array([12, 34]) 207 208 Args: 209 seed: an integer or 1-D numpy array. 210 alg: the RNG algorithm. Can be a string, an `Algorithm` or an integer. 211 212 Returns: 213 a 1-D numpy array whose size depends on the algorithm. 214 """ 215 alg = _convert_alg_to_int(alg) 216 return _make_state_from_seed(seed, alg) 217 218 219def _shape_tensor(shape): 220 """Convert to an int32 or int64 tensor, defaulting to int64 if empty.""" 221 if isinstance(shape, (tuple, list)) and not shape: 222 dtype = dtypes.int64 223 else: 224 dtype = None 225 return ops.convert_to_tensor(shape, dtype=dtype, name="shape") 226 227 228def _convert_to_state_tensor(t): 229 if isinstance(t, list): 230 # to avoid out-of-range error from ops.convert_to_tensor 231 t = list(map(_uint_to_int, t)) 232 return ops.convert_to_tensor(t, dtype=STATE_TYPE) 233 234 235def get_replica_id(): 236 rctx = ds_context.get_replica_context() 237 if rctx is None: 238 return None 239 return rctx.replica_id_in_sync_group 240 241 242@tf_export("random.Generator", "random.experimental.Generator") 243class Generator(tracking.AutoTrackable): 244 """Random-number generator. 245 246 Example: 247 248 Creating a generator from a seed: 249 250 >>> g = tf.random.Generator.from_seed(1234) 251 >>> g.normal(shape=(2, 3)) 252 <tf.Tensor: shape=(2, 3), dtype=float32, numpy= 253 array([[ 0.9356609 , 1.0854305 , -0.93788373], 254 [-0.5061547 , 1.3169702 , 0.7137579 ]], dtype=float32)> 255 256 Creating a generator from a non-deterministic state: 257 258 >>> g = tf.random.Generator.from_non_deterministic_state() 259 >>> g.normal(shape=(2, 3)) 260 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...> 261 262 All the constructors allow explicitly choosing an Random-Number-Generation 263 (RNG) algorithm. Supported algorithms are `"philox"` and `"threefry"`. For 264 example: 265 266 >>> g = tf.random.Generator.from_seed(123, alg="philox") 267 >>> g.normal(shape=(2, 3)) 268 <tf.Tensor: shape=(2, 3), dtype=float32, numpy= 269 array([[ 0.8673864 , -0.29899067, -0.9310337 ], 270 [-1.5828488 , 1.2481191 , -0.6770643 ]], dtype=float32)> 271 272 CPU, GPU and TPU with the same algorithm and seed will generate the same 273 integer random numbers. Float-point results (such as the output of `normal`) 274 may have small numerical discrepancies between different devices. 275 276 This class uses a `tf.Variable` to manage its internal state. Every time 277 random numbers are generated, the state of the generator will change. For 278 example: 279 280 >>> g = tf.random.Generator.from_seed(1234) 281 >>> g.state 282 <tf.Variable ... numpy=array([1234, 0, 0])> 283 >>> g.normal(shape=(2, 3)) 284 <...> 285 >>> g.state 286 <tf.Variable ... numpy=array([2770, 0, 0])> 287 288 The shape of the state is algorithm-specific. 289 290 There is also a global generator: 291 292 >>> g = tf.random.get_global_generator() 293 >>> g.normal(shape=(2, 3)) 294 <tf.Tensor: shape=(2, 3), dtype=float32, numpy=...> 295 296 When creating a generator inside a `tf.distribute.Strategy` scope, each 297 replica will get a different stream of random numbers. 298 299 Note: `tf.distribute.experimental.CentralStorageStrategy` and 300 `tf.distribute.experimental.ParameterServerStrategy` are not supported yet. 301 302 For example, in this code: 303 304 ``` 305 strat = tf.distribute.MirroredStrategy(devices=["cpu:0", "cpu:1"]) 306 with strat.scope(): 307 g = tf.random.Generator.from_seed(1) 308 def f(): 309 return g.normal([]) 310 results = strat.run(f).values 311 ``` 312 313 `results[0]` and `results[1]` will have different values. 314 315 If the generator is seeded (e.g. created via `Generator.from_seed`), the 316 random numbers will be determined by the seed, even though different replicas 317 get different numbers. One can think of a random number generated on a 318 replica as a hash of the replica ID and a "master" random number that may be 319 common to all replicas. Hence, the whole system is still deterministic. 320 321 (Note that the random numbers on different replicas are not correlated, even 322 if they are deterministically determined by the same seed. They are not 323 correlated in the sense that no matter what statistics one calculates on them, 324 there won't be any discernable correlation.) 325 326 Generators can be freely saved and restored using `tf.train.Checkpoint`. The 327 checkpoint can be restored in a distribution strategy with a different number 328 of replicas than the original strategy. If a replica ID is present in both the 329 original and the new distribution strategy, its state will be properly 330 restored (i.e. the random-number stream from the restored point will be the 331 same as that from the saving point) unless the replicas have already diverged 332 in their RNG call traces before saving (e.g. one replica has made one RNG call 333 while another has made two RNG calls). We don't have such guarantee if the 334 generator is saved in a strategy scope and restored outside of any strategy 335 scope, or vice versa. 336 """ 337 338 @classmethod 339 def from_state(cls, state, alg): 340 """Creates a generator from a state. 341 342 See `__init__` for description of `state` and `alg`. 343 344 Args: 345 state: the new state. 346 alg: the RNG algorithm. 347 348 Returns: 349 The new generator. 350 """ 351 return cls(alg=alg, state=state) 352 353 @classmethod 354 def from_seed(cls, seed, alg=None): 355 """Creates a generator from a seed. 356 357 A seed is a 1024-bit unsigned integer represented either as a Python 358 integer or a vector of integers. Seeds shorter than 1024-bit will be 359 padded. The padding, the internal structure of a seed and the way a seed 360 is converted to a state are all opaque (unspecified). The only semantics 361 specification of seeds is that two different seeds are likely to produce 362 two independent generators (but no guarantee). 363 364 Args: 365 seed: the seed for the RNG. 366 alg: (optional) the RNG algorithm. If None, it will be auto-selected. See 367 `__init__` for its possible values. 368 369 Returns: 370 The new generator. 371 """ 372 if alg is None: 373 # TODO(b/170668986): more sophisticated algorithm selection 374 alg = DEFAULT_ALGORITHM 375 alg = _convert_alg_to_int(alg) 376 state = create_rng_state(seed, alg) 377 return cls(state=state, alg=alg) 378 379 @classmethod 380 def from_non_deterministic_state(cls, alg=None): 381 """Creates a generator by non-deterministically initializing its state. 382 383 The source of the non-determinism will be platform- and time-dependent. 384 385 Args: 386 alg: (optional) the RNG algorithm. If None, it will be auto-selected. See 387 `__init__` for its possible values. 388 389 Returns: 390 The new generator. 391 """ 392 if alg is None: 393 # TODO(b/170668986): more sophisticated algorithm selection 394 alg = DEFAULT_ALGORITHM 395 alg = _convert_alg_to_int(alg) 396 state = non_deterministic_ints(shape=[_get_state_size(alg)], 397 dtype=SEED_TYPE) 398 return cls(state=state, alg=alg) 399 400 @classmethod 401 def from_key_counter(cls, key, counter, alg): 402 """Creates a generator from a key and a counter. 403 404 This constructor only applies if the algorithm is a counter-based algorithm. 405 See method `key` for the meaning of "key" and "counter". 406 407 Args: 408 key: the key for the RNG, a scalar of type STATE_TYPE. 409 counter: a vector of dtype STATE_TYPE representing the initial counter for 410 the RNG, whose length is algorithm-specific., 411 alg: the RNG algorithm. If None, it will be auto-selected. See 412 `__init__` for its possible values. 413 414 Returns: 415 The new generator. 416 """ 417 counter = _convert_to_state_tensor(counter) 418 key = _convert_to_state_tensor(key) 419 alg = _convert_alg_to_int(alg) 420 counter.shape.assert_is_compatible_with([_get_state_size(alg) - 1]) 421 key.shape.assert_is_compatible_with([]) 422 key = array_ops.reshape(key, [1]) 423 state = array_ops.concat([counter, key], 0) 424 return cls(state=state, alg=alg) 425 426 def __init__(self, copy_from=None, state=None, alg=None): 427 """Creates a generator. 428 429 The new generator will be initialized by one of the following ways, with 430 decreasing precedence: 431 (1) If `copy_from` is not None, the new generator is initialized by copying 432 information from another generator. 433 (2) If `state` and `alg` are not None (they must be set together), the new 434 generator is initialized by a state. 435 436 Args: 437 copy_from: a generator to be copied from. 438 state: a vector of dtype STATE_TYPE representing the initial state of the 439 RNG, whose length and semantics are algorithm-specific. If it's a 440 variable, the generator will reuse it instead of creating a new 441 variable. 442 alg: the RNG algorithm. Possible values are 443 `tf.random.Algorithm.PHILOX` for the Philox algorithm and 444 `tf.random.Algorithm.THREEFRY` for the ThreeFry algorithm 445 (see paper 'Parallel Random Numbers: As Easy as 1, 2, 3' 446 [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]). 447 The string names `"philox"` and `"threefry"` can also be used. 448 Note `PHILOX` guarantees the same numbers are produced (given 449 the same random state) across all architectures (CPU, GPU, XLA etc). 450 """ 451 # TODO(b/175072242): Remove distribution-strategy dependencies in this file. 452 if ds_context.has_strategy(): 453 self._distribution_strategy = ds_context.get_strategy() 454 else: 455 self._distribution_strategy = None 456 if copy_from is not None: 457 # All other arguments should be None 458 assert (alg or state) is None 459 self._state_var = self._create_variable(copy_from.state, dtype=STATE_TYPE, 460 trainable=False) 461 self._alg = copy_from.algorithm 462 else: 463 assert alg is not None and state is not None 464 if ds_context.has_strategy(): 465 strat_name = type(ds_context.get_strategy()).__name__ 466 # TODO(b/174610856): Support CentralStorageStrategy and 467 # ParameterServerStrategy. 468 if "CentralStorage" in strat_name or "ParameterServer" in strat_name: 469 raise ValueError("%s is not supported yet" % strat_name) 470 alg = _convert_alg_to_int(alg) 471 if isinstance(state, variables.Variable): 472 _check_state_shape(state.shape, alg) 473 self._state_var = state 474 else: 475 state = _convert_to_state_tensor(state) 476 _check_state_shape(state.shape, alg) 477 self._state_var = self._create_variable(state, dtype=STATE_TYPE, 478 trainable=False) 479 self._alg = alg 480 481 def _create_variable(self, *args, **kwargs): 482 """Creates a variable. 483 484 Args: 485 *args: positional arguments passed along to `variables.Variable. 486 **kwargs: keyword arguments passed along to `variables.Variable. 487 488 Returns: 489 The created variable. 490 """ 491 return variables.Variable(*args, **kwargs) 492 493 def reset(self, state): 494 """Resets the generator by a new state. 495 496 See `__init__` for the meaning of "state". 497 498 Args: 499 state: the new state. 500 """ 501 state = _convert_to_state_tensor(state) 502 state.shape.assert_is_compatible_with([_get_state_size(self.algorithm)]) 503 self._state_var.assign(state) 504 505 def reset_from_seed(self, seed): 506 """Resets the generator by a new seed. 507 508 See `from_seed` for the meaning of "seed". 509 510 Args: 511 seed: the new seed. 512 """ 513 state = create_rng_state(seed, self.algorithm) 514 self._state_var.assign(state) 515 516 def reset_from_key_counter(self, key, counter): 517 """Resets the generator by a new key-counter pair. 518 519 See `from_key_counter` for the meaning of "key" and "counter". 520 521 Args: 522 key: the new key. 523 counter: the new counter. 524 """ 525 counter = _convert_to_state_tensor(counter) 526 key = _convert_to_state_tensor(key) 527 counter.shape.assert_is_compatible_with( 528 [_get_state_size(self.algorithm) - 1]) 529 key.shape.assert_is_compatible_with([]) 530 key = array_ops.reshape(key, [1]) 531 state = array_ops.concat([counter, key], 0) 532 self._state_var.assign(state) 533 534 @property 535 def state(self): 536 """The internal state of the RNG.""" 537 return self._state_var 538 539 @property 540 def algorithm(self): 541 """The RNG algorithm id (a Python integer or scalar integer Tensor).""" 542 return self._alg 543 544 def _standard_normal(self, shape, dtype): 545 if compat.forward_compatible(2020, 10, 25): 546 key, counter = self._prepare_key_counter(shape) 547 return gen_stateless_random_ops_v2.stateless_random_normal_v2( 548 shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm) 549 return gen_stateful_random_ops.stateful_standard_normal_v2( 550 self.state.handle, self.algorithm, shape, dtype=dtype) 551 552 @property 553 def key(self): 554 """The 'key' part of the state of a counter-based RNG. 555 556 For a counter-base RNG algorithm such as Philox and ThreeFry (as 557 described in paper 'Parallel Random Numbers: As Easy as 1, 2, 3' 558 [https://www.thesalmons.org/john/random123/papers/random123sc11.pdf]), 559 the RNG state consists of two parts: counter and key. The output is 560 generated via the formula: output=hash(key, counter), i.e. a hashing of 561 the counter parametrized by the key. Two RNGs with two different keys can 562 be thought as generating two independent random-number streams (a stream 563 is formed by increasing the counter). 564 565 Returns: 566 A scalar which is the 'key' part of the state, if the RNG algorithm is 567 counter-based; otherwise it raises a ValueError. 568 """ 569 alg = self.algorithm 570 if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY: 571 return self._state_var[-1] 572 else: 573 raise ValueError("Unsupported algorithm id: %s" % alg) 574 575 # TODO(wangpeng): Add "Returns" section to docstring once new version kicks in 576 # pylint: disable=g-doc-return-or-yield 577 def skip(self, delta): 578 """Advance the counter of a counter-based RNG. 579 580 Args: 581 delta: the amount of advancement. The state of the RNG after 582 `skip(n)` will be the same as that after `normal([n])` 583 (or any other distribution). The actual increment added to the 584 counter is an unspecified implementation detail. 585 """ 586 if compat.forward_compatible(2020, 10, 25): 587 return self._skip(delta) 588 gen_stateful_random_ops.rng_skip( 589 self.state.handle, math_ops.cast(self.algorithm, dtypes.int64), 590 math_ops.cast(delta, dtypes.int64)) 591 # pylint: enable=g-doc-return-or-yield 592 593 def _skip_single_var(self, var, delta): 594 # TODO(wangpeng): Cache the cast algorithm instead of casting everytime. 595 return gen_stateful_random_ops.rng_read_and_skip( 596 var.handle, alg=math_ops.cast(self.algorithm, dtypes.int32), 597 delta=math_ops.cast(delta, dtypes.uint64)) 598 599 def _skip(self, delta): 600 def update_fn(v): 601 return self._skip_single_var(v, delta) 602 # TODO(b/170515001): Always call strategy.extended.update after calling it 603 # from both replica context and cross-replica context is supported. 604 if values_util.is_saving_non_distributed(): 605 # Assumes replica context with replica_id=0, since we only save the first 606 # replica. 607 return update_fn(self.state) 608 if self._distribution_strategy is not None: 609 with ds_context.enter_or_assert_strategy(self._distribution_strategy): 610 if ds_context.in_cross_replica_context(): 611 # Code that operates on all replicas of a variable cannot be saved 612 # without retracing. 613 values_util.mark_as_unsaveable() 614 # In cross-replica context we need to use strategy.extended.update. 615 return ds_context.get_strategy().extended.update( 616 self.state, update_fn) 617 return update_fn(self.state) 618 619 def _preprocess_key(self, key): 620 if self._distribution_strategy is None: 621 return key 622 with ds_context.enter_or_assert_strategy(self._distribution_strategy): 623 replica_id = get_replica_id() 624 if replica_id is not None: 625 replica_id = array_ops.stack([replica_id, 0], axis=0) 626 replica_id = math_ops.cast(replica_id, dtypes.uint64) 627 # Conceptually: key = hash(key, replica_id) 628 key = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( 629 shape=[1], key=key, counter=replica_id, dtype=dtypes.uint64, 630 alg=self.algorithm) 631 return key 632 633 def _prepare_key_counter(self, shape): 634 delta = math_ops.reduce_prod(shape) 635 counter_key = self.skip(delta) 636 counter_size = _get_counter_size(self.algorithm) 637 counter = array_ops.bitcast(counter_key[:counter_size], dtypes.uint64) 638 key = array_ops.bitcast(counter_key[counter_size:counter_size + 1], 639 dtypes.uint64) 640 key = self._preprocess_key(key) 641 return key, counter 642 643 # The following functions return a tensor and as a side effect update 644 # self._state_var. 645 def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32, 646 name=None): 647 """Outputs random values from a normal distribution. 648 649 Args: 650 shape: A 1-D integer Tensor or Python array. The shape of the output 651 tensor. 652 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal 653 distribution. 654 stddev: A 0-D Tensor or Python value of type `dtype`. The standard 655 deviation of the normal distribution. 656 dtype: The type of the output. 657 name: A name for the operation (optional). 658 659 Returns: 660 A tensor of the specified shape filled with random normal values. 661 """ 662 with ops.name_scope(name, "stateful_normal", [shape, mean, stddev]) as name: 663 shape = _shape_tensor(shape) 664 mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 665 stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 666 rnd = self._standard_normal(shape, dtype=dtype) 667 return math_ops.add(rnd * stddev, mean, name=name) 668 669 def _truncated_normal(self, shape, dtype): 670 if compat.forward_compatible(2020, 10, 25): 671 key, counter = self._prepare_key_counter(shape) 672 return gen_stateless_random_ops_v2.stateless_truncated_normal_v2( 673 shape=shape, 674 key=key, 675 counter=counter, 676 dtype=dtype, 677 alg=self.algorithm) 678 return gen_stateful_random_ops.stateful_truncated_normal( 679 self.state.handle, self.algorithm, shape, dtype=dtype) 680 681 def truncated_normal(self, shape, 682 mean=0.0, 683 stddev=1.0, 684 dtype=dtypes.float32, 685 name=None): 686 """Outputs random values from a truncated normal distribution. 687 688 The generated values follow a normal distribution with specified mean and 689 standard deviation, except that values whose magnitude is more than 690 2 standard deviations from the mean are dropped and re-picked. 691 692 Args: 693 shape: A 1-D integer Tensor or Python array. The shape of the output 694 tensor. 695 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the 696 truncated normal distribution. 697 stddev: A 0-D Tensor or Python value of type `dtype`. The standard 698 deviation of the normal distribution, before truncation. 699 dtype: The type of the output. 700 name: A name for the operation (optional). 701 702 Returns: 703 A tensor of the specified shape filled with random truncated normal 704 values. 705 """ 706 with ops.name_scope( 707 name, "truncated_normal", [shape, mean, stddev]) as name: 708 shape_tensor = _shape_tensor(shape) 709 mean_tensor = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 710 stddev_tensor = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 711 rnd = self._truncated_normal(shape_tensor, dtype=dtype) 712 mul = rnd * stddev_tensor 713 return math_ops.add(mul, mean_tensor, name=name) 714 715 def _uniform(self, shape, dtype): 716 if compat.forward_compatible(2020, 10, 25): 717 key, counter = self._prepare_key_counter(shape) 718 return gen_stateless_random_ops_v2.stateless_random_uniform_v2( 719 shape=shape, 720 key=key, 721 counter=counter, 722 dtype=dtype, 723 alg=self.algorithm) 724 return gen_stateful_random_ops.stateful_uniform( 725 self.state.handle, self.algorithm, shape=shape, dtype=dtype) 726 727 def _uniform_full_int(self, shape, dtype, name=None): 728 if compat.forward_compatible(2020, 10, 25): 729 key, counter = self._prepare_key_counter(shape) 730 return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( 731 shape=shape, 732 key=key, 733 counter=counter, 734 dtype=dtype, 735 alg=self.algorithm, 736 name=name) 737 return gen_stateful_random_ops.stateful_uniform_full_int( 738 self.state.handle, self.algorithm, shape=shape, 739 dtype=dtype, name=name) 740 741 def uniform(self, shape, minval=0, maxval=None, 742 dtype=dtypes.float32, name=None): 743 """Outputs random values from a uniform distribution. 744 745 The generated values follow a uniform distribution in the range 746 `[minval, maxval)`. The lower bound `minval` is included in the range, while 747 the upper bound `maxval` is excluded. (For float numbers especially 748 low-precision types like bfloat16, because of 749 rounding, the result may sometimes include `maxval`.) 750 751 For floats, the default range is `[0, 1)`. For ints, at least `maxval` must 752 be specified explicitly. 753 754 In the integer case, the random integers are slightly biased unless 755 `maxval - minval` is an exact power of two. The bias is small for values of 756 `maxval - minval` significantly smaller than the range of the output (either 757 `2**32` or `2**64`). 758 759 For full-range random integers, pass `minval=None` and `maxval=None` with an 760 integer `dtype` (for integer dtypes, `minval` and `maxval` must be both 761 `None` or both not `None`). 762 763 Args: 764 shape: A 1-D integer Tensor or Python array. The shape of the output 765 tensor. 766 minval: A Tensor or Python value of type `dtype`, broadcastable with 767 `shape` (for integer types, broadcasting is not supported, so it needs 768 to be a scalar). The lower bound (included) on the range of random 769 values to generate. Pass `None` for full-range integers. Defaults to 0. 770 maxval: A Tensor or Python value of type `dtype`, broadcastable with 771 `shape` (for integer types, broadcasting is not supported, so it needs 772 to be a scalar). The upper bound (excluded) on the range of random 773 values to generate. Pass `None` for full-range integers. Defaults to 1 774 if `dtype` is floating point. 775 dtype: The type of the output. 776 name: A name for the operation (optional). 777 778 Returns: 779 A tensor of the specified shape filled with random uniform values. 780 781 Raises: 782 ValueError: If `dtype` is integral and `maxval` is not specified. 783 """ 784 dtype = dtypes.as_dtype(dtype) 785 if dtype.is_integer: 786 if (minval is None) != (maxval is None): 787 raise ValueError("For integer dtype {}, minval and maxval must be both " 788 "`None` or both non-`None`; got minval={} and " 789 "maxval={}".format(dtype, minval, maxval)) 790 elif maxval is None: 791 maxval = 1 792 with ops.name_scope(name, "stateful_uniform", 793 [shape, minval, maxval]) as name: 794 shape = _shape_tensor(shape) 795 if dtype.is_integer and minval is None: 796 return self._uniform_full_int(shape=shape, dtype=dtype, name=name) 797 minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") 798 maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") 799 if dtype.is_integer: 800 if compat.forward_compatible(2020, 10, 25): 801 key, counter = self._prepare_key_counter(shape) 802 return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2( 803 shape=shape, 804 key=key, 805 counter=counter, 806 minval=minval, 807 maxval=maxval, 808 alg=self.algorithm, 809 name=name) 810 return gen_stateful_random_ops.stateful_uniform_int( 811 self.state.handle, self.algorithm, shape=shape, 812 minval=minval, maxval=maxval, name=name) 813 else: 814 rnd = self._uniform(shape=shape, dtype=dtype) 815 return math_ops.add(rnd * (maxval - minval), minval, name=name) 816 817 def uniform_full_int(self, shape, dtype=dtypes.uint64, name=None): 818 """Uniform distribution on an integer type's entire range. 819 820 This method is the same as setting `minval` and `maxval` to `None` in the 821 `uniform` method. 822 823 Args: 824 shape: the shape of the output. 825 dtype: (optional) the integer type, default to uint64. 826 name: (optional) the name of the node. 827 828 Returns: 829 A tensor of random numbers of the required shape. 830 """ 831 dtype = dtypes.as_dtype(dtype) 832 with ops.name_scope(name, "stateful_uniform_full_int", 833 [shape]) as name: 834 shape = _shape_tensor(shape) 835 return self._uniform_full_int(shape=shape, dtype=dtype, name=name) 836 837 def binomial(self, shape, counts, probs, dtype=dtypes.int32, name=None): 838 """Outputs random values from a binomial distribution. 839 840 The generated values follow a binomial distribution with specified count and 841 probability of success parameters. 842 843 Example: 844 845 ```python 846 counts = [10., 20.] 847 # Probability of success. 848 probs = [0.8] 849 850 rng = tf.random.Generator.from_seed(seed=234) 851 binomial_samples = rng.binomial(shape=[2], counts=counts, probs=probs) 852 853 854 counts = ... # Shape [3, 1, 2] 855 probs = ... # Shape [1, 4, 2] 856 shape = [3, 4, 3, 4, 2] 857 rng = tf.random.Generator.from_seed(seed=1717) 858 # Sample shape will be [3, 4, 3, 4, 2] 859 binomial_samples = rng.binomial(shape=shape, counts=counts, probs=probs) 860 ``` 861 862 863 Args: 864 shape: A 1-D integer Tensor or Python array. The shape of the output 865 tensor. 866 counts: Tensor. The counts of the binomial distribution. Must be 867 broadcastable with `probs`, and broadcastable with the rightmost 868 dimensions of `shape`. 869 probs: Tensor. The probability of success for the 870 binomial distribution. Must be broadcastable with `counts` and 871 broadcastable with the rightmost dimensions of `shape`. 872 dtype: The type of the output. Default: tf.int32 873 name: A name for the operation (optional). 874 875 Returns: 876 samples: A Tensor of the specified shape filled with random binomial 877 values. For each i, each samples[i, ...] is an independent draw from 878 the binomial distribution on counts[i] trials with probability of 879 success probs[i]. 880 """ 881 dtype = dtypes.as_dtype(dtype) 882 with ops.name_scope(name, "binomial", [shape, counts, probs]) as name: 883 counts = ops.convert_to_tensor(counts, name="counts") 884 probs = ops.convert_to_tensor(probs, name="probs") 885 shape_tensor = _shape_tensor(shape) 886 return gen_stateful_random_ops.stateful_random_binomial( 887 self.state.handle, 888 self.algorithm, 889 shape=shape_tensor, 890 counts=counts, 891 probs=probs, 892 dtype=dtype, 893 name=name) 894 895 # TODO(wangpeng): implement other distributions 896 897 def _make_int64_keys(self, shape=()): 898 # New independent keys are generated via 899 # `new_key[i] = hash(old_key, counter+i)`, which is exactly what 900 # `uniform_full_int(dtype=int64)` does for PhiloxRandom_64_128_128 and 901 # ThreeFry_64_64_64. 902 return self.uniform_full_int(shape=shape, dtype=dtypes.int64) 903 904 def make_seeds(self, count=1): 905 """Generates seeds for stateless random ops. 906 907 For example: 908 909 ```python 910 seeds = get_global_generator().make_seeds(count=10) 911 for i in range(10): 912 seed = seeds[:, i] 913 numbers = stateless_random_normal(shape=[2, 3], seed=seed) 914 ... 915 ``` 916 917 Args: 918 count: the number of seed pairs (note that stateless random ops need a 919 pair of seeds to invoke). 920 921 Returns: 922 A tensor of shape [2, count] and dtype int64. 923 """ 924 alg = self.algorithm 925 if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY: 926 keys = self._make_int64_keys(shape=[count]) 927 # The two seeds for stateless random ops don't have individual semantics 928 # and are scrambled together, so setting one to zero is fine. 929 zeros = array_ops.zeros_like(keys) 930 return array_ops.stack([keys, zeros]) 931 else: 932 raise ValueError("Unsupported algorithm id: %s" % alg) 933 934 def split(self, count=1): 935 """Returns a list of independent `Generator` objects. 936 937 Two generators are independent of each other in the sense that the 938 random-number streams they generate don't have statistically detectable 939 correlations. The new generators are also independent of the old one. 940 The old generator's state will be changed (like other random-number 941 generating methods), so two calls of `split` will return different 942 new generators. 943 944 For example: 945 946 ```python 947 gens = get_global_generator().split(count=10) 948 for gen in gens: 949 numbers = gen.normal(shape=[2, 3]) 950 # ... 951 gens2 = get_global_generator().split(count=10) 952 # gens2 will be different from gens 953 ``` 954 955 The new generators will be put on the current device (possible different 956 from the old generator's), for example: 957 958 ```python 959 with tf.device("/device:CPU:0"): 960 gen = Generator(seed=1234) # gen is on CPU 961 with tf.device("/device:GPU:0"): 962 gens = gen.split(count=10) # gens are on GPU 963 ``` 964 965 Args: 966 count: the number of generators to return. 967 968 Returns: 969 A list (length `count`) of `Generator` objects independent of each other. 970 The new generators have the same RNG algorithm as the old one. 971 """ 972 def _key_to_state(alg, key): 973 # Padding with zeros on the left. The zeros will be the counter. 974 return [0] * (_get_state_size(alg) - 1) + [key] 975 976 alg = self.algorithm 977 if alg == RNG_ALG_PHILOX or alg == RNG_ALG_THREEFRY: 978 keys = self._make_int64_keys(shape=[count]) 979 return [Generator(state=_key_to_state(alg, key), alg=alg) 980 for key in keys.numpy()] 981 else: 982 raise ValueError("Unsupported algorithm id: %s" % alg) 983 984 985# It's not safe to create TF ops before `init_google` is called, so this is 986# initialized to None and get a value the first time `get_global_generator` is 987# called. 988global_generator = None 989 990 991@tf_export("random.get_global_generator", 992 "random.experimental.get_global_generator") 993def get_global_generator(): 994 """Retrieves the global generator. 995 996 This function will create the global generator the first time it is called, 997 and the generator will be placed at the default device at that time, so one 998 needs to be careful when this function is first called. Using a generator 999 placed on a less-ideal device will incur performance regression. 1000 1001 Returns: 1002 The global `tf.random.Generator` object. 1003 """ 1004 global global_generator 1005 if global_generator is None: 1006 with ops.init_scope(): 1007 global_generator = Generator.from_non_deterministic_state() 1008 return global_generator 1009 1010 1011@tf_export("random.set_global_generator", 1012 "random.experimental.set_global_generator") 1013def set_global_generator(generator): 1014 """Replaces the global generator with another `Generator` object. 1015 1016 This function creates a new Generator object (and the Variable object within), 1017 which does not work well with tf.function because (1) tf.function puts 1018 restrictions on Variable creation thus reset_global_generator can't be freely 1019 used inside tf.function; (2) redirecting a global variable to 1020 a new object is problematic with tf.function because the old object may be 1021 captured by a 'tf.function'ed function and still be used by it. 1022 A 'tf.function'ed function only keeps weak references to variables, 1023 so deleting a variable and then calling that function again may raise an 1024 error, as demonstrated by 1025 random_test.py/RandomTest.testResetGlobalGeneratorBadWithDefun . 1026 1027 Args: 1028 generator: the new `Generator` object. 1029 """ 1030 global global_generator 1031 global_generator = generator 1032