1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Stateless random ops which take seed as a tensor input.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.compat import compat 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import gen_stateless_random_ops 29from tensorflow.python.ops import gen_stateless_random_ops_v2 30from tensorflow.python.ops import math_ops 31from tensorflow.python.util import deprecation 32from tensorflow.python.util import dispatch 33from tensorflow.python.util.tf_export import tf_export 34 35 36ops.NotDifferentiable("StatelessMultinomial") 37ops.NotDifferentiable("StatelessRandomBinomial") 38ops.NotDifferentiable("StatelessRandomNormal") 39ops.NotDifferentiable("StatelessRandomPoisson") 40ops.NotDifferentiable("StatelessRandomUniform") 41ops.NotDifferentiable("StatelessRandomUniformInt") 42ops.NotDifferentiable("StatelessRandomUniformFullInt") 43ops.NotDifferentiable("StatelessTruncatedNormal") 44 45 46ops.NotDifferentiable("StatelessRandomNormalV2") 47ops.NotDifferentiable("StatelessRandomUniformV2") 48ops.NotDifferentiable("StatelessRandomUniformIntV2") 49ops.NotDifferentiable("StatelessRandomUniformFullIntV2") 50ops.NotDifferentiable("StatelessTruncatedNormalV2") 51 52 53@tf_export("random.experimental.stateless_split") 54@dispatch.add_dispatch_support 55def split(seed, num=2): 56 """Splits an RNG seed into `num` new seeds by adding a leading axis. 57 58 Example: 59 60 >>> seed = [1, 2] 61 >>> new_seeds = tf.random.experimental.stateless_split(seed, num=3) 62 >>> print(new_seeds) 63 tf.Tensor( 64 [[1105988140 1738052849] 65 [-335576002 370444179] 66 [ 10670227 -246211131]], shape=(3, 2), dtype=int32) 67 >>> tf.random.stateless_normal(shape=[3], seed=new_seeds[0, :]) 68 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([-0.59835213, -0.9578608 , 69 0.9002807 ], dtype=float32)> 70 71 Args: 72 seed: an RNG seed (a tensor with shape [2] and dtype `int32` or 73 `int64`). (When using XLA, only `int32` is allowed.) 74 num: optional, a positive integer or scalar tensor indicating the number of 75 seeds to produce (default 2). 76 77 Returns: 78 A tensor with shape [num, 2] representing `num` new seeds. It will have the 79 same dtype as `seed` (if `seed` doesn't have an explict dtype, the dtype 80 will be determined by `tf.convert_to_tensor`). 81 """ 82 seed = ops.convert_to_tensor(seed) 83 return stateless_random_uniform(shape=[num, 2], seed=seed, dtype=seed.dtype, 84 minval=None, maxval=None) 85 86 87@tf_export("random.experimental.stateless_fold_in") 88@dispatch.add_dispatch_support 89def fold_in(seed, data): 90 """Folds in data to an RNG seed to form a new RNG seed. 91 92 For example, in a distributed-training setting, suppose we have a master seed 93 and a replica ID. We want to fold the replica ID into the master seed to 94 form a "replica seed" to be used by that replica later on, so that different 95 replicas will generate different random numbers but the reproducibility of the 96 whole system can still be controlled by the master seed: 97 98 >>> master_seed = [1, 2] 99 >>> replica_id = 3 100 >>> replica_seed = tf.random.experimental.stateless_fold_in( 101 ... master_seed, replica_id) 102 >>> print(replica_seed) 103 tf.Tensor([1105988140 3], shape=(2,), dtype=int32) 104 >>> tf.random.stateless_normal(shape=[3], seed=replica_seed) 105 <tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.03197195, 0.8979765 , 106 0.13253039], dtype=float32)> 107 108 Args: 109 seed: an RNG seed (a tensor with shape [2] and dtype `int32` or 110 `int64`). (When using XLA, only `int32` is allowed.) 111 data: an `int32` or `int64` scalar representing data to be folded in to the 112 seed. 113 114 Returns: 115 A new RNG seed that is a deterministic function of the inputs and is 116 statistically safe for producing a stream of new pseudo-random values. It 117 will have the same dtype as `data` (if `data` doesn't have an explict dtype, 118 the dtype will be determined by `tf.convert_to_tensor`). 119 """ 120 data = ops.convert_to_tensor(data) 121 seed1 = stateless_random_uniform(shape=[], seed=seed, dtype=data.dtype, 122 minval=None, maxval=None) 123 return array_ops.stack([seed1, data]) 124 125 126def _get_key_counter_alg(seed): 127 if compat.forward_compatible(2021, 3, 1): 128 key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter( 129 seed) 130 alg = gen_stateless_random_ops_v2.stateless_random_get_alg() 131 return key, counter, alg 132 else: 133 return gen_stateless_random_ops_v2.stateless_random_get_key_counter_alg( 134 seed) 135 136 137@tf_export("random.stateless_uniform") 138@dispatch.add_dispatch_support 139def stateless_random_uniform(shape, 140 seed, 141 minval=0, 142 maxval=None, 143 dtype=dtypes.float32, 144 name=None): 145 """Outputs deterministic pseudorandom values from a uniform distribution. 146 147 This is a stateless version of `tf.random.uniform`: if run twice with the 148 same seeds and shapes, it will produce the same pseudorandom numbers. The 149 output is consistent across multiple runs on the same hardware (and between 150 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 151 hardware. 152 153 The generated values follow a uniform distribution in the range 154 `[minval, maxval)`. The lower bound `minval` is included in the range, while 155 the upper bound `maxval` is excluded. 156 157 For floats, the default range is `[0, 1)`. For ints, at least `maxval` must 158 be specified explicitly. 159 160 In the integer case, the random integers are slightly biased unless 161 `maxval - minval` is an exact power of two. The bias is small for values of 162 `maxval - minval` significantly smaller than the range of the output (either 163 `2**32` or `2**64`). 164 165 For full-range (i.e. inclusive of both max and min) random integers, pass 166 `minval=None` and `maxval=None` with an integer `dtype`. For an integer dtype 167 either both `minval` and `maxval` must be `None` or neither may be `None`. For 168 example: 169 ```python 170 ints = tf.random.stateless_uniform( 171 [10], seed=(2, 3), minval=None, maxval=None, dtype=tf.int32) 172 ``` 173 174 Args: 175 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 176 seed: A shape [2] Tensor, the seed to the random number generator. Must have 177 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 178 minval: A Tensor or Python value of type `dtype`, broadcastable with 179 `shape` (for integer types, broadcasting is not supported, so it needs to 180 be a scalar). The lower bound on the range of random values to 181 generate. Pass `None` for full-range integers. Defaults to 0. 182 maxval: A Tensor or Python value of type `dtype`, broadcastable with 183 `shape` (for integer types, broadcasting is not supported, so it needs to 184 be a scalar). The upper bound on the range of random values to generate. 185 Defaults to 1 if `dtype` is floating point. Pass `None` for full-range 186 integers. 187 dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or 188 `int64`. For unbounded uniform ints (`minval`, `maxval` both `None`), 189 `uint32` and `uint64` may be used. 190 name: A name for the operation (optional). 191 192 Returns: 193 A tensor of the specified shape filled with random uniform values. 194 195 Raises: 196 ValueError: If `dtype` is integral and only one of `minval` or `maxval` is 197 specified. 198 """ 199 dtype = dtypes.as_dtype(dtype) 200 if dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32, 201 dtypes.float64, dtypes.int32, dtypes.int64, dtypes.uint32, 202 dtypes.uint64): 203 raise ValueError("Invalid dtype %r" % dtype) 204 if dtype.is_integer: 205 if (minval is None) != (maxval is None): 206 raise ValueError("For integer dtype {}, minval and maxval must be both " 207 "`None` or both non-`None`.".format(dtype)) 208 if minval is not None and dtype in (dtypes.uint32, dtypes.uint64): 209 raise ValueError("Invalid dtype for bounded uniform integers: %r" % dtype) 210 elif maxval is None: 211 maxval = 1 212 with ops.name_scope(name, "stateless_random_uniform", 213 [shape, seed, minval, maxval]) as name: 214 shape = tensor_util.shape_tensor(shape) 215 if dtype.is_integer and minval is None: 216 if compat.forward_compatible(2020, 10, 25): 217 key, counter, alg = _get_key_counter_alg(seed) 218 result = (gen_stateless_random_ops_v2 219 .stateless_random_uniform_full_int_v2( 220 shape, key=key, counter=counter, dtype=dtype, alg=alg, 221 name=name)) 222 else: 223 result = gen_stateless_random_ops.stateless_random_uniform_full_int( 224 shape, seed=seed, dtype=dtype, name=name) 225 else: 226 minval = ops.convert_to_tensor(minval, dtype=dtype, name="min") 227 maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max") 228 if dtype.is_integer: 229 if compat.forward_compatible(2020, 10, 25): 230 key, counter, alg = _get_key_counter_alg(seed) 231 result = gen_stateless_random_ops_v2.stateless_random_uniform_int_v2( 232 shape, key=key, counter=counter, minval=minval, maxval=maxval, 233 alg=alg, name=name) 234 else: 235 result = gen_stateless_random_ops.stateless_random_uniform_int( 236 shape, seed=seed, minval=minval, maxval=maxval, name=name) 237 else: 238 if compat.forward_compatible(2020, 10, 25): 239 key, counter, alg = _get_key_counter_alg(seed) 240 rnd = gen_stateless_random_ops_v2.stateless_random_uniform_v2( 241 shape, key=key, counter=counter, dtype=dtype, alg=alg) 242 else: 243 rnd = gen_stateless_random_ops.stateless_random_uniform( 244 shape, seed=seed, dtype=dtype) 245 result = math_ops.add(rnd * (maxval - minval), minval, name=name) 246 tensor_util.maybe_set_static_shape(result, shape) 247 return result 248 249 250@tf_export("random.stateless_binomial") 251@dispatch.add_dispatch_support 252def stateless_random_binomial(shape, 253 seed, 254 counts, 255 probs, 256 output_dtype=dtypes.int32, 257 name=None): 258 """Outputs deterministic pseudorandom values from a binomial distribution. 259 260 The generated values follow a binomial distribution with specified count and 261 probability of success parameters. 262 263 This is a stateless version of `tf.random.Generator.binomial`: if run twice 264 with the same seeds and shapes, it will produce the same pseudorandom numbers. 265 The output is consistent across multiple runs on the same hardware (and 266 between CPU and GPU), but may change between versions of TensorFlow or on 267 non-CPU/GPU hardware. 268 269 Example: 270 271 ```python 272 counts = [10., 20.] 273 # Probability of success. 274 probs = [0.8] 275 276 binomial_samples = tf.random.stateless_binomial( 277 shape=[2], seed=[123, 456], counts=counts, probs=probs) 278 279 counts = ... # Shape [3, 1, 2] 280 probs = ... # Shape [1, 4, 2] 281 shape = [3, 4, 3, 4, 2] 282 # Sample shape will be [3, 4, 3, 4, 2] 283 binomial_samples = tf.random.stateless_binomial( 284 shape=shape, seed=[123, 456], counts=counts, probs=probs) 285 ``` 286 287 Args: 288 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 289 seed: A shape [2] Tensor, the seed to the random number generator. Must have 290 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 291 counts: Tensor. The counts of the binomial distribution. Must be 292 broadcastable with `probs`, and broadcastable with the rightmost 293 dimensions of `shape`. 294 probs: Tensor. The probability of success for the binomial distribution. 295 Must be broadcastable with `counts` and broadcastable with the rightmost 296 dimensions of `shape`. 297 output_dtype: The type of the output. Default: tf.int32 298 name: A name for the operation (optional). 299 300 Returns: 301 samples: A Tensor of the specified shape filled with random binomial 302 values. For each i, each samples[..., i] is an independent draw from 303 the binomial distribution on counts[i] trials with probability of 304 success probs[i]. 305 306 """ 307 with ops.name_scope(name, "stateless_random_binomial", 308 [shape, seed, counts, probs]) as name: 309 shape = tensor_util.shape_tensor(shape) 310 probs = ops.convert_to_tensor( 311 probs, dtype_hint=dtypes.float32, name="probs") 312 counts = ops.convert_to_tensor( 313 counts, dtype_hint=probs.dtype, name="counts") 314 result = gen_stateless_random_ops.stateless_random_binomial( 315 shape=shape, seed=seed, counts=counts, probs=probs, dtype=output_dtype) 316 tensor_util.maybe_set_static_shape(result, shape) 317 return result 318 319 320@tf_export("random.stateless_gamma") 321@dispatch.add_dispatch_support 322def stateless_random_gamma(shape, 323 seed, 324 alpha, 325 beta=None, 326 dtype=dtypes.float32, 327 name=None): 328 """Outputs deterministic pseudorandom values from a gamma distribution. 329 330 The generated values follow a gamma distribution with specified concentration 331 (`alpha`) and inverse scale (`beta`) parameters. 332 333 This is a stateless version of `tf.random.gamma`: if run twice with the same 334 seeds and shapes, it will produce the same pseudorandom numbers. The output is 335 consistent across multiple runs on the same hardware (and between CPU and 336 GPU), 337 but may change between versions of TensorFlow or on non-CPU/GPU hardware. 338 339 A slight difference exists in the interpretation of the `shape` parameter 340 between `stateless_gamma` and `gamma`: in `gamma`, the `shape` is always 341 prepended to the shape of the broadcast of `alpha` with `beta`; whereas in 342 `stateless_gamma` the `shape` parameter must always encompass the shapes of 343 each of `alpha` and `beta` (which must broadcast together to match the 344 trailing dimensions of `shape`). 345 346 Note: Because internal calculations are done using `float64` and casting has 347 `floor` semantics, we must manually map zero outcomes to the smallest 348 possible positive floating-point value, i.e., `np.finfo(dtype).tiny`. This 349 means that `np.finfo(dtype).tiny` occurs more frequently than it otherwise 350 should. This bias can only happen for small values of `alpha`, i.e., 351 `alpha << 1` or large values of `beta`, i.e., `beta >> 1`. 352 353 The samples are differentiable w.r.t. alpha and beta. 354 The derivatives are computed using the approach described in 355 (Figurnov et al., 2018). 356 357 Example: 358 359 ```python 360 samples = tf.random.stateless_gamma([10, 2], seed=[12, 34], alpha=[0.5, 1.5]) 361 # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents 362 # the samples drawn from each distribution 363 364 samples = tf.random.stateless_gamma([7, 5, 2], seed=[12, 34], alpha=[.5, 1.5]) 365 # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] 366 # represents the 7x5 samples drawn from each of the two distributions 367 368 alpha = tf.constant([[1.], [3.], [5.]]) 369 beta = tf.constant([[3., 4.]]) 370 samples = tf.random.stateless_gamma( 371 [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta) 372 # samples has shape [30, 3, 2], with 30 samples each of 3x2 distributions. 373 374 with tf.GradientTape() as tape: 375 tape.watch([alpha, beta]) 376 loss = tf.reduce_mean(tf.square(tf.random.stateless_gamma( 377 [30, 3, 2], seed=[12, 34], alpha=alpha, beta=beta))) 378 dloss_dalpha, dloss_dbeta = tape.gradient(loss, [alpha, beta]) 379 # unbiased stochastic derivatives of the loss function 380 alpha.shape == dloss_dalpha.shape # True 381 beta.shape == dloss_dbeta.shape # True 382 ``` 383 384 Args: 385 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 386 seed: A shape [2] Tensor, the seed to the random number generator. Must have 387 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 388 alpha: Tensor. The concentration parameter of the gamma distribution. Must 389 be broadcastable with `beta`, and broadcastable with the rightmost 390 dimensions of `shape`. 391 beta: Tensor. The inverse scale parameter of the gamma distribution. Must be 392 broadcastable with `alpha` and broadcastable with the rightmost dimensions 393 of `shape`. 394 dtype: Floating point dtype of `alpha`, `beta`, and the output. 395 name: A name for the operation (optional). 396 397 Returns: 398 samples: A Tensor of the specified shape filled with random gamma values. 399 For each i, each `samples[..., i] is an independent draw from the gamma 400 distribution with concentration alpha[i] and scale beta[i]. 401 402 """ 403 with ops.name_scope(name, "stateless_random_gamma", 404 [shape, seed, alpha, beta]) as name: 405 shape = tensor_util.shape_tensor(shape) 406 alpha = ops.convert_to_tensor(alpha, dtype=dtype, name="alpha") 407 beta = ops.convert_to_tensor( 408 beta if beta is not None else 1, name="beta", dtype=dtype) 409 broadcast_shape = array_ops.broadcast_dynamic_shape( 410 array_ops.shape(alpha), array_ops.shape(beta)) 411 alpha_broadcast = array_ops.broadcast_to(alpha, broadcast_shape) 412 result = math_ops.maximum( 413 np.finfo(alpha.dtype.as_numpy_dtype).tiny, 414 gen_stateless_random_ops.stateless_random_gamma_v2( 415 shape, seed=seed, alpha=alpha_broadcast) / beta) 416 tensor_util.maybe_set_static_shape(result, shape) 417 return result 418 419 420@tf_export("random.stateless_poisson") 421@dispatch.add_dispatch_support 422def stateless_random_poisson(shape, 423 seed, 424 lam, 425 dtype=dtypes.int32, 426 name=None): 427 """Outputs deterministic pseudorandom values from a Poisson distribution. 428 429 The generated values follow a Poisson distribution with specified rate 430 parameter. 431 432 This is a stateless version of `tf.random.poisson`: if run twice with the same 433 seeds and shapes, it will produce the same pseudorandom numbers. The output is 434 consistent across multiple runs on the same hardware, but may change between 435 versions of TensorFlow or on non-CPU/GPU hardware. 436 437 A slight difference exists in the interpretation of the `shape` parameter 438 between `stateless_poisson` and `poisson`: in `poisson`, the `shape` is always 439 prepended to the shape of `lam`; whereas in `stateless_poisson` the shape of 440 `lam` must match the trailing dimensions of `shape`. 441 442 Example: 443 444 ```python 445 samples = tf.random.stateless_poisson([10, 2], seed=[12, 34], lam=[5, 15]) 446 # samples has shape [10, 2], where each slice [:, 0] and [:, 1] represents 447 # the samples drawn from each distribution 448 449 samples = tf.random.stateless_poisson([7, 5, 2], seed=[12, 34], lam=[5, 15]) 450 # samples has shape [7, 5, 2], where each slice [:, :, 0] and [:, :, 1] 451 # represents the 7x5 samples drawn from each of the two distributions 452 453 rate = tf.constant([[1.], [3.], [5.]]) 454 samples = tf.random.stateless_poisson([30, 3, 1], seed=[12, 34], lam=rate) 455 # samples has shape [30, 3, 1], with 30 samples each of 3x1 distributions. 456 ``` 457 458 Args: 459 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 460 seed: A shape [2] Tensor, the seed to the random number generator. Must have 461 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 462 lam: Tensor. The rate parameter "lambda" of the Poisson distribution. Shape 463 must match the rightmost dimensions of `shape`. 464 dtype: Dtype of the samples (int or float dtypes are permissible, as samples 465 are discrete). Default: int32. 466 name: A name for the operation (optional). 467 468 Returns: 469 samples: A Tensor of the specified shape filled with random Poisson values. 470 For each i, each `samples[..., i]` is an independent draw from the Poisson 471 distribution with rate `lam[i]`. 472 473 """ 474 with ops.name_scope(name, "stateless_random_poisson", 475 [shape, seed, lam]) as name: 476 shape = tensor_util.shape_tensor(shape) 477 result = gen_stateless_random_ops.stateless_random_poisson( 478 shape, seed=seed, lam=lam, dtype=dtype) 479 tensor_util.maybe_set_static_shape(result, shape) 480 return result 481 482 483@tf_export("random.stateless_normal") 484@dispatch.add_dispatch_support 485def stateless_random_normal(shape, 486 seed, 487 mean=0.0, 488 stddev=1.0, 489 dtype=dtypes.float32, 490 name=None): 491 """Outputs deterministic pseudorandom values from a normal distribution. 492 493 This is a stateless version of `tf.random.normal`: if run twice with the 494 same seeds and shapes, it will produce the same pseudorandom numbers. The 495 output is consistent across multiple runs on the same hardware (and between 496 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 497 hardware. 498 499 Args: 500 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 501 seed: A shape [2] Tensor, the seed to the random number generator. Must have 502 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 503 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal 504 distribution. 505 stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation 506 of the normal distribution. 507 dtype: The type of the output. 508 name: A name for the operation (optional). 509 510 Returns: 511 A tensor of the specified shape filled with random normal values. 512 """ 513 with ops.name_scope(name, "stateless_random_normal", 514 [shape, seed, mean, stddev]) as name: 515 shape = tensor_util.shape_tensor(shape) 516 mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 517 stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 518 if compat.forward_compatible(2021, 3, 1): 519 key, counter, alg = _get_key_counter_alg(seed) 520 rnd = gen_stateless_random_ops_v2.stateless_random_normal_v2( 521 shape, key=key, counter=counter, dtype=dtype, alg=alg) 522 else: 523 rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype) 524 result = math_ops.add(rnd * stddev, mean, name=name) 525 tensor_util.maybe_set_static_shape(result, shape) 526 return result 527 528 529@tf_export("random.stateless_truncated_normal") 530@dispatch.add_dispatch_support 531def stateless_truncated_normal(shape, 532 seed, 533 mean=0.0, 534 stddev=1.0, 535 dtype=dtypes.float32, 536 name=None): 537 """Outputs deterministic pseudorandom values, truncated normally distributed. 538 539 This is a stateless version of `tf.random.truncated_normal`: if run twice with 540 the same seeds and shapes, it will produce the same pseudorandom numbers. The 541 output is consistent across multiple runs on the same hardware (and between 542 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 543 hardware. 544 545 The generated values follow a normal distribution with specified mean and 546 standard deviation, except that values whose magnitude is more than 2 standard 547 deviations from the mean are dropped and re-picked. 548 549 Args: 550 shape: A 1-D integer Tensor or Python array. The shape of the output tensor. 551 seed: A shape [2] Tensor, the seed to the random number generator. Must have 552 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 553 mean: A 0-D Tensor or Python value of type `dtype`. The mean of the 554 truncated normal distribution. 555 stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation 556 of the normal distribution, before truncation. 557 dtype: The type of the output. 558 name: A name for the operation (optional). 559 560 Returns: 561 A tensor of the specified shape filled with random truncated normal values. 562 """ 563 with ops.name_scope(name, "stateless_truncated_normal", 564 [shape, seed, mean, stddev]) as name: 565 shape = tensor_util.shape_tensor(shape) 566 mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean") 567 stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev") 568 if compat.forward_compatible(2020, 10, 25): 569 key, counter, alg = _get_key_counter_alg(seed) 570 rnd = gen_stateless_random_ops_v2.stateless_truncated_normal_v2( 571 shape, key=key, counter=counter, dtype=dtype, alg=alg) 572 else: 573 rnd = gen_stateless_random_ops.stateless_truncated_normal( 574 shape, seed, dtype) 575 result = math_ops.add(rnd * stddev, mean, name=name) 576 tensor_util.maybe_set_static_shape(result, shape) 577 return result 578 579 580@tf_export(v1=["random.stateless_multinomial"]) 581@dispatch.add_dispatch_support 582@deprecation.deprecated( 583 date=None, instructions="Use `tf.random.stateless_categorical` instead.") 584def stateless_multinomial(logits, 585 num_samples, 586 seed, 587 output_dtype=dtypes.int64, 588 name=None): 589 """Draws deterministic pseudorandom samples from a multinomial distribution. 590 591 This is a stateless version of `tf.random.categorical`: if run twice with the 592 same seeds and shapes, it will produce the same pseudorandom numbers. The 593 output is consistent across multiple runs on the same hardware (and between 594 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 595 hardware. 596 597 Example: 598 599 ```python 600 # samples has shape [1, 5], where each value is either 0 or 1 with equal 601 # probability. 602 samples = tf.random.stateless_categorical( 603 tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17]) 604 ``` 605 606 Args: 607 logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice 608 `[i, :]` represents the unnormalized log-probabilities for all classes. 609 num_samples: 0-D. Number of independent samples to draw for each row slice. 610 seed: A shape [2] Tensor, the seed to the random number generator. Must have 611 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 612 output_dtype: integer type to use for the output. Defaults to int64. 613 name: Optional name for the operation. 614 615 Returns: 616 The drawn samples of shape `[batch_size, num_samples]`. 617 """ 618 with ops.name_scope(name, "stateless_multinomial", [logits, seed]): 619 return stateless_multinomial_categorical_impl(logits, num_samples, 620 output_dtype, seed) 621 622 623@tf_export("random.stateless_categorical") 624@dispatch.add_dispatch_support 625def stateless_categorical(logits, 626 num_samples, 627 seed, 628 dtype=dtypes.int64, 629 name=None): 630 """Draws deterministic pseudorandom samples from a categorical distribution. 631 632 This is a stateless version of `tf.categorical`: if run twice with the 633 same seeds and shapes, it will produce the same pseudorandom numbers. The 634 output is consistent across multiple runs on the same hardware (and between 635 CPU and GPU), but may change between versions of TensorFlow or on non-CPU/GPU 636 hardware. 637 638 639 Example: 640 641 ```python 642 # samples has shape [1, 5], where each value is either 0 or 1 with equal 643 # probability. 644 samples = tf.random.stateless_categorical( 645 tf.math.log([[0.5, 0.5]]), 5, seed=[7, 17]) 646 ``` 647 648 Args: 649 logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice 650 `[i, :]` represents the unnormalized log-probabilities for all classes. 651 num_samples: 0-D. Number of independent samples to draw for each row slice. 652 seed: A shape [2] Tensor, the seed to the random number generator. Must have 653 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 654 dtype: integer type to use for the output. Defaults to int64. 655 name: Optional name for the operation. 656 657 Returns: 658 The drawn samples of shape `[batch_size, num_samples]`. 659 """ 660 with ops.name_scope(name, "stateless_categorical", [logits, seed]): 661 return stateless_multinomial_categorical_impl(logits, num_samples, dtype, 662 seed) 663 664 665def stateless_multinomial_categorical_impl(logits, num_samples, dtype, seed): 666 """Implementation for stateless multinomial/categorical ops (v1/v2).""" 667 logits = ops.convert_to_tensor(logits, name="logits") 668 return gen_stateless_random_ops.stateless_multinomial( 669 logits, num_samples, seed, output_dtype=dtype) 670 671 672@dispatch.add_dispatch_support 673@tf_export("random.stateless_parameterized_truncated_normal") 674def stateless_parameterized_truncated_normal(shape, 675 seed, 676 means=0.0, 677 stddevs=1.0, 678 minvals=-2.0, 679 maxvals=2.0, 680 name=None): 681 """Outputs random values from a truncated normal distribution. 682 683 The generated values follow a normal distribution with specified mean and 684 standard deviation, except that values whose magnitude is more than 2 standard 685 deviations from the mean are dropped and re-picked. 686 687 688 Examples: 689 690 Sample from a Truncated normal, with deferring shape parameters that 691 broadcast. 692 693 >>> means = 0. 694 >>> stddevs = tf.math.exp(tf.random.uniform(shape=[2, 3])) 695 >>> minvals = [-1., -2., -1000.] 696 >>> maxvals = [[10000.], [1.]] 697 >>> y = tf.random.stateless_parameterized_truncated_normal( 698 ... shape=[10, 2, 3], seed=[7, 17], 699 ... means=means, stddevs=stddevs, minvals=minvals, maxvals=maxvals) 700 >>> y.shape 701 TensorShape([10, 2, 3]) 702 703 Args: 704 shape: A 1-D integer `Tensor` or Python array. The shape of the output 705 tensor. 706 seed: A shape [2] Tensor, the seed to the random number generator. Must have 707 dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.) 708 means: A `Tensor` or Python value of type `dtype`. The mean of the truncated 709 normal distribution. This must broadcast with `stddevs`, `minvals` and 710 `maxvals`, and the broadcasted shape must be dominated by `shape`. 711 stddevs: A `Tensor` or Python value of type `dtype`. The standard deviation 712 of the truncated normal distribution. This must broadcast with `means`, 713 `minvals` and `maxvals`, and the broadcasted shape must be dominated by 714 `shape`. 715 minvals: A `Tensor` or Python value of type `dtype`. The minimum value of 716 the truncated normal distribution. This must broadcast with `means`, 717 `stddevs` and `maxvals`, and the broadcasted shape must be dominated by 718 `shape`. 719 maxvals: A `Tensor` or Python value of type `dtype`. The maximum value of 720 the truncated normal distribution. This must broadcast with `means`, 721 `stddevs` and `minvals`, and the broadcasted shape must be dominated by 722 `shape`. 723 name: A name for the operation (optional). 724 725 Returns: 726 A tensor of the specified shape filled with random truncated normal values. 727 """ 728 with ops.name_scope(name, "stateless_parameterized_truncated_normal", 729 [shape, means, stddevs, minvals, maxvals]) as name: 730 shape_tensor = tensor_util.shape_tensor(shape) 731 means_tensor = ops.convert_to_tensor(means, name="means") 732 stddevs_tensor = ops.convert_to_tensor(stddevs, name="stddevs") 733 minvals_tensor = ops.convert_to_tensor(minvals, name="minvals") 734 maxvals_tensor = ops.convert_to_tensor(maxvals, name="maxvals") 735 rnd = gen_stateless_random_ops.stateless_parameterized_truncated_normal( 736 shape_tensor, seed, means_tensor, stddevs_tensor, minvals_tensor, 737 maxvals_tensor) 738 tensor_util.maybe_set_static_shape(rnd, shape) 739 return rnd 740