1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Stateless random ops which take seed as a tensor input.""" 16import enum 17import numpy as np 18 19from tensorflow.python.compat import compat 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import bitwise_ops 26from tensorflow.python.ops import gen_random_index_shuffle_ops 27from tensorflow.python.ops import gen_stateless_random_ops 28from tensorflow.python.ops import gen_stateless_random_ops_v2 29from tensorflow.python.ops import math_ops 30from tensorflow.python.util import deprecation 31from tensorflow.python.util import dispatch 32from tensorflow.python.util.tf_export import tf_export 33 34 35ops.NotDifferentiable("StatelessMultinomial") 36ops.NotDifferentiable("StatelessRandomBinomial") 37ops.NotDifferentiable("StatelessRandomNormal") 38ops.NotDifferentiable("StatelessRandomPoisson") 39ops.NotDifferentiable("StatelessRandomUniform") 40ops.NotDifferentiable("StatelessRandomUniformInt") 41ops.NotDifferentiable("StatelessRandomUniformFullInt") 42ops.NotDifferentiable("StatelessTruncatedNormal") 43ops.NotDifferentiable("StatelessRandomNormalV2") 44ops.NotDifferentiable("StatelessRandomUniformV2") 45ops.NotDifferentiable("StatelessRandomUniformIntV2") 46ops.NotDifferentiable("StatelessRandomUniformFullIntV2") 47ops.NotDifferentiable("StatelessTruncatedNormalV2") 48ops.NotDifferentiable("StatelessRandomShuffle") 49ops.NotDifferentiable("RandomIndexShuffle") 50 51 52@tf_export("random.Algorithm", "random.experimental.Algorithm") 53class Algorithm(enum.Enum): 54 # The numbers here must match framework/rng_alg.h 55 PHILOX = 1 56 THREEFRY = 2 57 AUTO_SELECT = 3 58 59 60def convert_alg_to_int(alg): 61 """Converts algorithm to an integer. 62 63 Args: 64 alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed 65 strings are "philox" and "threefry". 66 67 Returns: 68 An integer, unless the input is a Tensor in which case a Tensor is returned. 69 """ 70 if isinstance(alg, int): 71 return alg 72 if isinstance(alg, Algorithm): 73 return alg.value 74 if isinstance(alg, ops.Tensor): 75 return alg 76 if isinstance(alg, str): 77 if alg == "philox": 78 return Algorithm.PHILOX.value 79 elif alg in ("threefry", "three-fry", "three_fry"): 80 return Algorithm.THREEFRY.value 81 elif alg in ("autoselect", "auto-select", "auto_select"): 82 return Algorithm.AUTO_SELECT.value 83 else: 84 raise ValueError( 85 f"Argument `alg` got unsupported string value {alg}. Supported " 86 f"string values are 'philox' for the Philox algorithm, 'threefry' " 87 f"for the ThreeFry algorithm, and 'auto_select' for auto-selection.") 88 else: 89 raise TypeError( 90 f"Can't convert argument `alg` (of value {alg} and type {type(alg)}) " 91 f"to int.") 92 93 94def _resolve_alg(alg): 95 if alg == Algorithm.AUTO_SELECT.value: 96 return gen_stateless_random_ops_v2.stateless_random_get_alg() 97 return alg 98 99 100def _get_key_counter(seed, alg): 101 """Calculates the key and counter to pass to raw RNG ops. 102 103 This function calculates the key and counter that will be passed to 104 the raw RNG ops like `StatelessRandomUniformV2`. Depending on the 105 input `alg`, the key and counter may be scrambled or copied from 106 `seed`. If `alg` is `"auto_select"`, the key and counter will be 107 determined at runtime based on device type. 108 109 Args: 110 seed: An integer tensor of shape [2]. The seed to calculate the 111 key and counter from. 112 alg: The RNG algorithm. See `tf.random.stateless_uniform` for an 113 explanation. 114 115 Returns: 116 A pair (key, counter) suitable for V2 stateless RNG ops like 117 `StatelessRandomUniformV2`. 118 """ 119 if alg == Algorithm.AUTO_SELECT.value: 120 key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter( 121 seed) 122 elif alg == Algorithm.PHILOX.value: 123 key, counter = _philox_scramble_seed(seed) 124 elif alg == Algorithm.THREEFRY.value: 125 key = array_ops.reshape( 126 uint32s_to_uint64(math_ops.cast(seed, dtypes.uint32)), [1]) 127 counter = array_ops.zeros([1], dtypes.uint64) 128 else: 129 raise ValueError( 130 f"Argument `alg` got unsupported value {alg}. Supported values are " 131 f"{Algorithm.PHILOX.value} for the Philox algorithm, " 132 f"{Algorithm.THREEFRY.value} for the ThreeFry algorithm, and " 133 f"{Algorithm.AUTO_SELECT.value} for auto-selection.") 134 return key, counter 135 136 137def _get_key_counter_alg(seed, alg): 138 if alg is None: 139 alg = Algorithm.AUTO_SELECT.value 140 alg = convert_alg_to_int(alg) 141 key, counter = _get_key_counter(seed, alg) 142 if compat.forward_compatible(2021, 8, 11): 143 return key, counter, alg 144 else: 145 return key, counter, _resolve_alg(alg) 146 147 148def _philox_scramble_seed(seed): 149 # the same scrambling procedure as core/kernels/stateless_random_ops.cc 150 key = constant_op.constant([0x02461e293ec8f720], dtypes.uint64) 151 counter = math_ops.cast(seed, dtypes.uint64) 152 mix = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( 153 [4], key=key, counter=counter, dtype=dtypes.uint32, 154 alg=Algorithm.PHILOX.value) 155 key = array_ops.reshape(uint32s_to_uint64(mix[:2]), [1]) 156 counter = array_ops.stack([0, uint32s_to_uint64(mix[2:])], axis=0) 157 return key, counter 158 159 160def uint32s_to_uint64(x): 161 return bitwise_ops.bitwise_or( 162 math_ops.cast(x[0], dtypes.uint64), 163 bitwise_ops.left_shift(math_ops.cast(x[1], dtypes.uint64), 164 constant_op.constant(32, dtypes.uint64))) 165 166 167@tf_export("random.experimental.stateless_split") 168@dispatch.add_dispatch_support 169def split(seed, num=2, alg="auto_select"): 170 """Splits an RNG seed into `num` new seeds by adding a leading axis. 171 172 Example: 173 174 >>> seed = [1, 2] 175 >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3) 176 >>> print(new_seeds) 177 tf.Tensor( 178 [[1105988140 1738052849] 179 [-335576002 370444179] 180 [ 10670227 -246211131]], shape=(3, 2), dtype=int32) 181 >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :]) 182 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.59835213, -0.9578608 , 183 0.9002807 ], dtype=float32)> 184 185 Args: 186 seed: an RNG seed (a tensor with shape [2] and dtype `int32` or 187 `int64`). (When using XLA, only `int32` is allowed.) 188 num: optional, a positive integer or scalar tensor indicating the number of 189 seeds to produce (default 2). 190 alg: The RNG algorithm used to generate the random numbers. See 191 `tf.random.stateless_uniform` for a detailed explanation. 192 193 Returns: 194 A tensor with shape [num, 2] representing `num` new seeds. It will have the 195 same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype 196 will be determined by `tf.convert_to_tensor`). 197 """ 198 seed = ops.convert_to_tensor(seed) 199 return stateless_random_uniform(shape=[num, 2], seed=seed, dtype=seed.dtype, 200 minval=None, maxval=None, alg=alg) 201 202 203@tf_export("random.experimental.stateless_fold_in") 204@dispatch.add_dispatch_support 205def fold_in(seed, data, alg="auto_select"): 206 """Folds in data to an RNG seed to form a new RNG seed. 207 208 For example, in a distributed-training setting, suppose we have a master seed 209 and a replica ID. We want to fold the replica ID into the master seed to 210 form a "replica seed" to be used by that replica later on, so that different 211 replicas will generate different random numbers but the reproducibility of the 212 whole system can still be controlled by the master seed: 213 214 >>> master_seed = [1, 2] 215 >>> replica_id = 3 216 >>> replica_seed = tf.random.experimental.stateless_fold_in( 217 ... master_seed, replica_id) 218 >>> print(replica_seed) 219 tf.Tensor([1105988140 3], shape=(2,), dtype=int32) 220 >>> tf.random.stateless_normal(shape=[3], seed=replica_seed) 221 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.03197195, 0.8979765 , 222 0.13253039], dtype=float32)> 223 224 Args: 225 seed: an RNG seed (a tensor with shape [2] and dtype `int32` or 226 `int64`). (When using XLA, only `int32` is allowed.) 227 data: an `int32` or `int64` scalar representing data to be folded in to the 228 seed. 229 alg: The RNG algorithm used to generate the random numbers. See 230 `tf.random.stateless_uniform` for a detailed explanation. 231 232 Returns: 233 A new RNG seed that is a deterministic function of the inputs and is 234 statistically safe for producing a stream of new pseudo-random values. It 235 will have the same dtype as `data` (if `data` doesn't have an explict dtype, 236 the dtype will be determined by `tf.convert_to_tensor`). 237 """ 238 data = ops.convert_to_tensor(data) 239 seed1 = stateless_random_uniform(shape=[], seed=seed, dtype=data.dtype, 240 minval=None, maxval=None, alg=alg) 241 return array_ops.stack([seed1, data]) 242 243 244@tf_export("random.experimental.index_shuffle") 245@dispatch.add_dispatch_support 246def index_shuffle(index, seed, max_index): 247 """Outputs the position of `index` in a permutation of [0, ..., max_index]. 248 249 For each possible `seed` and `max_index` there is one pseudorandom permutation 250 of the sequence S=[0, ..., max_index]. Instead of materializing the full array 251 we can compute the new position of any single element in S. This can be useful 252 for very large `max_index`s. 253 254 The input `index` and output can be used as indices to shuffle a vector. 255 For example: 256 257 >>> vector = tf.constant(['e0', 'e1', 'e2', 'e3']) 258 >>> indices = tf.random.experimental.index_shuffle(tf.range(4), [5, 9], 3) 259 >>> shuffled_vector = tf.gather(vector, indices) 260 >>> print(shuffled_vector) 261 tf.Tensor([b'e2' b'e0' b'e1' b'e3'], shape=(4,), dtype=string) 262 263 More usefully, it can be used in a streaming (aka online) scenario such as 264 `tf.data`, where each element of `vector` is processed individually and the 265 whole `vector` is never materialized in memory. 266 267 >>> dataset = tf.data.Dataset.range(10) 268 >>> dataset = dataset.map( 269 ... lambda idx: tf.random.experimental.index_shuffle(idx, [5, 8], 9)) 270 >>> print(list(dataset.as_numpy_iterator())) 271 [3, 8, 0, 1, 2, 7, 6, 9, 4, 5] 272 273 This operation is stateless (like other `tf.random.stateless_*` functions), 274 meaning the output is fully determined by the `seed` (other inputs being 275 equal). 276 Each `seed` choice corresponds to one permutation, so when calling this 277 function 278 multiple times for the same shuffling, please make sure to use the same 279 `seed`. For example: 280 281 >>> seed = [5, 9] 282 >>> idx0 = tf.random.experimental.index_shuffle(0, seed, 3) 283 >>> idx1 = tf.random.experimental.index_shuffle(1, seed, 3) 284 >>> idx2 = tf.random.experimental.index_shuffle(2, seed, 3) 285 >>> idx3 = tf.random.experimental.index_shuffle(3, seed, 3) 286 >>> shuffled_vector = tf.gather(vector, [idx0, idx1, idx2, idx3]) 287 >>> print(shuffled_vector) 288 tf.Tensor([b'e2' b'e0' b'e1' b'e3'], shape=(4,), dtype=string) 289 290 Args: 291 index: An integer scalar tensor or vector with values in [0, `max_index`]. 292 It can be seen as either a value `v` in the sequence `S`=[0, ..., 293 `max_index`] to be permutated, or as an index of an element `e` in a 294 shuffled vector. 295 seed: A tensor of shape [2] or [n, 2] with dtype int32/uint32/int64/uint64. 296 The RNG seed. If the rank is unknown during graph building it must be 1 at 297 runtime. 298 max_index: A non-negative tensor with the same shape and dtype as `index`. 299 The upper bound (inclusive). 300 301 Returns: 302 If all inputs were scalar (shape [2] for `seed`) the output will be a scalar 303 with the same dtype as `index`. The output can be seen as the new position 304 of `v` in `S`, or as the index of `e` in the vector before shuffling. 305 If one or multiple inputs were vectors (shape [n, 2] for `seed`) then the 306 output will be a vector of the same size which each element shuffled 307 independently. Scalar values are broadcasted in this case. 308 """ 309 # We expect users to pass a seed with shape [2] to be consistent with other 310 # stateless_* ops, but the raw op expects shape [3]. 311 seed = ops.convert_to_tensor(seed) 312 # Pad the first dimension with an arbitrary number since our raw op expects 313 # shape [3]. 314 if seed.shape.rank is None: 315 paddings = [[1, 0]] 316 else: 317 paddings = [[1, 0]] + (seed.shape.rank - 1) * [[0, 0]] 318 seed = array_ops.pad(seed, paddings, constant_values=498247692) 319 return gen_random_index_shuffle_ops.random_index_shuffle( 320 index, seed=seed, max_index=max_index) 321 322 323@tf_export("random.experimental.stateless_shuffle") 324@dispatch.add_dispatch_support 325def stateless_shuffle(value, seed, alg="auto_select", name=None): 326 """Randomly and deterministically shuffles a tensor along its first dimension. 327 328 The tensor is shuffled along dimension 0, such that each `value[j]` is mapped 329 to one and only one `output[i]`. For example, a mapping that might occur for a 330 3x2 tensor is: 331 332 ```python 333 [[1, 2], [[5, 6], 334 [3, 4], ==> [1, 2], 335 [5, 6]] [3, 4]] 336 ``` 337 338 >>> v = tf.constant([[1, 2], [3, 4], [5, 6]]) 339 >>> shuffled = tf.random.experimental.stateless_shuffle(v, seed=[8, 9]) 340 >>> print(shuffled) 341 tf.Tensor( 342 [[5 6] 343 [1 2] 344 [3 4]], shape=(3, 2), dtype=int32) 345 346 This is a stateless version of `tf.random.shuffle`: if run twice with the 347 same `value` and `seed`, it will produce the same result. The 348 output is consistent across multiple runs on the same hardware (and between 349 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 350 hardware. 351 352 Args: 353 value: A Tensor to be shuffled. 354 seed: A shape [2] Tensor. The seed to the random number generator. Must have 355 dtype `int32` or `int64`. 356 alg: The RNG algorithm used to generate the random numbers. See 357 `tf.random.stateless_uniform` for a detailed explanation. 358 name: A name for the operation. 359 360 Returns: 361 A tensor of same shape and type as `value`, shuffled along its first 362 dimension. 363 """ 364 with ops.name_scope(name, "stateless_shuffle", [value, seed]) as name: 365 key, counter, alg = _get_key_counter_alg(seed, alg) 366 return gen_stateless_random_ops_v2.stateless_shuffle( 367 value, key=key, counter=counter, alg=alg) 368 369 370@tf_export("random.stateless_uniform") 371@dispatch.add_dispatch_support 372def stateless_random_uniform(shape, 373 seed, 374 minval=0, 375 maxval=None, 376 dtype=dtypes.float32, 377 name=None, 378 alg="auto_select"): 379 """Outputs deterministic pseudorandom values from a uniform distribution. 380 381 This is a stateless version of `tf.random.uniform`: if run twice with the 382 same seeds and shapes, it will produce the same pseudorandom numbers. The 383 output is consistent across multiple runs on the same hardware (and between 384 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 385 hardware. 386 387 The generated values follow a uniform distribution in the range 388 `[minval, maxval)`. The lower bound `minval` is included in the range, while 389 the upper bound `maxval` is excluded. 390 391 For floats, the default range is `[0, 1)`. For ints, at least `maxval` must 392 be specified explicitly. 393 394 In the integer case, the random integers are slightly biased unless 395 `maxval - minval` is an exact power of two. The bias is small for values of 396 `maxval - minval` significantly smaller than the range of the output (either 397 `2**32` or `2**64`). 398 399 For full-range (i.e. inclusive of both max and min) random integers, pass 400 `minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype 401 either both `minval` and `maxval` must be `None` or neither may be `None`. For 402 example: 403 ```python 404 ints = tf.random.stateless_uniform( 405 [10], seed=(2, 3), minval=None, maxval=None, dtype=tf.int32) 406 ``` 407 408 Args: 409 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 410 seed: A shape [2] Tensor, the seed to the random number generator. Must have 411 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 412 minval: A Tensor or Python value of type `dtype`, broadcastable with 413 `shape` (for integer types, broadcasting is not supported, so it needs to 414 be a scalar). The lower bound on the range of random values to 415 generate. Pass `None` for full-range integers. Defaults to 0. 416 maxval: A Tensor or Python value of type `dtype`, broadcastable with 417 `shape` (for integer types, broadcasting is not supported, so it needs to 418 be a scalar). The upper bound on the range of random values to generate. 419 Defaults to 1 if `dtype` is floating point. Pass `None` for full-range 420 integers. 421 dtype: The type of the output: `float16`, `bfloat16`, `float32`, `float64`, 422 `int32`, or `int64`. For unbounded uniform ints (`minval`, `maxval` both 423 `None`), `uint32` and `uint64` may be used. Defaults to `float32`. 424 name: A name for the operation (optional). 425 alg: The RNG algorithm used to generate the random numbers. Valid 426 choices are `"philox"` for [the Philox 427 algorithm](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf), 428 `"threefry"` for [the ThreeFry 429 algorithm](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf), 430 and `"auto_select"` (default) for the system to automatically 431 select an algorithm based the device type. Values of 432 `tf.random.Algorithm` can also be used. Note that with 433 `"auto_select"`, the outputs of this function may change when 434 it is running on a different device. 435 436 Returns: 437 A tensor of the specified shape filled with random uniform values. 438 439 Raises: 440 ValueError: If `dtype` is integral and only one of `minval` or `maxval` is 441 specified. 442 """ 443 dtype = dtypes.as_dtype(dtype) 444 accepted_dtypes = (dtypes.float16, dtypes.bfloat16, dtypes.float32, 445 dtypes.float64, dtypes.int32, dtypes.int64, dtypes.uint32, 446 dtypes.uint64) 447 if dtype not in accepted_dtypes: 448 raise ValueError( 449 f"Argument `dtype` got invalid value {dtype}. Accepted dtypes are " 450 f"{accepted_dtypes}.") 451 if dtype.is_integer: 452 if (minval is None) != (maxval is None): 453 raise ValueError( 454 f"For integer `dtype` argument {dtype}, argument `minval` and " 455 f"`maxval` must be both None or not None. Got `minval`={minval} and " 456 f"`maxval`={maxval}.") 457 if minval is not None and dtype in (dtypes.uint32, dtypes.uint64): 458 raise ValueError( 459 f"Argument `dtype` got invalid value {dtype} when argument `minval` " 460 f"is not None. Please don't use unsigned integers in this case.") 461 elif maxval is None: 462 maxval = 1 463 with ops.name_scope(name, "stateless_random_uniform", 464 [shape, seed, minval, maxval]) as name: 465 shape = tensor_util.shape_tensor(shape) 466 if dtype.is_integer and minval is None: 467 key, counter, alg = _get_key_counter_alg(seed, alg) 468 result = ( 469 gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( 470 shape, key=key, counter=counter, dtype=dtype, alg=alg, name=name)) 471 else: 472 minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") 473 maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") 474 if dtype.is_integer: 475 key, counter, alg = _get_key_counter_alg(seed, alg) 476 result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2( 477 shape, 478 key=key, 479 counter=counter, 480 minval=minval, 481 maxval=maxval, 482 alg=alg, 483 name=name) 484 else: 485 key, counter, alg = _get_key_counter_alg(seed, alg) 486 rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2( 487 shape, key=key, counter=counter, dtype=dtype, alg=alg) 488 result = math_ops.add(rnd * (maxval - minval), minval, name=name) 489 tensor_util.maybe_set_static_shape(result, shape) 490 return result 491 492 493@tf_export("random.stateless_binomial") 494@dispatch.add_dispatch_support 495def stateless_random_binomial(shape, 496 seed, 497 counts, 498 probs, 499 output_dtype=dtypes.int32, 500 name=None): 501 """Outputs deterministic pseudorandom values from a binomial distribution. 502 503 The generated values follow a binomial distribution with specified count and 504 probability of success parameters. 505 506 This is a stateless version of `tf.random.Generator.binomial`: if run twice 507 with the same seeds and shapes, it will produce the same pseudorandom numbers. 508 The output is consistent across multiple runs on the same hardware (and 509 between CPU and GPU), but may change between versions of TensorFlow or on 510 non-CPU/GPU hardware. 511 512 Example: 513 514 ```python 515 counts = [10., 20.] 516 # Probability of success. 517 probs = [0.8] 518 519 binomial_samples = tf.random.stateless_binomial( 520 shape=[2], seed=[123, 456], counts=counts, probs=probs) 521 522 counts = ... # Shape [3, 1, 2] 523 probs = ... # Shape [1, 4, 2] 524 shape = [3, 4, 3, 4, 2] 525 # Sample shape will be [3, 4, 3, 4, 2] 526 binomial_samples = tf.random.stateless_binomial( 527 shape=shape, seed=[123, 456], counts=counts, probs=probs) 528 ``` 529 530 Args: 531 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 532 seed: A shape [2] Tensor, the seed to the random number generator. Must have 533 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 534 counts: Tensor. The counts of the binomial distribution. Must be 535 broadcastable with `probs`, and broadcastable with the rightmost 536 dimensions of `shape`. 537 probs: Tensor. The probability of success for the binomial distribution. 538 Must be broadcastable with `counts` and broadcastable with the rightmost 539 dimensions of `shape`. 540 output_dtype: The type of the output. Default: tf.int32 541 name: A name for the operation (optional). 542 543 Returns: 544 samples: A Tensor of the specified shape filled with random binomial 545 values. For each i, each samples[..., i] is an independent draw from 546 the binomial distribution on counts[i] trials with probability of 547 success probs[i]. 548 549 """ 550 with ops.name_scope(name, "stateless_random_binomial", 551 [shape, seed, counts, probs]) as name: 552 shape = tensor_util.shape_tensor(shape) 553 probs = ops.convert_to_tensor( 554 probs, dtype_hint=dtypes.float32, name="probs") 555 counts = ops.convert_to_tensor( 556 counts, dtype_hint=probs.dtype, name="counts") 557 result = gen_stateless_random_ops.stateless_random_binomial( 558 shape=shape, seed=seed, counts=counts, probs=probs, dtype=output_dtype) 559 tensor_util.maybe_set_static_shape(result, shape) 560 return result 561 562 563@tf_export("random.stateless_gamma") 564@dispatch.add_dispatch_support 565def stateless_random_gamma(shape, 566 seed, 567 alpha, 568 beta=None, 569 dtype=dtypes.float32, 570 name=None): 571 """Outputs deterministic pseudorandom values from a gamma distribution. 572 573 The generated values follow a gamma distribution with specified concentration 574 (`alpha`) and inverse scale (`beta`) parameters. 575 576 This is a stateless version of `tf.random.gamma`: if run twice with the same 577 seeds and shapes, it will produce the same pseudorandom numbers. The output is 578 consistent across multiple runs on the same hardware (and between CPU and 579 GPU), 580 but may change between versions of TensorFlow or on non-CPU/GPU hardware. 581 582 A slight difference exists in the interpretation of the `shape` parameter 583 between `stateless_gamma` and `gamma`: in `gamma`, the `shape` is always 584 prepended to the shape of the broadcast of `alpha` with `beta`; whereas in 585 `stateless_gamma` the `shape` parameter must always encompass the shapes of 586 each of `alpha` and `beta` (which must broadcast together to match the 587 trailing dimensions of `shape`). 588 589 Note: Because internal calculations are done using `float64` and casting has 590 `floor` semantics, we must manually map zero outcomes to the smallest 591 possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This 592 means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise 593 should. This bias can only happen for small values of `alpha`, i.e., 594 `alpha << 1` or large values of `beta`, i.e., `beta >> 1`. 595 596 The samples are differentiable w.r.t. alpha and beta. 597 The derivatives are computed using the approach described in 598 (Figurnov et al., 2018). 599 600 Example: 601 602 ```python 603 samples = tf.random.stateless_gamma([10, 2], seed=[12, 34], alpha=[0.5, 1.5]) 604 # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents 605 # the samples drawn from each distribution 606 607 samples = tf.random.stateless_gamma([7, 5, 2], seed=[12, 34], alpha=[.5, 1.5]) 608 # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] 609 # represents the 7x5 samples drawn from each of the two distributions 610 611 alpha = tf.constant([[1.], [3.], [5.]]) 612 beta = tf.constant([[3., 4.]]) 613 samples = tf.random.stateless_gamma( 614 [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta) 615 # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions. 616 617 with tf.GradientTape() as tape: 618 tape.watch([alpha, beta]) 619 loss = tf.reduce_mean(tf.square(tf.random.stateless_gamma( 620 [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta))) 621 dloss_dalpha, dloss_dbeta = tape.gradient(loss, [alpha, beta]) 622 # unbiased stochastic derivatives of the loss function 623 alpha.shape == dloss_dalpha.shape # True 624 beta.shape == dloss_dbeta.shape # True 625 ``` 626 627 Args: 628 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 629 seed: A shape [2] Tensor, the seed to the random number generator. Must have 630 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 631 alpha: Tensor. The concentration parameter of the gamma distribution. Must 632 be broadcastable with `beta`, and broadcastable with the rightmost 633 dimensions of `shape`. 634 beta: Tensor. The inverse scale parameter of the gamma distribution. Must be 635 broadcastable with `alpha` and broadcastable with the rightmost dimensions 636 of `shape`. 637 dtype: Floating point dtype of `alpha`, `beta`, and the output. 638 name: A name for the operation (optional). 639 640 Returns: 641 samples: A Tensor of the specified shape filled with random gamma values. 642 For each i, each `samples[..., i] is an independent draw from the gamma 643 distribution with concentration alpha[i] and scale beta[i]. 644 645 """ 646 with ops.name_scope(name, "stateless_random_gamma", 647 [shape, seed, alpha, beta]) as name: 648 shape = tensor_util.shape_tensor(shape) 649 alpha = ops.convert_to_tensor(alpha, dtype=dtype, name="alpha") 650 beta = ops.convert_to_tensor( 651 beta if beta is not None else 1, name="beta", dtype=dtype) 652 broadcast_shape = array_ops.broadcast_dynamic_shape( 653 array_ops.shape(alpha), array_ops.shape(beta)) 654 alpha_broadcast = array_ops.broadcast_to(alpha, broadcast_shape) 655 result = math_ops.maximum( 656 np.finfo(alpha.dtype.as_numpy_dtype).tiny, 657 gen_stateless_random_ops.stateless_random_gamma_v2( 658 shape, seed=seed, alpha=alpha_broadcast) / beta) 659 tensor_util.maybe_set_static_shape(result, shape) 660 return result 661 662 663@tf_export("random.stateless_poisson") 664@dispatch.add_dispatch_support 665def stateless_random_poisson(shape, 666 seed, 667 lam, 668 dtype=dtypes.int32, 669 name=None): 670 """Outputs deterministic pseudorandom values from a Poisson distribution. 671 672 The generated values follow a Poisson distribution with specified rate 673 parameter. 674 675 This is a stateless version of `tf.random.poisson`: if run twice with the same 676 seeds and shapes, it will produce the same pseudorandom numbers. The output is 677 consistent across multiple runs on the same hardware, but may change between 678 versions of TensorFlow or on non-CPU/GPU hardware. 679 680 A slight difference exists in the interpretation of the `shape` parameter 681 between `stateless_poisson` and `poisson`: in `poisson`, the `shape` is always 682 prepended to the shape of `lam`; whereas in `stateless_poisson` the shape of 683 `lam` must match the trailing dimensions of `shape`. 684 685 Example: 686 687 ```python 688 samples = tf.random.stateless_poisson([10, 2], seed=[12, 34], lam=[5, 15]) 689 # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents 690 # the samples drawn from each distribution 691 692 samples = tf.random.stateless_poisson([7, 5, 2], seed=[12, 34], lam=[5, 15]) 693 # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] 694 # represents the 7x5 samples drawn from each of the two distributions 695 696 rate = tf.constant([[1.], [3.], [5.]]) 697 samples = tf.random.stateless_poisson([30, 3, 1], seed=[12, 34], lam=rate) 698 # samples has shape [30, 3, 1], with 30 samples each of 3x1 distributions. 699 ``` 700 701 Args: 702 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 703 seed: A shape [2] Tensor, the seed to the random number generator. Must have 704 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 705 lam: Tensor. The rate parameter "lambda" of the Poisson distribution. Shape 706 must match the rightmost dimensions of `shape`. 707 dtype: Dtype of the samples (int or float dtypes are permissible, as samples 708 are discrete). Default: int32. 709 name: A name for the operation (optional). 710 711 Returns: 712 samples: A Tensor of the specified shape filled with random Poisson values. 713 For each i, each `samples[..., i]` is an independent draw from the Poisson 714 distribution with rate `lam[i]`. 715 716 """ 717 with ops.name_scope(name, "stateless_random_poisson", 718 [shape, seed, lam]) as name: 719 shape = tensor_util.shape_tensor(shape) 720 result = gen_stateless_random_ops.stateless_random_poisson( 721 shape, seed=seed, lam=lam, dtype=dtype) 722 tensor_util.maybe_set_static_shape(result, shape) 723 return result 724 725 726@tf_export("random.stateless_normal") 727@dispatch.add_dispatch_support 728def stateless_random_normal(shape, 729 seed, 730 mean=0.0, 731 stddev=1.0, 732 dtype=dtypes.float32, 733 name=None, 734 alg="auto_select"): 735 """Outputs deterministic pseudorandom values from a normal distribution. 736 737 This is a stateless version of `tf.random.normal`: if run twice with the 738 same seeds and shapes, it will produce the same pseudorandom numbers. The 739 output is consistent across multiple runs on the same hardware (and between 740 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 741 hardware. 742 743 Args: 744 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 745 seed: A shape [2] Tensor, the seed to the random number generator. Must have 746 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 747 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal 748 distribution. 749 stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation 750 of the normal distribution. 751 dtype: The float type of the output: `float16`, `bfloat16`, `float32`, 752 `float64`. Defaults to `float32`. 753 name: A name for the operation (optional). 754 alg: The RNG algorithm used to generate the random numbers. See 755 `tf.random.stateless_uniform` for a detailed explanation. 756 757 Returns: 758 A tensor of the specified shape filled with random normal values. 759 """ 760 with ops.name_scope(name, "stateless_random_normal", 761 [shape, seed, mean, stddev]) as name: 762 shape = tensor_util.shape_tensor(shape) 763 mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 764 stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 765 key, counter, alg = _get_key_counter_alg(seed, alg) 766 rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2( 767 shape, key=key, counter=counter, dtype=dtype, alg=alg) 768 result = math_ops.add(rnd * stddev, mean, name=name) 769 tensor_util.maybe_set_static_shape(result, shape) 770 return result 771 772 773@tf_export("random.stateless_truncated_normal") 774@dispatch.add_dispatch_support 775def stateless_truncated_normal(shape, 776 seed, 777 mean=0.0, 778 stddev=1.0, 779 dtype=dtypes.float32, 780 name=None, 781 alg="auto_select"): 782 """Outputs deterministic pseudorandom values, truncated normally distributed. 783 784 This is a stateless version of `tf.random.truncated_normal`: if run twice with 785 the same seeds and shapes, it will produce the same pseudorandom numbers. The 786 output is consistent across multiple runs on the same hardware (and between 787 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 788 hardware. 789 790 The generated values follow a normal distribution with specified mean and 791 standard deviation, except that values whose magnitude is more than 2 standard 792 deviations from the mean are dropped and re-picked. 793 794 Args: 795 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 796 seed: A shape [2] Tensor, the seed to the random number generator. Must have 797 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 798 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the 799 truncated normal distribution. 800 stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation 801 of the normal distribution, before truncation. 802 dtype: The type of the output. 803 name: A name for the operation (optional). 804 alg: The RNG algorithm used to generate the random numbers. See 805 `tf.random.stateless_uniform` for a detailed explanation. 806 807 Returns: 808 A tensor of the specified shape filled with random truncated normal values. 809 """ 810 with ops.name_scope(name, "stateless_truncated_normal", 811 [shape, seed, mean, stddev]) as name: 812 shape = tensor_util.shape_tensor(shape) 813 mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 814 stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 815 key, counter, alg = _get_key_counter_alg(seed, alg) 816 rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2( 817 shape, key=key, counter=counter, dtype=dtype, alg=alg) 818 result = math_ops.add(rnd * stddev, mean, name=name) 819 tensor_util.maybe_set_static_shape(result, shape) 820 return result 821 822 823@tf_export(v1=["random.stateless_multinomial"]) 824@dispatch.add_dispatch_support 825@deprecation.deprecated( 826 date=None, instructions="Use `tf.random.stateless_categorical` instead.") 827def stateless_multinomial(logits, 828 num_samples, 829 seed, 830 output_dtype=dtypes.int64, 831 name=None): 832 """Draws deterministic pseudorandom samples from a multinomial distribution. 833 834 This is a stateless version of `tf.random.categorical`: if run twice with the 835 same seeds and shapes, it will produce the same pseudorandom numbers. The 836 output is consistent across multiple runs on the same hardware (and between 837 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 838 hardware. 839 840 Example: 841 842 ```python 843 # samples has shape [1, 5], where each value is either 0 or 1 with equal 844 # probability. 845 samples = tf.random.stateless_categorical( 846 tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17]) 847 ``` 848 849 Args: 850 logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice 851 `[i, :]` represents the unnormalized log-probabilities for all classes. 852 num_samples: 0-D. Number of independent samples to draw for each row slice. 853 seed: A shape [2] Tensor, the seed to the random number generator. Must have 854 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 855 output_dtype: The integer type of the output: `int32` or `int64`. Defaults 856 to `int64`. 857 name: Optional name for the operation. 858 859 Returns: 860 The drawn samples of shape `[batch_size, num_samples]`. 861 """ 862 with ops.name_scope(name, "stateless_multinomial", [logits, seed]): 863 return stateless_multinomial_categorical_impl(logits, num_samples, 864 output_dtype, seed) 865 866 867@tf_export("random.stateless_categorical") 868@dispatch.add_dispatch_support 869def stateless_categorical(logits, 870 num_samples, 871 seed, 872 dtype=dtypes.int64, 873 name=None): 874 """Draws deterministic pseudorandom samples from a categorical distribution. 875 876 This is a stateless version of `tf.categorical`: if run twice with the 877 same seeds and shapes, it will produce the same pseudorandom numbers. The 878 output is consistent across multiple runs on the same hardware (and between 879 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 880 hardware. 881 882 883 Example: 884 885 ```python 886 # samples has shape [1, 5], where each value is either 0 or 1 with equal 887 # probability. 888 samples = tf.random.stateless_categorical( 889 tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17]) 890 ``` 891 892 Args: 893 logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice 894 `[i, :]` represents the unnormalized log-probabilities for all classes. 895 num_samples: 0-D. Number of independent samples to draw for each row slice. 896 seed: A shape [2] Tensor, the seed to the random number generator. Must have 897 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 898 dtype: The integer type of the output: `int32` or `int64`. Defaults to 899 `int64`. 900 name: Optional name for the operation. 901 902 Returns: 903 The drawn samples of shape `[batch_size, num_samples]`. 904 """ 905 with ops.name_scope(name, "stateless_categorical", [logits, seed]): 906 return stateless_multinomial_categorical_impl(logits, num_samples, dtype, 907 seed) 908 909 910def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed): 911 """Implementation for stateless multinomial/categorical ops (v1/v2).""" 912 logits = ops.convert_to_tensor(logits, name="logits") 913 dtype = dtypes.as_dtype(dtype) if dtype else dtypes.int64 914 accepted_dtypes = (dtypes.int32, dtypes.int64) 915 if dtype not in accepted_dtypes: 916 raise ValueError( 917 f"Argument `dtype` got invalid value {dtype}. Accepted dtypes are " 918 f"{accepted_dtypes}.") 919 return gen_stateless_random_ops.stateless_multinomial( 920 logits, num_samples, seed, output_dtype=dtype) 921 922 923@dispatch.add_dispatch_support 924@tf_export("random.stateless_parameterized_truncated_normal") 925def stateless_parameterized_truncated_normal(shape, 926 seed, 927 means=0.0, 928 stddevs=1.0, 929 minvals=-2.0, 930 maxvals=2.0, 931 name=None): 932 """Outputs random values from a truncated normal distribution. 933 934 The generated values follow a normal distribution with specified mean and 935 standard deviation, except that values whose magnitude is more than 2 standard 936 deviations from the mean are dropped and re-picked. 937 938 939 Examples: 940 941 Sample from a Truncated normal, with deferring shape parameters that 942 broadcast. 943 944 >>> means = 0. 945 >>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3])) 946 >>> minvals = [-1., -2., -1000.] 947 >>> maxvals = [[10000.], [1.]] 948 >>> y = tf.random.stateless_parameterized_truncated_normal( 949 ... shape=[10, 2, 3], seed=[7, 17], 950 ... means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals) 951 >>> y.shape 952 TensorShape([10, 2, 3]) 953 954 Args: 955 shape: A 1-D integer `Tensor` or Python array. The shape of the output 956 tensor. 957 seed: A shape [2] Tensor, the seed to the random number generator. Must have 958 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 959 means: A `Tensor` or Python value of type `dtype`. The mean of the truncated 960 normal distribution. This must broadcast with `stddevs`, `minvals` and 961 `maxvals`, and the broadcasted shape must be dominated by `shape`. 962 stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation 963 of the truncated normal distribution. This must broadcast with `means`, 964 `minvals` and `maxvals`, and the broadcasted shape must be dominated by 965 `shape`. 966 minvals: A `Tensor` or Python value of type `dtype`. The minimum value of 967 the truncated normal distribution. This must broadcast with `means`, 968 `stddevs` and `maxvals`, and the broadcasted shape must be dominated by 969 `shape`. 970 maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of 971 the truncated normal distribution. This must broadcast with `means`, 972 `stddevs` and `minvals`, and the broadcasted shape must be dominated by 973 `shape`. 974 name: A name for the operation (optional). 975 976 Returns: 977 A tensor of the specified shape filled with random truncated normal values. 978 """ 979 with ops.name_scope(name, "stateless_parameterized_truncated_normal", 980 [shape, means, stddevs, minvals, maxvals]) as name: 981 shape_tensor = tensor_util.shape_tensor(shape) 982 means_tensor = ops.convert_to_tensor(means, name="means") 983 stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs") 984 minvals_tensor = ops.convert_to_tensor(minvals, name="minvals") 985 maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals") 986 rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal( 987 shape_tensor, seed, means_tensor, stddevs_tensor, minvals_tensor, 988 maxvals_tensor) 989 tensor_util.maybe_set_static_shape(rnd, shape) 990 return rnd 991