1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Stateless random ops which take seed as a tensor input.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import enum 22import numpy as np 23import six 24 25from tensorflow.python.compat import compat 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import bitwise_ops 32from tensorflow.python.ops import gen_stateless_random_ops 33from tensorflow.python.ops import gen_stateless_random_ops_v2 34from tensorflow.python.ops import math_ops 35from tensorflow.python.util import deprecation 36from tensorflow.python.util import dispatch 37from tensorflow.python.util.tf_export import tf_export 38 39 40ops.NotDifferentiable("StatelessMultinomial") 41ops.NotDifferentiable("StatelessRandomBinomial") 42ops.NotDifferentiable("StatelessRandomNormal") 43ops.NotDifferentiable("StatelessRandomPoisson") 44ops.NotDifferentiable("StatelessRandomUniform") 45ops.NotDifferentiable("StatelessRandomUniformInt") 46ops.NotDifferentiable("StatelessRandomUniformFullInt") 47ops.NotDifferentiable("StatelessTruncatedNormal") 48 49 50ops.NotDifferentiable("StatelessRandomNormalV2") 51ops.NotDifferentiable("StatelessRandomUniformV2") 52ops.NotDifferentiable("StatelessRandomUniformIntV2") 53ops.NotDifferentiable("StatelessRandomUniformFullIntV2") 54ops.NotDifferentiable("StatelessTruncatedNormalV2") 55 56 57@tf_export("random.Algorithm", "random.experimental.Algorithm") 58class Algorithm(enum.Enum): 59 # The numbers here must match framework/rng_alg.h 60 PHILOX = 1 61 THREEFRY = 2 62 AUTO_SELECT = 3 63 64 65def convert_alg_to_int(alg): 66 """Converts algorithm to an integer. 67 68 Args: 69 alg: can be one of these types: integer, Algorithm, Tensor, string. Allowed 70 strings are "philox" and "threefry". 71 72 Returns: 73 An integer, unless the input is a Tensor in which case a Tensor is returned. 74 """ 75 if isinstance(alg, six.integer_types): 76 return alg 77 if isinstance(alg, Algorithm): 78 return alg.value 79 if isinstance(alg, ops.Tensor): 80 return alg 81 if isinstance(alg, str): 82 if alg == "philox": 83 return Algorithm.PHILOX.value 84 elif alg in ("threefry", "three-fry", "three_fry"): 85 return Algorithm.THREEFRY.value 86 elif alg in ("autoselect", "auto-select", "auto_select"): 87 return Algorithm.AUTO_SELECT.value 88 else: 89 raise ValueError( 90 f"Argument `alg` got unsupported string value {alg}. Supported " 91 f"string values are 'philox' for the Philox algorithm, 'threefry' " 92 f"for the ThreeFry algorithm, and 'auto_select' for auto-selection.") 93 else: 94 raise TypeError( 95 f"Can't convert argument `alg` (of value {alg} and type {type(alg)}) " 96 f"to int.") 97 98 99def _resolve_alg(alg): 100 if alg == Algorithm.AUTO_SELECT.value: 101 return gen_stateless_random_ops_v2.stateless_random_get_alg() 102 return alg 103 104 105def _get_key_counter(seed, alg): 106 """Calculates the key and counter to pass to raw RNG ops. 107 108 This function calculates the key and counter that will be passed to 109 the raw RNG ops like `StatelessRandomUniformV2`. Depending on the 110 input `alg`, the key and counter may be scrambled or copied from 111 `seed`. If `alg` is `"auto_select"`, the key and counter will be 112 determined at runtime based on device type. 113 114 Args: 115 seed: An integer tensor of shape [2]. The seed to calculate the 116 key and counter from. 117 alg: The RNG algorithm. See `tf.random.stateless_uniform` for an 118 explanation. 119 120 Returns: 121 A pair (key, counter) suitable for V2 stateless RNG ops like 122 `StatelessRandomUniformV2`. 123 """ 124 if alg == Algorithm.AUTO_SELECT.value: 125 key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter( 126 seed) 127 elif alg == Algorithm.PHILOX.value: 128 key, counter = _philox_scramble_seed(seed) 129 elif alg == Algorithm.THREEFRY.value: 130 key = array_ops.reshape( 131 uint32s_to_uint64(math_ops.cast(seed, dtypes.uint32)), [1]) 132 counter = array_ops.zeros([1], dtypes.uint64) 133 else: 134 raise ValueError( 135 f"Argument `alg` got unsupported value {alg}. Supported values are " 136 f"{Algorithm.PHILOX.value} for the Philox algorithm, " 137 f"{Algorithm.THREEFRY.value} for the ThreeFry algorithm, and " 138 f"{Algorithm.AUTO_SELECT.value} for auto-selection.") 139 return key, counter 140 141 142def _get_key_counter_alg(seed, alg): 143 if alg is None: 144 alg = Algorithm.AUTO_SELECT.value 145 alg = convert_alg_to_int(alg) 146 key, counter = _get_key_counter(seed, alg) 147 if compat.forward_compatible(2021, 8, 11): 148 return key, counter, alg 149 else: 150 return key, counter, _resolve_alg(alg) 151 152 153def _philox_scramble_seed(seed): 154 # the same scrambling procedure as core/kernels/stateless_random_ops.cc 155 key = constant_op.constant([0x02461e293ec8f720], dtypes.uint64) 156 counter = math_ops.cast(seed, dtypes.uint64) 157 mix = gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( 158 [4], key=key, counter=counter, dtype=dtypes.uint32, 159 alg=Algorithm.PHILOX.value) 160 key = array_ops.reshape(uint32s_to_uint64(mix[:2]), [1]) 161 counter = array_ops.stack([0, uint32s_to_uint64(mix[2:])], axis=0) 162 return key, counter 163 164 165def uint32s_to_uint64(x): 166 return bitwise_ops.bitwise_or( 167 math_ops.cast(x[0], dtypes.uint64), 168 bitwise_ops.left_shift(math_ops.cast(x[1], dtypes.uint64), 169 constant_op.constant(32, dtypes.uint64))) 170 171 172@tf_export("random.experimental.stateless_split") 173@dispatch.add_dispatch_support 174def split(seed, num=2, alg="auto_select"): 175 """Splits an RNG seed into `num` new seeds by adding a leading axis. 176 177 Example: 178 179 >>> seed = [1, 2] 180 >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3) 181 >>> print(new_seeds) 182 tf.Tensor( 183 [[1105988140 1738052849] 184 [-335576002 370444179] 185 [ 10670227 -246211131]], shape=(3, 2), dtype=int32) 186 >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :]) 187 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.59835213, -0.9578608 , 188 0.9002807 ], dtype=float32)> 189 190 Args: 191 seed: an RNG seed (a tensor with shape [2] and dtype `int32` or 192 `int64`). (When using XLA, only `int32` is allowed.) 193 num: optional, a positive integer or scalar tensor indicating the number of 194 seeds to produce (default 2). 195 alg: The RNG algorithm used to generate the random numbers. See 196 `tf.random.stateless_uniform` for a detailed explanation. 197 198 Returns: 199 A tensor with shape [num, 2] representing `num` new seeds. It will have the 200 same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype 201 will be determined by `tf.convert_to_tensor`). 202 """ 203 seed = ops.convert_to_tensor(seed) 204 return stateless_random_uniform(shape=[num, 2], seed=seed, dtype=seed.dtype, 205 minval=None, maxval=None, alg=alg) 206 207 208@tf_export("random.experimental.stateless_fold_in") 209@dispatch.add_dispatch_support 210def fold_in(seed, data, alg="auto_select"): 211 """Folds in data to an RNG seed to form a new RNG seed. 212 213 For example, in a distributed-training setting, suppose we have a master seed 214 and a replica ID. We want to fold the replica ID into the master seed to 215 form a "replica seed" to be used by that replica later on, so that different 216 replicas will generate different random numbers but the reproducibility of the 217 whole system can still be controlled by the master seed: 218 219 >>> master_seed = [1, 2] 220 >>> replica_id = 3 221 >>> replica_seed = tf.random.experimental.stateless_fold_in( 222 ... master_seed, replica_id) 223 >>> print(replica_seed) 224 tf.Tensor([1105988140 3], shape=(2,), dtype=int32) 225 >>> tf.random.stateless_normal(shape=[3], seed=replica_seed) 226 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.03197195, 0.8979765 , 227 0.13253039], dtype=float32)> 228 229 Args: 230 seed: an RNG seed (a tensor with shape [2] and dtype `int32` or 231 `int64`). (When using XLA, only `int32` is allowed.) 232 data: an `int32` or `int64` scalar representing data to be folded in to the 233 seed. 234 alg: The RNG algorithm used to generate the random numbers. See 235 `tf.random.stateless_uniform` for a detailed explanation. 236 237 Returns: 238 A new RNG seed that is a deterministic function of the inputs and is 239 statistically safe for producing a stream of new pseudo-random values. It 240 will have the same dtype as `data` (if `data` doesn't have an explict dtype, 241 the dtype will be determined by `tf.convert_to_tensor`). 242 """ 243 data = ops.convert_to_tensor(data) 244 seed1 = stateless_random_uniform(shape=[], seed=seed, dtype=data.dtype, 245 minval=None, maxval=None, alg=alg) 246 return array_ops.stack([seed1, data]) 247 248 249@tf_export("random.stateless_uniform") 250@dispatch.add_dispatch_support 251def stateless_random_uniform(shape, 252 seed, 253 minval=0, 254 maxval=None, 255 dtype=dtypes.float32, 256 name=None, 257 alg="auto_select"): 258 """Outputs deterministic pseudorandom values from a uniform distribution. 259 260 This is a stateless version of `tf.random.uniform`: if run twice with the 261 same seeds and shapes, it will produce the same pseudorandom numbers. The 262 output is consistent across multiple runs on the same hardware (and between 263 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 264 hardware. 265 266 The generated values follow a uniform distribution in the range 267 `[minval, maxval)`. The lower bound `minval` is included in the range, while 268 the upper bound `maxval` is excluded. 269 270 For floats, the default range is `[0, 1)`. For ints, at least `maxval` must 271 be specified explicitly. 272 273 In the integer case, the random integers are slightly biased unless 274 `maxval - minval` is an exact power of two. The bias is small for values of 275 `maxval - minval` significantly smaller than the range of the output (either 276 `2**32` or `2**64`). 277 278 For full-range (i.e. inclusive of both max and min) random integers, pass 279 `minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype 280 either both `minval` and `maxval` must be `None` or neither may be `None`. For 281 example: 282 ```python 283 ints = tf.random.stateless_uniform( 284 [10], seed=(2, 3), minval=None, maxval=None, dtype=tf.int32) 285 ``` 286 287 Args: 288 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 289 seed: A shape [2] Tensor, the seed to the random number generator. Must have 290 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 291 minval: A Tensor or Python value of type `dtype`, broadcastable with 292 `shape` (for integer types, broadcasting is not supported, so it needs to 293 be a scalar). The lower bound on the range of random values to 294 generate. Pass `None` for full-range integers. Defaults to 0. 295 maxval: A Tensor or Python value of type `dtype`, broadcastable with 296 `shape` (for integer types, broadcasting is not supported, so it needs to 297 be a scalar). The upper bound on the range of random values to generate. 298 Defaults to 1 if `dtype` is floating point. Pass `None` for full-range 299 integers. 300 dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or 301 `int64`. For unbounded uniform ints (`minval`, `maxval` both `None`), 302 `uint32` and `uint64` may be used. 303 name: A name for the operation (optional). 304 alg: The RNG algorithm used to generate the random numbers. Valid 305 choices are `"philox"` for [the Philox 306 algorithm](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf), 307 `"threefry"` for [the ThreeFry 308 algorithm](https://www.thesalmons.org/john/random123/papers/random123sc11.pdf), 309 and `"auto_select"` (default) for the system to automatically 310 select an algorithm based the device type. Values of 311 `tf.random.Algorithm` can also be used. Note that with 312 `"auto_select"`, the outputs of this function may change when 313 it is running on a different device. 314 315 Returns: 316 A tensor of the specified shape filled with random uniform values. 317 318 Raises: 319 ValueError: If `dtype` is integral and only one of `minval` or `maxval` is 320 specified. 321 """ 322 dtype = dtypes.as_dtype(dtype) 323 accepted_dtypes = (dtypes.float16, dtypes.bfloat16, dtypes.float32, 324 dtypes.float64, dtypes.int32, dtypes.int64, dtypes.uint32, 325 dtypes.uint64) 326 if dtype not in accepted_dtypes: 327 raise ValueError( 328 f"Argument `dtype` got invalid value {dtype}. Accepted dtypes are " 329 f"{accepted_dtypes}.") 330 if dtype.is_integer: 331 if (minval is None) != (maxval is None): 332 raise ValueError( 333 f"For integer `dtype` argument {dtype}, argument `minval` and " 334 f"`maxval` must be both None or not None. Got `minval`={minval} and " 335 f"`maxval`={maxval}.") 336 if minval is not None and dtype in (dtypes.uint32, dtypes.uint64): 337 raise ValueError( 338 f"Argument `dtype` got invalid value {dtype} when argument `minval` " 339 f"is not None. Please don't use unsigned integers in this case.") 340 elif maxval is None: 341 maxval = 1 342 with ops.name_scope(name, "stateless_random_uniform", 343 [shape, seed, minval, maxval]) as name: 344 shape = tensor_util.shape_tensor(shape) 345 if dtype.is_integer and minval is None: 346 key, counter, alg = _get_key_counter_alg(seed, alg) 347 result = ( 348 gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2( 349 shape, key=key, counter=counter, dtype=dtype, alg=alg, name=name)) 350 else: 351 minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") 352 maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") 353 if dtype.is_integer: 354 key, counter, alg = _get_key_counter_alg(seed, alg) 355 result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2( 356 shape, 357 key=key, 358 counter=counter, 359 minval=minval, 360 maxval=maxval, 361 alg=alg, 362 name=name) 363 else: 364 key, counter, alg = _get_key_counter_alg(seed, alg) 365 rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2( 366 shape, key=key, counter=counter, dtype=dtype, alg=alg) 367 result = math_ops.add(rnd * (maxval - minval), minval, name=name) 368 tensor_util.maybe_set_static_shape(result, shape) 369 return result 370 371 372@tf_export("random.stateless_binomial") 373@dispatch.add_dispatch_support 374def stateless_random_binomial(shape, 375 seed, 376 counts, 377 probs, 378 output_dtype=dtypes.int32, 379 name=None): 380 """Outputs deterministic pseudorandom values from a binomial distribution. 381 382 The generated values follow a binomial distribution with specified count and 383 probability of success parameters. 384 385 This is a stateless version of `tf.random.Generator.binomial`: if run twice 386 with the same seeds and shapes, it will produce the same pseudorandom numbers. 387 The output is consistent across multiple runs on the same hardware (and 388 between CPU and GPU), but may change between versions of TensorFlow or on 389 non-CPU/GPU hardware. 390 391 Example: 392 393 ```python 394 counts = [10., 20.] 395 # Probability of success. 396 probs = [0.8] 397 398 binomial_samples = tf.random.stateless_binomial( 399 shape=[2], seed=[123, 456], counts=counts, probs=probs) 400 401 counts = ... # Shape [3, 1, 2] 402 probs = ... # Shape [1, 4, 2] 403 shape = [3, 4, 3, 4, 2] 404 # Sample shape will be [3, 4, 3, 4, 2] 405 binomial_samples = tf.random.stateless_binomial( 406 shape=shape, seed=[123, 456], counts=counts, probs=probs) 407 ``` 408 409 Args: 410 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 411 seed: A shape [2] Tensor, the seed to the random number generator. Must have 412 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 413 counts: Tensor. The counts of the binomial distribution. Must be 414 broadcastable with `probs`, and broadcastable with the rightmost 415 dimensions of `shape`. 416 probs: Tensor. The probability of success for the binomial distribution. 417 Must be broadcastable with `counts` and broadcastable with the rightmost 418 dimensions of `shape`. 419 output_dtype: The type of the output. Default: tf.int32 420 name: A name for the operation (optional). 421 422 Returns: 423 samples: A Tensor of the specified shape filled with random binomial 424 values. For each i, each samples[..., i] is an independent draw from 425 the binomial distribution on counts[i] trials with probability of 426 success probs[i]. 427 428 """ 429 with ops.name_scope(name, "stateless_random_binomial", 430 [shape, seed, counts, probs]) as name: 431 shape = tensor_util.shape_tensor(shape) 432 probs = ops.convert_to_tensor( 433 probs, dtype_hint=dtypes.float32, name="probs") 434 counts = ops.convert_to_tensor( 435 counts, dtype_hint=probs.dtype, name="counts") 436 result = gen_stateless_random_ops.stateless_random_binomial( 437 shape=shape, seed=seed, counts=counts, probs=probs, dtype=output_dtype) 438 tensor_util.maybe_set_static_shape(result, shape) 439 return result 440 441 442@tf_export("random.stateless_gamma") 443@dispatch.add_dispatch_support 444def stateless_random_gamma(shape, 445 seed, 446 alpha, 447 beta=None, 448 dtype=dtypes.float32, 449 name=None): 450 """Outputs deterministic pseudorandom values from a gamma distribution. 451 452 The generated values follow a gamma distribution with specified concentration 453 (`alpha`) and inverse scale (`beta`) parameters. 454 455 This is a stateless version of `tf.random.gamma`: if run twice with the same 456 seeds and shapes, it will produce the same pseudorandom numbers. The output is 457 consistent across multiple runs on the same hardware (and between CPU and 458 GPU), 459 but may change between versions of TensorFlow or on non-CPU/GPU hardware. 460 461 A slight difference exists in the interpretation of the `shape` parameter 462 between `stateless_gamma` and `gamma`: in `gamma`, the `shape` is always 463 prepended to the shape of the broadcast of `alpha` with `beta`; whereas in 464 `stateless_gamma` the `shape` parameter must always encompass the shapes of 465 each of `alpha` and `beta` (which must broadcast together to match the 466 trailing dimensions of `shape`). 467 468 Note: Because internal calculations are done using `float64` and casting has 469 `floor` semantics, we must manually map zero outcomes to the smallest 470 possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This 471 means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise 472 should. This bias can only happen for small values of `alpha`, i.e., 473 `alpha << 1` or large values of `beta`, i.e., `beta >> 1`. 474 475 The samples are differentiable w.r.t. alpha and beta. 476 The derivatives are computed using the approach described in 477 (Figurnov et al., 2018). 478 479 Example: 480 481 ```python 482 samples = tf.random.stateless_gamma([10, 2], seed=[12, 34], alpha=[0.5, 1.5]) 483 # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents 484 # the samples drawn from each distribution 485 486 samples = tf.random.stateless_gamma([7, 5, 2], seed=[12, 34], alpha=[.5, 1.5]) 487 # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] 488 # represents the 7x5 samples drawn from each of the two distributions 489 490 alpha = tf.constant([[1.], [3.], [5.]]) 491 beta = tf.constant([[3., 4.]]) 492 samples = tf.random.stateless_gamma( 493 [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta) 494 # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions. 495 496 with tf.GradientTape() as tape: 497 tape.watch([alpha, beta]) 498 loss = tf.reduce_mean(tf.square(tf.random.stateless_gamma( 499 [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta))) 500 dloss_dalpha, dloss_dbeta = tape.gradient(loss, [alpha, beta]) 501 # unbiased stochastic derivatives of the loss function 502 alpha.shape == dloss_dalpha.shape # True 503 beta.shape == dloss_dbeta.shape # True 504 ``` 505 506 Args: 507 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 508 seed: A shape [2] Tensor, the seed to the random number generator. Must have 509 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 510 alpha: Tensor. The concentration parameter of the gamma distribution. Must 511 be broadcastable with `beta`, and broadcastable with the rightmost 512 dimensions of `shape`. 513 beta: Tensor. The inverse scale parameter of the gamma distribution. Must be 514 broadcastable with `alpha` and broadcastable with the rightmost dimensions 515 of `shape`. 516 dtype: Floating point dtype of `alpha`, `beta`, and the output. 517 name: A name for the operation (optional). 518 519 Returns: 520 samples: A Tensor of the specified shape filled with random gamma values. 521 For each i, each `samples[..., i] is an independent draw from the gamma 522 distribution with concentration alpha[i] and scale beta[i]. 523 524 """ 525 with ops.name_scope(name, "stateless_random_gamma", 526 [shape, seed, alpha, beta]) as name: 527 shape = tensor_util.shape_tensor(shape) 528 alpha = ops.convert_to_tensor(alpha, dtype=dtype, name="alpha") 529 beta = ops.convert_to_tensor( 530 beta if beta is not None else 1, name="beta", dtype=dtype) 531 broadcast_shape = array_ops.broadcast_dynamic_shape( 532 array_ops.shape(alpha), array_ops.shape(beta)) 533 alpha_broadcast = array_ops.broadcast_to(alpha, broadcast_shape) 534 result = math_ops.maximum( 535 np.finfo(alpha.dtype.as_numpy_dtype).tiny, 536 gen_stateless_random_ops.stateless_random_gamma_v2( 537 shape, seed=seed, alpha=alpha_broadcast) / beta) 538 tensor_util.maybe_set_static_shape(result, shape) 539 return result 540 541 542@tf_export("random.stateless_poisson") 543@dispatch.add_dispatch_support 544def stateless_random_poisson(shape, 545 seed, 546 lam, 547 dtype=dtypes.int32, 548 name=None): 549 """Outputs deterministic pseudorandom values from a Poisson distribution. 550 551 The generated values follow a Poisson distribution with specified rate 552 parameter. 553 554 This is a stateless version of `tf.random.poisson`: if run twice with the same 555 seeds and shapes, it will produce the same pseudorandom numbers. The output is 556 consistent across multiple runs on the same hardware, but may change between 557 versions of TensorFlow or on non-CPU/GPU hardware. 558 559 A slight difference exists in the interpretation of the `shape` parameter 560 between `stateless_poisson` and `poisson`: in `poisson`, the `shape` is always 561 prepended to the shape of `lam`; whereas in `stateless_poisson` the shape of 562 `lam` must match the trailing dimensions of `shape`. 563 564 Example: 565 566 ```python 567 samples = tf.random.stateless_poisson([10, 2], seed=[12, 34], lam=[5, 15]) 568 # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents 569 # the samples drawn from each distribution 570 571 samples = tf.random.stateless_poisson([7, 5, 2], seed=[12, 34], lam=[5, 15]) 572 # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] 573 # represents the 7x5 samples drawn from each of the two distributions 574 575 rate = tf.constant([[1.], [3.], [5.]]) 576 samples = tf.random.stateless_poisson([30, 3, 1], seed=[12, 34], lam=rate) 577 # samples has shape [30, 3, 1], with 30 samples each of 3x1 distributions. 578 ``` 579 580 Args: 581 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 582 seed: A shape [2] Tensor, the seed to the random number generator. Must have 583 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 584 lam: Tensor. The rate parameter "lambda" of the Poisson distribution. Shape 585 must match the rightmost dimensions of `shape`. 586 dtype: Dtype of the samples (int or float dtypes are permissible, as samples 587 are discrete). Default: int32. 588 name: A name for the operation (optional). 589 590 Returns: 591 samples: A Tensor of the specified shape filled with random Poisson values. 592 For each i, each `samples[..., i]` is an independent draw from the Poisson 593 distribution with rate `lam[i]`. 594 595 """ 596 with ops.name_scope(name, "stateless_random_poisson", 597 [shape, seed, lam]) as name: 598 shape = tensor_util.shape_tensor(shape) 599 result = gen_stateless_random_ops.stateless_random_poisson( 600 shape, seed=seed, lam=lam, dtype=dtype) 601 tensor_util.maybe_set_static_shape(result, shape) 602 return result 603 604 605@tf_export("random.stateless_normal") 606@dispatch.add_dispatch_support 607def stateless_random_normal(shape, 608 seed, 609 mean=0.0, 610 stddev=1.0, 611 dtype=dtypes.float32, 612 name=None, 613 alg="auto_select"): 614 """Outputs deterministic pseudorandom values from a normal distribution. 615 616 This is a stateless version of `tf.random.normal`: if run twice with the 617 same seeds and shapes, it will produce the same pseudorandom numbers. The 618 output is consistent across multiple runs on the same hardware (and between 619 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 620 hardware. 621 622 Args: 623 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 624 seed: A shape [2] Tensor, the seed to the random number generator. Must have 625 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 626 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal 627 distribution. 628 stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation 629 of the normal distribution. 630 dtype: The type of the output. 631 name: A name for the operation (optional). 632 alg: The RNG algorithm used to generate the random numbers. See 633 `tf.random.stateless_uniform` for a detailed explanation. 634 635 Returns: 636 A tensor of the specified shape filled with random normal values. 637 """ 638 with ops.name_scope(name, "stateless_random_normal", 639 [shape, seed, mean, stddev]) as name: 640 shape = tensor_util.shape_tensor(shape) 641 mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 642 stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 643 key, counter, alg = _get_key_counter_alg(seed, alg) 644 rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2( 645 shape, key=key, counter=counter, dtype=dtype, alg=alg) 646 result = math_ops.add(rnd * stddev, mean, name=name) 647 tensor_util.maybe_set_static_shape(result, shape) 648 return result 649 650 651@tf_export("random.stateless_truncated_normal") 652@dispatch.add_dispatch_support 653def stateless_truncated_normal(shape, 654 seed, 655 mean=0.0, 656 stddev=1.0, 657 dtype=dtypes.float32, 658 name=None, 659 alg="auto_select"): 660 """Outputs deterministic pseudorandom values, truncated normally distributed. 661 662 This is a stateless version of `tf.random.truncated_normal`: if run twice with 663 the same seeds and shapes, it will produce the same pseudorandom numbers. The 664 output is consistent across multiple runs on the same hardware (and between 665 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 666 hardware. 667 668 The generated values follow a normal distribution with specified mean and 669 standard deviation, except that values whose magnitude is more than 2 standard 670 deviations from the mean are dropped and re-picked. 671 672 Args: 673 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 674 seed: A shape [2] Tensor, the seed to the random number generator. Must have 675 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 676 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the 677 truncated normal distribution. 678 stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation 679 of the normal distribution, before truncation. 680 dtype: The type of the output. 681 name: A name for the operation (optional). 682 alg: The RNG algorithm used to generate the random numbers. See 683 `tf.random.stateless_uniform` for a detailed explanation. 684 685 Returns: 686 A tensor of the specified shape filled with random truncated normal values. 687 """ 688 with ops.name_scope(name, "stateless_truncated_normal", 689 [shape, seed, mean, stddev]) as name: 690 shape = tensor_util.shape_tensor(shape) 691 mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 692 stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 693 key, counter, alg = _get_key_counter_alg(seed, alg) 694 rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2( 695 shape, key=key, counter=counter, dtype=dtype, alg=alg) 696 result = math_ops.add(rnd * stddev, mean, name=name) 697 tensor_util.maybe_set_static_shape(result, shape) 698 return result 699 700 701@tf_export(v1=["random.stateless_multinomial"]) 702@dispatch.add_dispatch_support 703@deprecation.deprecated( 704 date=None, instructions="Use `tf.random.stateless_categorical` instead.") 705def stateless_multinomial(logits, 706 num_samples, 707 seed, 708 output_dtype=dtypes.int64, 709 name=None): 710 """Draws deterministic pseudorandom samples from a multinomial distribution. 711 712 This is a stateless version of `tf.random.categorical`: if run twice with the 713 same seeds and shapes, it will produce the same pseudorandom numbers. The 714 output is consistent across multiple runs on the same hardware (and between 715 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 716 hardware. 717 718 Example: 719 720 ```python 721 # samples has shape [1, 5], where each value is either 0 or 1 with equal 722 # probability. 723 samples = tf.random.stateless_categorical( 724 tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17]) 725 ``` 726 727 Args: 728 logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice 729 `[i, :]` represents the unnormalized log-probabilities for all classes. 730 num_samples: 0-D. Number of independent samples to draw for each row slice. 731 seed: A shape [2] Tensor, the seed to the random number generator. Must have 732 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 733 output_dtype: integer type to use for the output. Defaults to int64. 734 name: Optional name for the operation. 735 736 Returns: 737 The drawn samples of shape `[batch_size, num_samples]`. 738 """ 739 with ops.name_scope(name, "stateless_multinomial", [logits, seed]): 740 return stateless_multinomial_categorical_impl(logits, num_samples, 741 output_dtype, seed) 742 743 744@tf_export("random.stateless_categorical") 745@dispatch.add_dispatch_support 746def stateless_categorical(logits, 747 num_samples, 748 seed, 749 dtype=dtypes.int64, 750 name=None): 751 """Draws deterministic pseudorandom samples from a categorical distribution. 752 753 This is a stateless version of `tf.categorical`: if run twice with the 754 same seeds and shapes, it will produce the same pseudorandom numbers. The 755 output is consistent across multiple runs on the same hardware (and between 756 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 757 hardware. 758 759 760 Example: 761 762 ```python 763 # samples has shape [1, 5], where each value is either 0 or 1 with equal 764 # probability. 765 samples = tf.random.stateless_categorical( 766 tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17]) 767 ``` 768 769 Args: 770 logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice 771 `[i, :]` represents the unnormalized log-probabilities for all classes. 772 num_samples: 0-D. Number of independent samples to draw for each row slice. 773 seed: A shape [2] Tensor, the seed to the random number generator. Must have 774 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 775 dtype: integer type to use for the output. Defaults to int64. 776 name: Optional name for the operation. 777 778 Returns: 779 The drawn samples of shape `[batch_size, num_samples]`. 780 """ 781 with ops.name_scope(name, "stateless_categorical", [logits, seed]): 782 return stateless_multinomial_categorical_impl(logits, num_samples, dtype, 783 seed) 784 785 786def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed): 787 """Implementation for stateless multinomial/categorical ops (v1/v2).""" 788 logits = ops.convert_to_tensor(logits, name="logits") 789 return gen_stateless_random_ops.stateless_multinomial( 790 logits, num_samples, seed, output_dtype=dtype) 791 792 793@dispatch.add_dispatch_support 794@tf_export("random.stateless_parameterized_truncated_normal") 795def stateless_parameterized_truncated_normal(shape, 796 seed, 797 means=0.0, 798 stddevs=1.0, 799 minvals=-2.0, 800 maxvals=2.0, 801 name=None): 802 """Outputs random values from a truncated normal distribution. 803 804 The generated values follow a normal distribution with specified mean and 805 standard deviation, except that values whose magnitude is more than 2 standard 806 deviations from the mean are dropped and re-picked. 807 808 809 Examples: 810 811 Sample from a Truncated normal, with deferring shape parameters that 812 broadcast. 813 814 >>> means = 0. 815 >>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3])) 816 >>> minvals = [-1., -2., -1000.] 817 >>> maxvals = [[10000.], [1.]] 818 >>> y = tf.random.stateless_parameterized_truncated_normal( 819 ... shape=[10, 2, 3], seed=[7, 17], 820 ... means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals) 821 >>> y.shape 822 TensorShape([10, 2, 3]) 823 824 Args: 825 shape: A 1-D integer `Tensor` or Python array. The shape of the output 826 tensor. 827 seed: A shape [2] Tensor, the seed to the random number generator. Must have 828 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 829 means: A `Tensor` or Python value of type `dtype`. The mean of the truncated 830 normal distribution. This must broadcast with `stddevs`, `minvals` and 831 `maxvals`, and the broadcasted shape must be dominated by `shape`. 832 stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation 833 of the truncated normal distribution. This must broadcast with `means`, 834 `minvals` and `maxvals`, and the broadcasted shape must be dominated by 835 `shape`. 836 minvals: A `Tensor` or Python value of type `dtype`. The minimum value of 837 the truncated normal distribution. This must broadcast with `means`, 838 `stddevs` and `maxvals`, and the broadcasted shape must be dominated by 839 `shape`. 840 maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of 841 the truncated normal distribution. This must broadcast with `means`, 842 `stddevs` and `minvals`, and the broadcasted shape must be dominated by 843 `shape`. 844 name: A name for the operation (optional). 845 846 Returns: 847 A tensor of the specified shape filled with random truncated normal values. 848 """ 849 with ops.name_scope(name, "stateless_parameterized_truncated_normal", 850 [shape, means, stddevs, minvals, maxvals]) as name: 851 shape_tensor = tensor_util.shape_tensor(shape) 852 means_tensor = ops.convert_to_tensor(means, name="means") 853 stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs") 854 minvals_tensor = ops.convert_to_tensor(minvals, name="minvals") 855 maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals") 856 rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal( 857 shape_tensor, seed, means_tensor, stddevs_tensor, minvals_tensor, 858 maxvals_tensor) 859 tensor_util.maybe_set_static_shape(rnd, shape) 860 return rnd 861