1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for stateless random ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23import numpy as np 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import random_seed 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import random_ops 30from tensorflow.python.ops import stateless_random_ops as stateless 31from tensorflow.python.platform import test 32 33 34def invert_philox(key, value): 35 """Invert the Philox bijection.""" 36 key = np.array(key, dtype=np.uint32) 37 value = np.array(value, dtype=np.uint32) 38 step = np.array([0x9E3779B9, 0xBB67AE85], dtype=np.uint32) 39 for n in range(10)[::-1]: 40 key0, key1 = key + n * step 41 v0 = value[3] * 0x991a7cdb & 0xffffffff 42 v2 = value[1] * 0x6d7cae67 & 0xffffffff 43 hi0 = v0 * 0xD2511F53 >> 32 44 hi1 = v2 * 0xCD9E8D57 >> 32 45 v1 = hi1 ^ value[0] ^ key0 46 v3 = hi0 ^ value[2] ^ key1 47 value = v0, v1, v2, v3 48 return np.array(value) 49 50 51class StatelessOpsTest(test.TestCase): 52 53 def _test_match(self, cases): 54 # Stateless ops should be the same as stateful ops on the first call 55 # after seed scrambling. 56 cases = tuple(cases) 57 key = 0x3ec8f720, 0x02461e29 58 for seed in (7, 17), (11, 5), (2, 3): 59 preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64) 60 preseed = preseed[::2] | preseed[1::2] << 32 61 random_seed.set_random_seed(seed[0]) 62 with test_util.use_gpu(): 63 for stateless_op, stateful_op in cases: 64 stateful = stateful_op(seed=seed[1]) 65 pure = stateless_op(seed=preseed) 66 self.assertAllEqual(self.evaluate(stateful), self.evaluate(pure)) 67 68 def _test_determinism(self, cases): 69 # Stateless values should be equal iff the seeds are equal (roughly) 70 cases = tuple(cases) 71 with self.test_session(use_gpu=True): 72 for seed_type in [dtypes.int32, dtypes.int64]: 73 seed_t = array_ops.placeholder(seed_type, shape=[2]) 74 seeds = [(x, y) for x in range(5) for y in range(5)] * 3 75 for stateless_op, _ in cases: 76 pure = stateless_op(seed=seed_t) 77 values = [ 78 (seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds 79 ] 80 for s0, v0 in values: 81 for s1, v1 in values: 82 self.assertEqual(s0 == s1, np.all(v0 == v1)) 83 84 def _float_cases(self, shape_dtypes=(None,)): 85 float_cases = ( 86 # Uniform distribution, with and without range 87 (stateless.stateless_random_uniform, random_ops.random_uniform, {}), 88 (stateless.stateless_random_uniform, random_ops.random_uniform, 89 dict(minval=2.2, maxval=7.1)), 90 # Normal distribution, with and without mean+stddev 91 (stateless.stateless_random_normal, random_ops.random_normal, {}), 92 (stateless.stateless_random_normal, random_ops.random_normal, 93 dict(mean=2, stddev=3)), 94 # Truncated normal distribution, with and without mean+stddev 95 (stateless.stateless_truncated_normal, random_ops.truncated_normal, {}), 96 (stateless.stateless_truncated_normal, random_ops.truncated_normal, 97 dict(mean=3, stddev=4)), 98 ) 99 for dtype in dtypes.float16, dtypes.float32, dtypes.float64: 100 for shape_dtype in shape_dtypes: 101 for shape in (), (3,), (2, 5): 102 if shape_dtype is not None: 103 shape = constant_op.constant(shape, dtype=shape_dtype) 104 for stateless_op, stateful_op, kwds in float_cases: 105 kwds = dict(shape=shape, dtype=dtype, **kwds) 106 yield (functools.partial(stateless_op, **kwds), 107 functools.partial(stateful_op, **kwds)) 108 109 def _int_cases(self, shape_dtypes=(None,)): 110 for shape_dtype in shape_dtypes: 111 for shape in (), (3,), (2, 5): 112 if shape_dtype is not None: 113 shape = constant_op.constant(shape, dtype=shape_dtype) 114 for dtype in dtypes.int32, dtypes.int64: 115 kwds = dict(minval=2, maxval=11111, dtype=dtype, shape=shape) 116 yield (functools.partial(stateless.stateless_random_uniform, **kwds), 117 functools.partial(random_ops.random_uniform, **kwds)) 118 119 def _multinomial_cases(self): 120 num_samples = 10 121 for logits_dtype in np.float16, np.float32, np.float64: 122 for output_dtype in dtypes.int32, dtypes.int64: 123 for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2], 124 [0.25, 0.75]]): 125 kwds = dict( 126 logits=constant_op.constant(logits, dtype=logits_dtype), 127 num_samples=num_samples, 128 output_dtype=output_dtype) 129 yield (functools.partial(stateless.stateless_multinomial, **kwds), 130 functools.partial(random_ops.multinomial, **kwds)) 131 132 @test_util.run_deprecated_v1 133 def testMatchFloat(self): 134 self._test_match(self._float_cases()) 135 136 @test_util.run_deprecated_v1 137 def testMatchInt(self): 138 self._test_match(self._int_cases()) 139 140 @test_util.run_deprecated_v1 141 def testMatchMultinomial(self): 142 self._test_match(self._multinomial_cases()) 143 144 @test_util.run_deprecated_v1 145 def testDeterminismFloat(self): 146 self._test_determinism( 147 self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64))) 148 149 @test_util.run_deprecated_v1 150 def testDeterminismInt(self): 151 self._test_determinism( 152 self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64))) 153 154 @test_util.run_deprecated_v1 155 def testDeterminismMultinomial(self): 156 self._test_determinism(self._multinomial_cases()) 157 158 159if __name__ == '__main__': 160 test.main() 161