1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for stateful_random_ops.py.""" 16 17import os 18import re 19 20from absl.testing import parameterized 21import numpy as np 22from tensorflow.python.checkpoint import checkpoint as tracking_util 23from tensorflow.python.eager import context 24from tensorflow.python.eager import def_function 25from tensorflow.python.framework import config 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import test_util 31from tensorflow.python.kernel_tests.random import util as random_test_util 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import gen_random_ops 34from tensorflow.python.ops import gen_stateful_random_ops 35from tensorflow.python.ops import logging_ops 36from tensorflow.python.ops import stateful_random_ops as random 37from tensorflow.python.ops import variables 38from tensorflow.python.platform import test 39 40 41g_seeded = None 42g_unseeded = None 43 44 45GPU_FLOATS = [dtypes.float16, dtypes.float32, dtypes.float64] 46CPU_FLOATS = GPU_FLOATS + [dtypes.bfloat16] 47FLOATS = GPU_FLOATS 48INTS = [dtypes.int32, dtypes.int64] 49 50 51class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase): 52 53 def setUp(self): 54 super(StatefulRandomOpsTest, self).setUp() 55 physical_devices = config.list_physical_devices("CPU") 56 config.set_logical_device_configuration( 57 physical_devices[0], [ 58 context.LogicalDeviceConfiguration(), 59 context.LogicalDeviceConfiguration() 60 ]) 61 62 def testCreateRNGStateIntSeed(self): 63 """Tests `create_rng_state` when `seed` is int.""" 64 # using leading 'F' to test overflow tolerance 65 state = random.create_rng_state(0xFFFF222233334444FFAA666677778888, 66 random.RNG_ALG_PHILOX) 67 self.assertAllEqual( 68 list(map(random._uint_to_int, 69 [0xFFAA666677778888, 0xFFFF222233334444] + 70 [0] * (random.PHILOX_STATE_SIZE - 2))), 71 state) 72 73 def assertAllDifferent(self, tensors): 74 """Checks that there are no duplicate elements anywhere among the tensors. 75 76 Args: 77 tensors: a list of tensors. They can have different shapes. 78 """ 79 tensors = [array_ops.reshape(t, shape=[-1]) for t in tensors] 80 ls = array_ops.concat(tensors, axis=0).numpy().tolist() 81 self.assertAllEqual(len(ls), len(set(ls))) 82 83 @test_util.run_v2_only 84 def testNonDeterministicInts(self): 85 """Tests that non_deterministic_ints returns different results every time. 86 87 This test is flaky, but with very low probability of failing. 88 """ 89 shape = [2, 3] 90 dtype = dtypes.int64 91 a = random.non_deterministic_ints(shape=shape, dtype=dtype) 92 self.assertAllEqual(shape, a.shape) 93 self.assertEqual(dtype, a.dtype) 94 b = random.non_deterministic_ints(shape, dtype=dtype) 95 self.assertAllDifferent([a, b]) 96 97 @test_util.run_v2_only 98 def testBatchSeeds(self): 99 """Test for batch seeds. 100 """ 101 shape = [2, 3] 102 count = 6 103 gen = random.Generator.from_seed(1234) 104 keys1 = gen._make_int64_keys(shape=shape) 105 keys2 = gen._make_int64_keys(shape=shape) 106 self.assertAllDifferent([keys1, keys2]) 107 seeds1 = gen.make_seeds(count=count) 108 seeds2 = gen.make_seeds(count=count) 109 self.assertAllDifferent([seeds1[0, :], seeds2[0, :]]) 110 gens = gen.split(count=count) 111 self.assertAllEqual(count, len(gens)) 112 randoms = [g.uniform_full_int(shape=shape, dtype=dtypes.int32) 113 for g in gens] 114 self.assertAllDifferent(randoms) 115 # Tests graph mode. 116 @def_function.function 117 def f(): 118 return gen.make_seeds(count=count) 119 for _ in range(3): 120 f() 121 122 def assertRegex(self, pattern, text): 123 self.assertTrue( 124 re.search(pattern, text), 125 "Can't find pattern '%s' in text '%s'" % (pattern, text)) 126 127 @test_util.run_v2_only 128 @test_util.run_cuda_only 129 def testCrossDeviceSplit(self): 130 """Tests that a CPU RNG can split into RNGs on GPU. 131 """ 132 with ops.device("/device:CPU:0"): 133 gen = random.Generator.from_seed(1234) # gen is on CPU 134 self.assertRegex("CPU", gen.state.device) 135 with ops.device(test_util.gpu_device_name()): 136 gens = gen.split(count=10) # gens are on GPU 137 self.assertRegex("GPU", gens[0].state.device) 138 139 @test_util.run_v2_only 140 def testSplitInFunction(self): 141 g = random.Generator.from_seed(1) 142 new_g = [None] # using list as mutable cells 143 @def_function.function 144 def f(): 145 if new_g[0] is None: # avoid creating variable in 2nd trace 146 new_g[0] = g.split(2) 147 return [new_g[0][i].normal([]) for i in range(2)] 148 f() 149 150 def testFnVars(self): 151 """Tests that RNG variable is added to ConcreteFunction.variables.""" 152 rng = random.Generator.from_seed(0) 153 @def_function.function 154 def f(): 155 return rng.normal([]) 156 157 concrete = f.get_concrete_function() 158 self.assertIn(rng.state, concrete.variables) 159 160 @test_util.run_v2_only 161 def testReset(self): 162 shape = [2, 3] 163 gen = random.Generator.from_seed(0) 164 for resetter in [ 165 lambda g: g.reset(state=[1, 2, 3]), 166 lambda g: g.reset_from_seed(1234), 167 lambda g: g.reset_from_key_counter(key=1, counter=[2, 3]), 168 ]: 169 resetter(gen) 170 expected_normal = gen.normal(shape) 171 @def_function.function 172 def f(resetter): 173 resetter(gen) 174 return gen.normal(shape) 175 def check_results(expected_normal, v): 176 self.assertAllEqual(expected_normal, v) 177 check_results(expected_normal, f(resetter)) 178 check_results(expected_normal, f(resetter)) 179 180 @test_util.run_v2_only 181 def testGeneratorCreation(self): 182 """Tests generator creation, in both eager and tf.function. 183 184 The interaction between Generator creation and defun should be the same as 185 tf.Variable. 186 """ 187 shape = [2, 3] 188 alg = random.RNG_ALG_PHILOX 189 for constructor in [ 190 lambda: random.Generator(state=[1, 2, 3], alg=alg), 191 lambda: random.Generator.from_seed(1234), 192 lambda: random.Generator.from_key_counter( # pylint: disable=g-long-lambda 193 key=1, counter=[2, 3], alg=alg), 194 ]: 195 gen = constructor() 196 # Tests tf.function 197 expected_normal1 = gen.normal(shape) 198 expected_normal2 = gen.normal(shape) 199 global g_seeded 200 g_seeded = None 201 @def_function.function 202 def f(constructor): 203 global g_seeded 204 # defun'ed function should only create variables once 205 if g_seeded is None: 206 g_seeded = constructor() 207 return g_seeded.normal(shape) 208 def check_results(expected_normal, v): 209 self.assertAllEqual(expected_normal, v) 210 check_results(expected_normal1, f(constructor)) 211 check_results(expected_normal2, f(constructor)) 212 213 @test_util.run_v2_only 214 def testCreateGeneratorFromSymbolic(self): 215 g = [None, None, None] # using list as mutable cells 216 @def_function.function 217 def f(scalar, vector2, vector3): 218 if g[0] is None: # avoid creating variable in 2nd trace 219 g[0] = random.Generator.from_seed(scalar) 220 g[0].reset_from_seed(scalar) # also test reset 221 g[1] = random.Generator.from_state(vector3, random.RNG_ALG_PHILOX) 222 g[1].reset(vector3) 223 g[2] = random.Generator.from_key_counter( 224 scalar, vector2, random.RNG_ALG_PHILOX) 225 g[2].reset_from_key_counter(scalar, vector2) 226 return [g[i].normal([]) for i in range(3)] 227 args = (1, [2, 2], [3, 3, 3]) 228 args = [constant_op.constant(v) for v in args] 229 f(*args) 230 231 @parameterized.parameters([ 232 ("philox", random.RNG_ALG_PHILOX, random.Algorithm.PHILOX), 233 ("threefry", random.RNG_ALG_THREEFRY, random.Algorithm.THREEFRY)]) 234 @test_util.run_v2_only 235 def testAlg(self, name, int_id, enum_id): 236 g_by_name = random.Generator.from_seed(1234, name) 237 g_by_int = random.Generator.from_seed(1234, int_id) 238 g_by_enum = random.Generator.from_seed(1234, enum_id) 239 self.assertEqual(g_by_name.algorithm, g_by_int.algorithm) 240 self.assertEqual(g_by_name.algorithm, g_by_enum.algorithm) 241 242 @test_util.run_v2_only 243 def testGeneratorCreationWithVar(self): 244 """Tests creating generator with a variable. 245 """ 246 alg = random.RNG_ALG_PHILOX 247 state = [1, 2, 3] 248 var = variables.Variable(state, dtype=random.STATE_TYPE) 249 g = random.Generator(state=state, alg=alg) 250 g_var = random.Generator(state=var, alg=alg) 251 shape = [2, 3] 252 g.normal(shape) 253 g_var.normal(shape) 254 self.assertAllEqual(g.state.read_value(), var.read_value()) 255 256 @test_util.run_v2_only 257 def testGeneratorCreationUnseeded(self): 258 """Tests generator creation, the unseeded case.""" 259 shape = [2, 3] 260 global g_unseeded 261 g_unseeded = None 262 @def_function.function 263 def f(): 264 global g_unseeded 265 # defun'ed function should only create variables once 266 if g_unseeded is None: 267 g_unseeded = random.Generator.from_non_deterministic_state() 268 return g_unseeded.normal(shape) 269 self.assertAllEqual(shape, f().shape) 270 271 @test_util.run_v2_only 272 def testGeneratorCopy(self): 273 """Tests copying a generator.""" 274 g = random.Generator.from_seed(0) 275 g_copy = random.Generator(g) 276 self.assertAllEqual(g.algorithm, g_copy.algorithm) 277 self.assertAllEqual(g.state.read_value(), g_copy.state.read_value()) 278 # Tests tf.function 279 global g_seeded 280 g_seeded = None 281 # Do the same in tf.function 282 @def_function.function 283 def f(): 284 global g_seeded 285 # defun'ed function should only create variables once 286 if g_seeded is None: 287 g_seeded = random.Generator(g) 288 self.assertAllEqual(g.algorithm, g_seeded.algorithm) 289 self.assertAllEqual(g.state.read_value(), g_seeded.state.read_value()) 290 f() 291 292 @test_util.run_v1_only( 293 ("This test is specifically for checking TF1 compatibility. " 294 "It cannot run under TF2.")) 295 def testTF1(self): 296 seed = 1234 297 shape = [2, 3] 298 expected_normal1 = constant_op.constant( 299 [[0.9356609, 1.0854305, -0.93788373], 300 [-0.50615472, 1.31697023, 0.71375787]], dtype=dtypes.float32) 301 expected_normal2 = constant_op.constant( 302 [[-0.3964749, 0.8369565, -0.30946946], 303 [1.1206646, 1.00852597, -0.10185789]], dtype=dtypes.float32) 304 with self.cached_session() as sess: 305 gen1 = random.Generator.from_seed(seed) 306 gen2 = random.Generator.from_non_deterministic_state() 307 sess.run((gen1.state.initializer, gen2.state.initializer)) 308 r1 = gen1.normal(shape, dtype=dtypes.float32) 309 r2 = gen2.normal(shape, dtype=dtypes.float32) 310 def f(): 311 return sess.run((r1, r2)) 312 def check_results(expected_normal, v1, v2): 313 self.assertAllClose(expected_normal, v1, rtol=1e-5, atol=1e-5) 314 self.assertAllEqual(shape, v2.shape) 315 check_results(expected_normal1, *f()) 316 check_results(expected_normal2, *f()) 317 318 @test_util.run_v2_only 319 @test_util.also_run_as_tf_function 320 def testEagerAndDefun(self): 321 """A simple test to make sure the op works in eager and defunned mode.""" 322 random.get_global_generator().normal((3,)) 323 324 @test_util.run_v2_only 325 def testOpSeedSelectionAfterSetSeed(self): 326 """Tests that op-seed selection is reset after reseting global generator. 327 328 Fixing GitHub issue 9171: 329 https://github.com/tensorflow/tensorflow/issues/9171 330 """ 331 shape = (3,) 332 random.get_global_generator().reset_from_seed(1) 333 a = random.get_global_generator().normal(shape) 334 random.get_global_generator().reset_from_seed(1) 335 b = random.get_global_generator().normal(shape) 336 self.assertAllEqual(a, b) 337 338 # Now do the above again using accelerated ('defun'ed) computation 339 @def_function.function 340 def f(): 341 return random.get_global_generator().normal(shape) 342 343 random.get_global_generator().reset_from_seed(1) 344 c = f() 345 random.get_global_generator().reset_from_seed(1) 346 d = f() 347 self.assertAllEqual(c, d) 348 self.assertAllEqual(a, c) 349 350 @test_util.run_v2_only 351 def testOpSeedSelectionNotSensitive(self): 352 """Test that op-seed selection is not sensitive to trivial changes. 353 354 Test that op-seed selection is not sensitive to trivial computation 355 (i.e. graph) changes. 356 357 Fixing b/32087099 358 """ 359 def f(include_print): 360 shape = constant_op.constant([5]) 361 if include_print: 362 shape = logging_ops.Print(shape, [shape]) 363 return random.get_global_generator().normal(shape) 364 365 def compare(fst_includes_print, snd_includes_print): 366 random.get_global_generator().reset_from_seed(50) 367 fst = f(fst_includes_print) 368 random.get_global_generator().reset_from_seed(50) 369 snd = f(snd_includes_print) 370 self.assertAllEqual(fst, snd) 371 # Now do the above again using accelerated (defunned) 'f'. 372 # Running 'f' with two different Boolean arguments should cause 373 # two different graphs to be generated, hence demonstrating the 374 # insensitivity to graph changes. 375 f_acc = def_function.function(f) 376 random.get_global_generator().reset_from_seed(50) 377 fst = f_acc(fst_includes_print) 378 random.get_global_generator().reset_from_seed(50) 379 snd = f_acc(snd_includes_print) 380 self.assertAllEqual(fst, snd) 381 382 compare(False, False) 383 compare(True, True) 384 compare(True, False) 385 386 @test_util.run_v2_only 387 def testKey(self): 388 key = 1234 389 gen = random.Generator(state=[0, 0, key], alg=random.RNG_ALG_PHILOX) 390 got = gen.key 391 self.assertAllEqual(key, got) 392 @def_function.function 393 def f(): 394 return gen.key 395 got = f() 396 self.assertAllEqual(key, got) 397 398 @test_util.run_v2_only 399 def testSkip(self): 400 key = 1234 401 counter = 5678 402 gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX) 403 delta = 432 404 gen.skip(delta) 405 new_counter = gen.state[0] 406 self.assertAllEqual(counter + delta * 256, new_counter) 407 408 def _sameAsOldRandomOps(self, device, floats): 409 def compare(dtype, old, new): 410 seed1, seed2 = 79, 25 411 # note how the two seeds for the old op correspond to the seed for the new 412 # op 413 with ops.device(device): 414 gen = random.Generator(state=[0, seed2, seed1], 415 alg=random.RNG_ALG_PHILOX) 416 417 # create a graph for the old op in order to call it many times 418 @def_function.function 419 def run_old(): 420 with ops.device(device): 421 return old(dtype, seed1, seed2) 422 423 def run_new(): 424 with ops.device(device): 425 return new(dtype, gen) 426 427 for _ in range(5): 428 self.assertAllEqual(run_old(), run_new()) 429 430 shape = constant_op.constant([4, 7]) 431 minval = 128 432 maxval = 256 433 434 # passing `dtype` around to compress go/gpylint-faq#cell-var-from-loop and 435 # go/gpylint-faq#undefined-loop-variable 436 def old_normal(dtype, seed1, seed2): 437 return gen_random_ops.random_standard_normal( 438 shape, dtype=dtype, seed=seed1, seed2=seed2) 439 def new_normal(dtype, gen): 440 return gen._standard_normal(shape, dtype=dtype) 441 def old_truncated_normal(dtype, seed1, seed2): 442 return gen_random_ops.truncated_normal( 443 shape, dtype=dtype, seed=seed1, seed2=seed2) 444 def new_truncated_normal(dtype, gen): 445 return gen._truncated_normal(shape, dtype=dtype) 446 def old_uniform_int(dtype, seed1, seed2): 447 minval2 = constant_op.constant(minval, dtype=dtype) 448 maxval2 = constant_op.constant(maxval, dtype=dtype) 449 return gen_random_ops.random_uniform_int( 450 shape, minval=minval2, maxval=maxval2, seed=seed1, seed2=seed2) 451 def new_uniform_int(dtype, gen): 452 return gen.uniform(shape, minval=minval, maxval=maxval, dtype=dtype) 453 def old_uniform(dtype, seed1, seed2): 454 return gen_random_ops.random_uniform( 455 shape, dtype=dtype, seed=seed1, seed2=seed2) 456 def new_uniform(dtype, gen): 457 return gen._uniform(shape, dtype=dtype) 458 459 for dtype in floats: 460 compare(dtype, old_normal, new_normal) 461 compare(dtype, old_truncated_normal, new_truncated_normal) 462 compare(dtype, old_uniform, new_uniform) 463 for dtype in INTS: 464 compare(dtype, old_uniform_int, new_uniform_int) 465 466 @test_util.run_v2_only 467 def testSameAsOldRandomOpsCPU(self): 468 """Tests that the generated numbers are the same as the old random_ops.py. 469 470 The CPU version. 471 """ 472 self._sameAsOldRandomOps("/device:CPU:0", CPU_FLOATS) 473 474 @test_util.run_v2_only 475 @test_util.run_cuda_only 476 def testSameAsOldRandomOpsGPU(self): 477 """Tests that the generated numbers are the same as the old random_ops.py. 478 479 The GPU version. 480 """ 481 self._sameAsOldRandomOps(test_util.gpu_device_name(), GPU_FLOATS) 482 483 @parameterized.parameters(INTS + [dtypes.uint32, dtypes.uint64]) 484 @test_util.run_v2_only 485 @test_util.run_cuda_only 486 def testGPUEqualsCPU(self, dtype): 487 """Tests that GPU and CPU generate the same integer outputs.""" 488 seed = 1234 489 shape = [315, 49] 490 with ops.device("/device:CPU:0"): 491 cpu = random.Generator.from_seed(seed).uniform_full_int( 492 shape=shape, dtype=dtype) 493 with ops.device(test_util.gpu_device_name()): 494 gpu = random.Generator.from_seed(seed).uniform_full_int( 495 shape=shape, dtype=dtype) 496 self.assertAllEqual(cpu, gpu) 497 498 @parameterized.parameters(FLOATS + INTS) 499 @test_util.run_v2_only 500 def testUniformIsInRange(self, dtype): 501 minval = 2 502 maxval = 33 503 size = 1000 504 gen = random.Generator.from_seed(1234) 505 x = gen.uniform( 506 shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy() 507 self.assertTrue(np.all(x >= minval)) 508 self.assertTrue(np.all(x < maxval)) 509 510 @parameterized.parameters(FLOATS) 511 @test_util.run_v2_only 512 def testNormalIsFinite(self, dtype): 513 gen = random.Generator.from_seed(1234) 514 x = gen.normal(shape=[10000], dtype=dtype).numpy() 515 self.assertTrue(np.all(np.isfinite(x))) 516 517 @parameterized.parameters(FLOATS + INTS) 518 @test_util.run_v2_only 519 def testDistributionOfUniform(self, dtype): 520 """Use Pearson's Chi-squared test to test for uniformity.""" 521 n = 1000 522 seed = 12 523 gen = random.Generator.from_seed(seed) 524 maxval = 1 525 if dtype.is_integer: 526 maxval = 100 527 x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy() 528 if maxval > 1: 529 # Normalize y to range [0, 1). 530 x = x.astype(float) / maxval 531 # Tests that the values are distributed amongst 10 bins with equal 532 # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with 533 # p=0.05. This test is probabilistic and would be flaky if the random 534 # seed were not fixed. 535 val = random_test_util.chi_squared(x, 10) 536 self.assertLess(val, 16.92) 537 538 @parameterized.parameters(FLOATS) 539 @test_util.run_v2_only 540 def testDistributionOfNormal(self, dtype): 541 """Use Anderson-Darling test to test distribution appears normal.""" 542 n = 1000 543 gen = random.Generator.from_seed(1234) 544 x = gen.normal(shape=[n], dtype=dtype).numpy() 545 # The constant 2.492 is the 5% critical value for the Anderson-Darling 546 # test where the mean and variance are known. This test is probabilistic 547 # so to avoid flakiness the seed is fixed. 548 self.assertLess( 549 random_test_util.anderson_darling(x.astype(float)), 2.492) 550 551 @test_util.run_v2_only 552 def testErrors(self): 553 """Tests that proper errors are raised. 554 """ 555 shape = [2, 3] 556 gen = random.Generator.from_seed(1234) 557 with self.assertRaisesWithPredicateMatch( 558 errors.InvalidArgumentError, 559 r"must have shape \[\], not"): 560 gen_stateful_random_ops.stateful_standard_normal_v2( 561 gen.state.handle, [0, 0], shape) 562 with self.assertRaisesWithPredicateMatch( 563 errors.InvalidArgumentError, 564 r"must have shape \[\], not"): 565 gen_stateful_random_ops.rng_skip( 566 gen.state.handle, gen.algorithm, [0, 0]) 567 with self.assertRaisesWithPredicateMatch( 568 TypeError, "EagerTensor of dtype int64"): 569 gen_stateful_random_ops.stateful_standard_normal_v2( 570 gen.state.handle, 1.1, shape) 571 with self.assertRaisesWithPredicateMatch( 572 errors.InvalidArgumentError, 573 "Unsupported algorithm id"): 574 gen_stateful_random_ops.stateful_standard_normal_v2( 575 gen.state.handle, 123, shape) 576 var = variables.Variable([0, 0], dtype=dtypes.int32) 577 with self.assertRaisesWithPredicateMatch( 578 errors.InvalidArgumentError, 579 "dtype of RNG state variable must be int64, not"): 580 gen_stateful_random_ops.stateful_standard_normal_v2( 581 var.handle, random.RNG_ALG_PHILOX, shape) 582 var = variables.Variable([[0]], dtype=dtypes.int64) 583 with self.assertRaisesWithPredicateMatch( 584 errors.InvalidArgumentError, 585 "RNG state must have one and only one dimension, not"): 586 gen_stateful_random_ops.stateful_standard_normal_v2( 587 var.handle, random.RNG_ALG_PHILOX, shape) 588 var = variables.Variable([0], dtype=dtypes.int64) 589 with self.assertRaisesWithPredicateMatch( 590 errors.InvalidArgumentError, 591 "For the Philox algorithm, the size of state must be at least"): 592 gen_stateful_random_ops.stateful_standard_normal_v2( 593 var.handle, random.RNG_ALG_PHILOX, shape) 594 with self.assertRaisesWithPredicateMatch( 595 ValueError, 596 "minval must be a scalar; got a tensor of shape "): 597 @def_function.function 598 def f(): 599 gen.uniform(shape=shape, minval=array_ops.zeros(shape, "int32"), 600 maxval=100, dtype="int32") 601 f() 602 with self.assertRaisesWithPredicateMatch( 603 ValueError, 604 "maxval must be a scalar; got a tensor of shape "): 605 @def_function.function 606 def f2(): 607 gen.uniform( 608 shape=shape, minval=0, maxval=array_ops.ones(shape, "int32") * 100, 609 dtype="int32") 610 f2() 611 612 @test_util.run_v2_only 613 def testGetGlobalGeneratorWithXla(self): 614 """Demonstrates using the global generator with XLA.""" 615 # This test was passing before because soft placement silently picked the 616 # CPU kernel. 617 # TODO(wangpeng): Remove this skip 618 self.skipTest("NonDeterministicInts lacks XLA kernel.") 619 620 if not config.list_physical_devices("XLA_CPU"): 621 self.skipTest("No XLA_CPU device available.") 622 623 random.set_global_generator(None) 624 625 @def_function.function(jit_compile=True) 626 def make_seed(): 627 generator = random.get_global_generator() 628 state = array_ops.identity(generator.state, name="state") 629 return generator.uniform_full_int((2,), dtypes.int32, name="seed"), state 630 631 with ops.device("/device:XLA_CPU:0"): 632 seed, state = make_seed() 633 self.assertTrue(np.all(np.isfinite(seed.numpy()))) 634 random.get_global_generator().reset(state) 635 self.assertAllEqual(make_seed()[0], seed) 636 637 @test_util.run_v2_only 638 def testSetGlobalGeneratorBadWithDefun(self): 639 """Demonstrates set_global_generator does not affect compiled tf.function.""" 640 shape = (3,) 641 642 @def_function.function 643 def f(): 644 return random.get_global_generator().normal(shape) 645 646 random.set_global_generator(random.Generator.from_seed(50)) 647 samples = f() 648 # Resetting global generator has no effect to the compiled tf.function. 649 random.set_global_generator(random.Generator.from_seed(50)) 650 # New samples are returned. 651 self.assertNotAllEqual(samples, f()) 652 653 @test_util.run_v2_only 654 def testFunctionArg(self): 655 """Tests that RNG can be used as tf.function's argument. 656 """ 657 shape = [2, 3] 658 @def_function.function 659 def f(gen): 660 return gen.normal(shape) 661 g1 = random.Generator.from_seed(1) 662 g2 = random.Generator.from_seed(1) 663 res1 = f(g1) 664 res2 = g2.normal(shape) 665 self.assertAllEqual(res1, res2) 666 self.assertAllEqual(g1.state.read_value(), g2.state.read_value()) 667 668 @test_util.run_v2_only 669 def testUniformFullInt(self): 670 """Tests full-range int uniform. 671 """ 672 shape = [3, 4] 673 dtype = dtypes.int32 674 g = random.Generator.from_seed(1) 675 r1 = g.uniform(shape=shape, dtype=dtype, minval=None) 676 g = random.Generator.from_seed(1) 677 r2 = g.uniform_full_int(shape=shape, dtype=dtype) 678 self.assertAllEqual(r1, r2) 679 680 @test_util.run_v2_only 681 def testRestore(self): 682 """Tests save and restore. 683 """ 684 fname = os.path.join(self.get_temp_dir(), "checkpoint") 685 g = random.Generator.from_seed(1) 686 cp = tracking_util.Checkpoint(g=g) 687 def write_restore_compare(): 688 cp.write(fname) 689 r1 = g.uniform([], dtype=dtypes.uint32, minval=None) 690 cp.restore(fname) 691 r2 = g.uniform([], dtype=dtypes.uint32, minval=None) 692 self.assertAllEqual(r1, r2) 693 # Run multiple times so that cp.write is called in various RNG states 694 for _ in range(2): 695 write_restore_compare() 696 697 @test_util.run_v2_only 698 def testDeterministicOpsErrors(self): 699 try: 700 config.enable_op_determinism() 701 random.set_global_generator(None) 702 with self.assertRaisesWithPredicateMatch( 703 RuntimeError, 704 '"get_global_generator" cannot be called if determinism is enabled'): 705 random.get_global_generator() 706 random.set_global_generator(random.Generator.from_seed(50)) 707 random.get_global_generator() 708 with self.assertRaisesWithPredicateMatch( 709 RuntimeError, 710 '"from_non_deterministic_state" cannot be called when determinism ' 711 "is enabled."): 712 random.Generator.from_non_deterministic_state() 713 finally: 714 config.disable_op_determinism() 715 716 717if __name__ == "__main__": 718 config.set_soft_device_placement(False) 719 test.main() 720