• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for stateful_random_ops.py."""
16
17import os
18import re
19
20from absl.testing import parameterized
21import numpy as np
22from tensorflow.python.checkpoint import checkpoint as tracking_util
23from tensorflow.python.eager import context
24from tensorflow.python.eager import def_function
25from tensorflow.python.framework import config
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import test_util
31from tensorflow.python.kernel_tests.random import util as random_test_util
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import gen_random_ops
34from tensorflow.python.ops import gen_stateful_random_ops
35from tensorflow.python.ops import logging_ops
36from tensorflow.python.ops import stateful_random_ops as random
37from tensorflow.python.ops import variables
38from tensorflow.python.platform import test
39
40
41g_seeded = None
42g_unseeded = None
43
44
45GPU_FLOATS = [dtypes.float16, dtypes.float32, dtypes.float64]
46CPU_FLOATS = GPU_FLOATS + [dtypes.bfloat16]
47FLOATS = GPU_FLOATS
48INTS = [dtypes.int32, dtypes.int64]
49
50
51class StatefulRandomOpsTest(test.TestCase, parameterized.TestCase):
52
53  def setUp(self):
54    super(StatefulRandomOpsTest, self).setUp()
55    physical_devices = config.list_physical_devices("CPU")
56    config.set_logical_device_configuration(
57        physical_devices[0], [
58            context.LogicalDeviceConfiguration(),
59            context.LogicalDeviceConfiguration()
60        ])
61
62  def testCreateRNGStateIntSeed(self):
63    """Tests `create_rng_state` when `seed` is int."""
64    # using leading 'F' to test overflow tolerance
65    state = random.create_rng_state(0xFFFF222233334444FFAA666677778888,
66                                    random.RNG_ALG_PHILOX)
67    self.assertAllEqual(
68        list(map(random._uint_to_int,
69                 [0xFFAA666677778888, 0xFFFF222233334444] +
70                 [0] * (random.PHILOX_STATE_SIZE - 2))),
71        state)
72
73  def assertAllDifferent(self, tensors):
74    """Checks that there are no duplicate elements anywhere among the tensors.
75
76    Args:
77      tensors: a list of tensors. They can have different shapes.
78    """
79    tensors = [array_ops.reshape(t, shape=[-1]) for t in tensors]
80    ls = array_ops.concat(tensors, axis=0).numpy().tolist()
81    self.assertAllEqual(len(ls), len(set(ls)))
82
83  @test_util.run_v2_only
84  def testNonDeterministicInts(self):
85    """Tests that non_deterministic_ints returns different results every time.
86
87    This test is flaky, but with very low probability of failing.
88    """
89    shape = [2, 3]
90    dtype = dtypes.int64
91    a = random.non_deterministic_ints(shape=shape, dtype=dtype)
92    self.assertAllEqual(shape, a.shape)
93    self.assertEqual(dtype, a.dtype)
94    b = random.non_deterministic_ints(shape, dtype=dtype)
95    self.assertAllDifferent([a, b])
96
97  @test_util.run_v2_only
98  def testBatchSeeds(self):
99    """Test for batch seeds.
100    """
101    shape = [2, 3]
102    count = 6
103    gen = random.Generator.from_seed(1234)
104    keys1 = gen._make_int64_keys(shape=shape)
105    keys2 = gen._make_int64_keys(shape=shape)
106    self.assertAllDifferent([keys1, keys2])
107    seeds1 = gen.make_seeds(count=count)
108    seeds2 = gen.make_seeds(count=count)
109    self.assertAllDifferent([seeds1[0, :], seeds2[0, :]])
110    gens = gen.split(count=count)
111    self.assertAllEqual(count, len(gens))
112    randoms = [g.uniform_full_int(shape=shape, dtype=dtypes.int32)
113               for g in gens]
114    self.assertAllDifferent(randoms)
115    # Tests graph mode.
116    @def_function.function
117    def f():
118      return gen.make_seeds(count=count)
119    for _ in range(3):
120      f()
121
122  def assertRegex(self, pattern, text):
123    self.assertTrue(
124        re.search(pattern, text),
125        "Can't find pattern '%s' in text '%s'" % (pattern, text))
126
127  @test_util.run_v2_only
128  @test_util.run_cuda_only
129  def testCrossDeviceSplit(self):
130    """Tests that a CPU RNG can split into RNGs on GPU.
131    """
132    with ops.device("/device:CPU:0"):
133      gen = random.Generator.from_seed(1234)  # gen is on CPU
134      self.assertRegex("CPU", gen.state.device)
135    with ops.device(test_util.gpu_device_name()):
136      gens = gen.split(count=10)  # gens are on GPU
137      self.assertRegex("GPU", gens[0].state.device)
138
139  @test_util.run_v2_only
140  def testSplitInFunction(self):
141    g = random.Generator.from_seed(1)
142    new_g = [None]  # using list as mutable cells
143    @def_function.function
144    def f():
145      if new_g[0] is None:  # avoid creating variable in 2nd trace
146        new_g[0] = g.split(2)
147      return [new_g[0][i].normal([]) for i in range(2)]
148    f()
149
150  def testFnVars(self):
151    """Tests that RNG variable is added to ConcreteFunction.variables."""
152    rng = random.Generator.from_seed(0)
153    @def_function.function
154    def f():
155      return rng.normal([])
156
157    concrete = f.get_concrete_function()
158    self.assertIn(rng.state, concrete.variables)
159
160  @test_util.run_v2_only
161  def testReset(self):
162    shape = [2, 3]
163    gen = random.Generator.from_seed(0)
164    for resetter in [
165        lambda g: g.reset(state=[1, 2, 3]),
166        lambda g: g.reset_from_seed(1234),
167        lambda g: g.reset_from_key_counter(key=1, counter=[2, 3]),
168    ]:
169      resetter(gen)
170      expected_normal = gen.normal(shape)
171      @def_function.function
172      def f(resetter):
173        resetter(gen)
174        return gen.normal(shape)
175      def check_results(expected_normal, v):
176        self.assertAllEqual(expected_normal, v)
177      check_results(expected_normal, f(resetter))
178      check_results(expected_normal, f(resetter))
179
180  @test_util.run_v2_only
181  def testGeneratorCreation(self):
182    """Tests generator creation, in both eager and tf.function.
183
184    The interaction between Generator creation and defun should be the same as
185    tf.Variable.
186    """
187    shape = [2, 3]
188    alg = random.RNG_ALG_PHILOX
189    for constructor in [
190        lambda: random.Generator(state=[1, 2, 3], alg=alg),
191        lambda: random.Generator.from_seed(1234),
192        lambda: random.Generator.from_key_counter(  # pylint: disable=g-long-lambda
193            key=1, counter=[2, 3], alg=alg),
194    ]:
195      gen = constructor()
196      # Tests tf.function
197      expected_normal1 = gen.normal(shape)
198      expected_normal2 = gen.normal(shape)
199      global g_seeded
200      g_seeded = None
201      @def_function.function
202      def f(constructor):
203        global g_seeded
204        # defun'ed function should only create variables once
205        if g_seeded is None:
206          g_seeded = constructor()
207        return g_seeded.normal(shape)
208      def check_results(expected_normal, v):
209        self.assertAllEqual(expected_normal, v)
210      check_results(expected_normal1, f(constructor))
211      check_results(expected_normal2, f(constructor))
212
213  @test_util.run_v2_only
214  def testCreateGeneratorFromSymbolic(self):
215    g = [None, None, None]  # using list as mutable cells
216    @def_function.function
217    def f(scalar, vector2, vector3):
218      if g[0] is None:  # avoid creating variable in 2nd trace
219        g[0] = random.Generator.from_seed(scalar)
220        g[0].reset_from_seed(scalar)  # also test reset
221        g[1] = random.Generator.from_state(vector3, random.RNG_ALG_PHILOX)
222        g[1].reset(vector3)
223        g[2] = random.Generator.from_key_counter(
224            scalar, vector2, random.RNG_ALG_PHILOX)
225        g[2].reset_from_key_counter(scalar, vector2)
226      return [g[i].normal([]) for i in range(3)]
227    args = (1, [2, 2], [3, 3, 3])
228    args = [constant_op.constant(v) for v in args]
229    f(*args)
230
231  @parameterized.parameters([
232      ("philox", random.RNG_ALG_PHILOX, random.Algorithm.PHILOX),
233      ("threefry", random.RNG_ALG_THREEFRY, random.Algorithm.THREEFRY)])
234  @test_util.run_v2_only
235  def testAlg(self, name, int_id, enum_id):
236    g_by_name = random.Generator.from_seed(1234, name)
237    g_by_int = random.Generator.from_seed(1234, int_id)
238    g_by_enum = random.Generator.from_seed(1234, enum_id)
239    self.assertEqual(g_by_name.algorithm, g_by_int.algorithm)
240    self.assertEqual(g_by_name.algorithm, g_by_enum.algorithm)
241
242  @test_util.run_v2_only
243  def testGeneratorCreationWithVar(self):
244    """Tests creating generator with a variable.
245    """
246    alg = random.RNG_ALG_PHILOX
247    state = [1, 2, 3]
248    var = variables.Variable(state, dtype=random.STATE_TYPE)
249    g = random.Generator(state=state, alg=alg)
250    g_var = random.Generator(state=var, alg=alg)
251    shape = [2, 3]
252    g.normal(shape)
253    g_var.normal(shape)
254    self.assertAllEqual(g.state.read_value(), var.read_value())
255
256  @test_util.run_v2_only
257  def testGeneratorCreationUnseeded(self):
258    """Tests generator creation, the unseeded case."""
259    shape = [2, 3]
260    global g_unseeded
261    g_unseeded = None
262    @def_function.function
263    def f():
264      global g_unseeded
265      # defun'ed function should only create variables once
266      if g_unseeded is None:
267        g_unseeded = random.Generator.from_non_deterministic_state()
268      return g_unseeded.normal(shape)
269    self.assertAllEqual(shape, f().shape)
270
271  @test_util.run_v2_only
272  def testGeneratorCopy(self):
273    """Tests copying a generator."""
274    g = random.Generator.from_seed(0)
275    g_copy = random.Generator(g)
276    self.assertAllEqual(g.algorithm, g_copy.algorithm)
277    self.assertAllEqual(g.state.read_value(), g_copy.state.read_value())
278    # Tests tf.function
279    global g_seeded
280    g_seeded = None
281    # Do the same in tf.function
282    @def_function.function
283    def f():
284      global g_seeded
285      # defun'ed function should only create variables once
286      if g_seeded is None:
287        g_seeded = random.Generator(g)
288      self.assertAllEqual(g.algorithm, g_seeded.algorithm)
289      self.assertAllEqual(g.state.read_value(), g_seeded.state.read_value())
290    f()
291
292  @test_util.run_v1_only(
293      ("This test is specifically for checking TF1 compatibility. "
294       "It cannot run under TF2."))
295  def testTF1(self):
296    seed = 1234
297    shape = [2, 3]
298    expected_normal1 = constant_op.constant(
299        [[0.9356609, 1.0854305, -0.93788373],
300         [-0.50615472, 1.31697023, 0.71375787]], dtype=dtypes.float32)
301    expected_normal2 = constant_op.constant(
302        [[-0.3964749, 0.8369565, -0.30946946],
303         [1.1206646, 1.00852597, -0.10185789]], dtype=dtypes.float32)
304    with self.cached_session() as sess:
305      gen1 = random.Generator.from_seed(seed)
306      gen2 = random.Generator.from_non_deterministic_state()
307      sess.run((gen1.state.initializer, gen2.state.initializer))
308      r1 = gen1.normal(shape, dtype=dtypes.float32)
309      r2 = gen2.normal(shape, dtype=dtypes.float32)
310      def f():
311        return sess.run((r1, r2))
312      def check_results(expected_normal, v1, v2):
313        self.assertAllClose(expected_normal, v1, rtol=1e-5, atol=1e-5)
314        self.assertAllEqual(shape, v2.shape)
315      check_results(expected_normal1, *f())
316      check_results(expected_normal2, *f())
317
318  @test_util.run_v2_only
319  @test_util.also_run_as_tf_function
320  def testEagerAndDefun(self):
321    """A simple test to make sure the op works in eager and defunned mode."""
322    random.get_global_generator().normal((3,))
323
324  @test_util.run_v2_only
325  def testOpSeedSelectionAfterSetSeed(self):
326    """Tests that op-seed selection is reset after reseting global generator.
327
328    Fixing GitHub issue 9171:
329    https://github.com/tensorflow/tensorflow/issues/9171
330    """
331    shape = (3,)
332    random.get_global_generator().reset_from_seed(1)
333    a = random.get_global_generator().normal(shape)
334    random.get_global_generator().reset_from_seed(1)
335    b = random.get_global_generator().normal(shape)
336    self.assertAllEqual(a, b)
337
338    # Now do the above again using accelerated ('defun'ed) computation
339    @def_function.function
340    def f():
341      return random.get_global_generator().normal(shape)
342
343    random.get_global_generator().reset_from_seed(1)
344    c = f()
345    random.get_global_generator().reset_from_seed(1)
346    d = f()
347    self.assertAllEqual(c, d)
348    self.assertAllEqual(a, c)
349
350  @test_util.run_v2_only
351  def testOpSeedSelectionNotSensitive(self):
352    """Test that op-seed selection is not sensitive to trivial changes.
353
354    Test that op-seed selection is not sensitive to trivial computation
355    (i.e. graph) changes.
356
357    Fixing b/32087099
358    """
359    def f(include_print):
360      shape = constant_op.constant([5])
361      if include_print:
362        shape = logging_ops.Print(shape, [shape])
363      return random.get_global_generator().normal(shape)
364
365    def compare(fst_includes_print, snd_includes_print):
366      random.get_global_generator().reset_from_seed(50)
367      fst = f(fst_includes_print)
368      random.get_global_generator().reset_from_seed(50)
369      snd = f(snd_includes_print)
370      self.assertAllEqual(fst, snd)
371      # Now do the above again using accelerated (defunned) 'f'.
372      # Running 'f' with two different Boolean arguments should cause
373      # two different graphs to be generated, hence demonstrating the
374      # insensitivity to graph changes.
375      f_acc = def_function.function(f)
376      random.get_global_generator().reset_from_seed(50)
377      fst = f_acc(fst_includes_print)
378      random.get_global_generator().reset_from_seed(50)
379      snd = f_acc(snd_includes_print)
380      self.assertAllEqual(fst, snd)
381
382    compare(False, False)
383    compare(True, True)
384    compare(True, False)
385
386  @test_util.run_v2_only
387  def testKey(self):
388    key = 1234
389    gen = random.Generator(state=[0, 0, key], alg=random.RNG_ALG_PHILOX)
390    got = gen.key
391    self.assertAllEqual(key, got)
392    @def_function.function
393    def f():
394      return gen.key
395    got = f()
396    self.assertAllEqual(key, got)
397
398  @test_util.run_v2_only
399  def testSkip(self):
400    key = 1234
401    counter = 5678
402    gen = random.Generator(state=[counter, 0, key], alg=random.RNG_ALG_PHILOX)
403    delta = 432
404    gen.skip(delta)
405    new_counter = gen.state[0]
406    self.assertAllEqual(counter + delta * 256, new_counter)
407
408  def _sameAsOldRandomOps(self, device, floats):
409    def compare(dtype, old, new):
410      seed1, seed2 = 79, 25
411      # note how the two seeds for the old op correspond to the seed for the new
412      # op
413      with ops.device(device):
414        gen = random.Generator(state=[0, seed2, seed1],
415                               alg=random.RNG_ALG_PHILOX)
416
417      # create a graph for the old op in order to call it many times
418      @def_function.function
419      def run_old():
420        with ops.device(device):
421          return old(dtype, seed1, seed2)
422
423      def run_new():
424        with ops.device(device):
425          return new(dtype, gen)
426
427      for _ in range(5):
428        self.assertAllEqual(run_old(), run_new())
429
430    shape = constant_op.constant([4, 7])
431    minval = 128
432    maxval = 256
433
434    # passing `dtype` around to compress go/gpylint-faq#cell-var-from-loop and
435    # go/gpylint-faq#undefined-loop-variable
436    def old_normal(dtype, seed1, seed2):
437      return gen_random_ops.random_standard_normal(
438          shape, dtype=dtype, seed=seed1, seed2=seed2)
439    def new_normal(dtype, gen):
440      return gen._standard_normal(shape, dtype=dtype)
441    def old_truncated_normal(dtype, seed1, seed2):
442      return gen_random_ops.truncated_normal(
443          shape, dtype=dtype, seed=seed1, seed2=seed2)
444    def new_truncated_normal(dtype, gen):
445      return gen._truncated_normal(shape, dtype=dtype)
446    def old_uniform_int(dtype, seed1, seed2):
447      minval2 = constant_op.constant(minval, dtype=dtype)
448      maxval2 = constant_op.constant(maxval, dtype=dtype)
449      return gen_random_ops.random_uniform_int(
450          shape, minval=minval2, maxval=maxval2, seed=seed1, seed2=seed2)
451    def new_uniform_int(dtype, gen):
452      return gen.uniform(shape, minval=minval, maxval=maxval, dtype=dtype)
453    def old_uniform(dtype, seed1, seed2):
454      return gen_random_ops.random_uniform(
455          shape, dtype=dtype, seed=seed1, seed2=seed2)
456    def new_uniform(dtype, gen):
457      return gen._uniform(shape, dtype=dtype)
458
459    for dtype in floats:
460      compare(dtype, old_normal, new_normal)
461      compare(dtype, old_truncated_normal, new_truncated_normal)
462      compare(dtype, old_uniform, new_uniform)
463    for dtype in INTS:
464      compare(dtype, old_uniform_int, new_uniform_int)
465
466  @test_util.run_v2_only
467  def testSameAsOldRandomOpsCPU(self):
468    """Tests that the generated numbers are the same as the old random_ops.py.
469
470    The CPU version.
471    """
472    self._sameAsOldRandomOps("/device:CPU:0", CPU_FLOATS)
473
474  @test_util.run_v2_only
475  @test_util.run_cuda_only
476  def testSameAsOldRandomOpsGPU(self):
477    """Tests that the generated numbers are the same as the old random_ops.py.
478
479    The GPU version.
480    """
481    self._sameAsOldRandomOps(test_util.gpu_device_name(), GPU_FLOATS)
482
483  @parameterized.parameters(INTS + [dtypes.uint32, dtypes.uint64])
484  @test_util.run_v2_only
485  @test_util.run_cuda_only
486  def testGPUEqualsCPU(self, dtype):
487    """Tests that GPU and CPU generate the same integer outputs."""
488    seed = 1234
489    shape = [315, 49]
490    with ops.device("/device:CPU:0"):
491      cpu = random.Generator.from_seed(seed).uniform_full_int(
492          shape=shape, dtype=dtype)
493    with ops.device(test_util.gpu_device_name()):
494      gpu = random.Generator.from_seed(seed).uniform_full_int(
495          shape=shape, dtype=dtype)
496    self.assertAllEqual(cpu, gpu)
497
498  @parameterized.parameters(FLOATS + INTS)
499  @test_util.run_v2_only
500  def testUniformIsInRange(self, dtype):
501    minval = 2
502    maxval = 33
503    size = 1000
504    gen = random.Generator.from_seed(1234)
505    x = gen.uniform(
506        shape=[size], dtype=dtype, minval=minval, maxval=maxval).numpy()
507    self.assertTrue(np.all(x >= minval))
508    self.assertTrue(np.all(x < maxval))
509
510  @parameterized.parameters(FLOATS)
511  @test_util.run_v2_only
512  def testNormalIsFinite(self, dtype):
513    gen = random.Generator.from_seed(1234)
514    x = gen.normal(shape=[10000], dtype=dtype).numpy()
515    self.assertTrue(np.all(np.isfinite(x)))
516
517  @parameterized.parameters(FLOATS + INTS)
518  @test_util.run_v2_only
519  def testDistributionOfUniform(self, dtype):
520    """Use Pearson's Chi-squared test to test for uniformity."""
521    n = 1000
522    seed = 12
523    gen = random.Generator.from_seed(seed)
524    maxval = 1
525    if dtype.is_integer:
526      maxval = 100
527    x = gen.uniform(shape=[n], maxval=maxval, dtype=dtype).numpy()
528    if maxval > 1:
529      # Normalize y to range [0, 1).
530      x = x.astype(float) / maxval
531    # Tests that the values are distributed amongst 10 bins with equal
532    # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
533    # p=0.05. This test is probabilistic and would be flaky if the random
534    # seed were not fixed.
535    val = random_test_util.chi_squared(x, 10)
536    self.assertLess(val, 16.92)
537
538  @parameterized.parameters(FLOATS)
539  @test_util.run_v2_only
540  def testDistributionOfNormal(self, dtype):
541    """Use Anderson-Darling test to test distribution appears normal."""
542    n = 1000
543    gen = random.Generator.from_seed(1234)
544    x = gen.normal(shape=[n], dtype=dtype).numpy()
545    # The constant 2.492 is the 5% critical value for the Anderson-Darling
546    # test where the mean and variance are known. This test is probabilistic
547    # so to avoid flakiness the seed is fixed.
548    self.assertLess(
549        random_test_util.anderson_darling(x.astype(float)), 2.492)
550
551  @test_util.run_v2_only
552  def testErrors(self):
553    """Tests that proper errors are raised.
554    """
555    shape = [2, 3]
556    gen = random.Generator.from_seed(1234)
557    with self.assertRaisesWithPredicateMatch(
558        errors.InvalidArgumentError,
559        r"must have shape \[\], not"):
560      gen_stateful_random_ops.stateful_standard_normal_v2(
561          gen.state.handle, [0, 0], shape)
562    with self.assertRaisesWithPredicateMatch(
563        errors.InvalidArgumentError,
564        r"must have shape \[\], not"):
565      gen_stateful_random_ops.rng_skip(
566          gen.state.handle, gen.algorithm, [0, 0])
567    with self.assertRaisesWithPredicateMatch(
568        TypeError, "EagerTensor of dtype int64"):
569      gen_stateful_random_ops.stateful_standard_normal_v2(
570          gen.state.handle, 1.1, shape)
571    with self.assertRaisesWithPredicateMatch(
572        errors.InvalidArgumentError,
573        "Unsupported algorithm id"):
574      gen_stateful_random_ops.stateful_standard_normal_v2(
575          gen.state.handle, 123, shape)
576    var = variables.Variable([0, 0], dtype=dtypes.int32)
577    with self.assertRaisesWithPredicateMatch(
578        errors.InvalidArgumentError,
579        "dtype of RNG state variable must be int64, not"):
580      gen_stateful_random_ops.stateful_standard_normal_v2(
581          var.handle, random.RNG_ALG_PHILOX, shape)
582    var = variables.Variable([[0]], dtype=dtypes.int64)
583    with self.assertRaisesWithPredicateMatch(
584        errors.InvalidArgumentError,
585        "RNG state must have one and only one dimension, not"):
586      gen_stateful_random_ops.stateful_standard_normal_v2(
587          var.handle, random.RNG_ALG_PHILOX, shape)
588    var = variables.Variable([0], dtype=dtypes.int64)
589    with self.assertRaisesWithPredicateMatch(
590        errors.InvalidArgumentError,
591        "For the Philox algorithm, the size of state must be at least"):
592      gen_stateful_random_ops.stateful_standard_normal_v2(
593          var.handle, random.RNG_ALG_PHILOX, shape)
594    with self.assertRaisesWithPredicateMatch(
595        ValueError,
596        "minval must be a scalar; got a tensor of shape "):
597      @def_function.function
598      def f():
599        gen.uniform(shape=shape, minval=array_ops.zeros(shape, "int32"),
600                    maxval=100, dtype="int32")
601      f()
602    with self.assertRaisesWithPredicateMatch(
603        ValueError,
604        "maxval must be a scalar; got a tensor of shape "):
605      @def_function.function
606      def f2():
607        gen.uniform(
608            shape=shape, minval=0, maxval=array_ops.ones(shape, "int32") * 100,
609            dtype="int32")
610      f2()
611
612  @test_util.run_v2_only
613  def testGetGlobalGeneratorWithXla(self):
614    """Demonstrates using the global generator with XLA."""
615    # This test was passing before because soft placement silently picked the
616    # CPU kernel.
617    # TODO(wangpeng): Remove this skip
618    self.skipTest("NonDeterministicInts lacks XLA kernel.")
619
620    if not config.list_physical_devices("XLA_CPU"):
621      self.skipTest("No XLA_CPU device available.")
622
623    random.set_global_generator(None)
624
625    @def_function.function(jit_compile=True)
626    def make_seed():
627      generator = random.get_global_generator()
628      state = array_ops.identity(generator.state, name="state")
629      return generator.uniform_full_int((2,), dtypes.int32, name="seed"), state
630
631    with ops.device("/device:XLA_CPU:0"):
632      seed, state = make_seed()
633      self.assertTrue(np.all(np.isfinite(seed.numpy())))
634      random.get_global_generator().reset(state)
635      self.assertAllEqual(make_seed()[0], seed)
636
637  @test_util.run_v2_only
638  def testSetGlobalGeneratorBadWithDefun(self):
639    """Demonstrates set_global_generator does not affect compiled tf.function."""
640    shape = (3,)
641
642    @def_function.function
643    def f():
644      return random.get_global_generator().normal(shape)
645
646    random.set_global_generator(random.Generator.from_seed(50))
647    samples = f()
648    # Resetting global generator has no effect to the compiled tf.function.
649    random.set_global_generator(random.Generator.from_seed(50))
650    # New samples are returned.
651    self.assertNotAllEqual(samples, f())
652
653  @test_util.run_v2_only
654  def testFunctionArg(self):
655    """Tests that RNG can be used as tf.function's argument.
656    """
657    shape = [2, 3]
658    @def_function.function
659    def f(gen):
660      return gen.normal(shape)
661    g1 = random.Generator.from_seed(1)
662    g2 = random.Generator.from_seed(1)
663    res1 = f(g1)
664    res2 = g2.normal(shape)
665    self.assertAllEqual(res1, res2)
666    self.assertAllEqual(g1.state.read_value(), g2.state.read_value())
667
668  @test_util.run_v2_only
669  def testUniformFullInt(self):
670    """Tests full-range int uniform.
671    """
672    shape = [3, 4]
673    dtype = dtypes.int32
674    g = random.Generator.from_seed(1)
675    r1 = g.uniform(shape=shape, dtype=dtype, minval=None)
676    g = random.Generator.from_seed(1)
677    r2 = g.uniform_full_int(shape=shape, dtype=dtype)
678    self.assertAllEqual(r1, r2)
679
680  @test_util.run_v2_only
681  def testRestore(self):
682    """Tests save and restore.
683    """
684    fname = os.path.join(self.get_temp_dir(), "checkpoint")
685    g = random.Generator.from_seed(1)
686    cp = tracking_util.Checkpoint(g=g)
687    def write_restore_compare():
688      cp.write(fname)
689      r1 = g.uniform([], dtype=dtypes.uint32, minval=None)
690      cp.restore(fname)
691      r2 = g.uniform([], dtype=dtypes.uint32, minval=None)
692      self.assertAllEqual(r1, r2)
693    # Run multiple times so that cp.write is called in various RNG states
694    for _ in range(2):
695      write_restore_compare()
696
697  @test_util.run_v2_only
698  def testDeterministicOpsErrors(self):
699    try:
700      config.enable_op_determinism()
701      random.set_global_generator(None)
702      with self.assertRaisesWithPredicateMatch(
703          RuntimeError,
704          '"get_global_generator" cannot be called if determinism is enabled'):
705        random.get_global_generator()
706      random.set_global_generator(random.Generator.from_seed(50))
707      random.get_global_generator()
708      with self.assertRaisesWithPredicateMatch(
709          RuntimeError,
710          '"from_non_deterministic_state" cannot be called when determinism '
711          "is enabled."):
712        random.Generator.from_non_deterministic_state()
713    finally:
714      config.disable_op_determinism()
715
716
717if __name__ == "__main__":
718  config.set_soft_device_placement(False)
719  test.main()
720