1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for utilities working with arbitrarily nested structures.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.data.util import random_seed as data_random_seed 22from tensorflow.python.eager import context 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import random_seed 27from tensorflow.python.framework import test_util 28from tensorflow.python.platform import test 29 30 31class RandomSeedTest(test.TestCase): 32 33 @test_util.run_in_graph_and_eager_modes 34 def testRandomSeed(self): 35 zero_t = constant_op.constant(0, dtype=dtypes.int64, name='zero') 36 one_t = constant_op.constant(1, dtype=dtypes.int64, name='one') 37 intmax_t = constant_op.constant( 38 2**31 - 1, dtype=dtypes.int64, name='intmax') 39 test_cases = [ 40 # Each test case is a tuple with input to get_seed: 41 # (input_graph_seed, input_op_seed) 42 # and output from get_seed: 43 # (output_graph_seed, output_op_seed) 44 ((None, None), (0, 0)), 45 ((None, 1), (random_seed.DEFAULT_GRAPH_SEED, 1)), 46 ((1, 1), (1, 1)), 47 ((0, 0), (0, 2**31 - 1)), # Avoid nondeterministic (0, 0) output 48 ((2**31 - 1, 0), (0, 2**31 - 1)), # Don't wrap to (0, 0) either 49 ((0, 2**31 - 1), (0, 2**31 - 1)), # Wrapping for the other argument 50 # Once more, with tensor-valued arguments 51 ((None, one_t), (random_seed.DEFAULT_GRAPH_SEED, 1)), 52 ((1, one_t), (1, 1)), 53 ((0, zero_t), (0, 2**31 - 1)), # Avoid nondeterministic (0, 0) output 54 ((2**31 - 1, zero_t), (0, 2**31 - 1)), # Don't wrap to (0, 0) either 55 ((0, intmax_t), (0, 2**31 - 1)), # Wrapping for the other argument 56 ] 57 for tc in test_cases: 58 tinput, toutput = tc[0], tc[1] 59 random_seed.set_random_seed(tinput[0]) 60 g_seed, op_seed = data_random_seed.get_seed(tinput[1]) 61 g_seed = self.evaluate(g_seed) 62 op_seed = self.evaluate(op_seed) 63 msg = 'test_case = {0}, got {1}, want {2}'.format( 64 tinput, (g_seed, op_seed), toutput) 65 self.assertEqual((g_seed, op_seed), toutput, msg=msg) 66 random_seed.set_random_seed(None) 67 68 if not context.executing_eagerly(): 69 random_seed.set_random_seed(1) 70 tinput = (1, None) 71 toutput = (1, ops.get_default_graph()._last_id) # pylint: disable=protected-access 72 random_seed.set_random_seed(tinput[0]) 73 g_seed, op_seed = data_random_seed.get_seed(tinput[1]) 74 g_seed = self.evaluate(g_seed) 75 op_seed = self.evaluate(op_seed) 76 msg = 'test_case = {0}, got {1}, want {2}'.format(1, (g_seed, op_seed), 77 toutput) 78 self.assertEqual((g_seed, op_seed), toutput, msg=msg) 79 random_seed.set_random_seed(None) 80 81 82if __name__ == '__main__': 83 test.main() 84