• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for stateless random ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23import numpy as np
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import random_seed
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import random_ops
30from tensorflow.python.ops import stateless_random_ops as stateless
31from tensorflow.python.platform import test
32
33
34def invert_philox(key, value):
35  """Invert the Philox bijection."""
36  key = np.array(key, dtype=np.uint32)
37  value = np.array(value, dtype=np.uint32)
38  step = np.array([0x9E3779B9, 0xBB67AE85], dtype=np.uint32)
39  for n in range(10)[::-1]:
40    key0, key1 = key + n * step
41    v0 = value[3] * 0x991a7cdb & 0xffffffff
42    v2 = value[1] * 0x6d7cae67 & 0xffffffff
43    hi0 = v0 * 0xD2511F53 >> 32
44    hi1 = v2 * 0xCD9E8D57 >> 32
45    v1 = hi1 ^ value[0] ^ key0
46    v3 = hi0 ^ value[2] ^ key1
47    value = v0, v1, v2, v3
48  return np.array(value)
49
50
51class StatelessOpsTest(test.TestCase):
52
53  def _test_match(self, cases):
54    # Stateless ops should be the same as stateful ops on the first call
55    # after seed scrambling.
56    cases = tuple(cases)
57    key = 0x3ec8f720, 0x02461e29
58    for seed in (7, 17), (11, 5), (2, 3):
59      preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
60      preseed = preseed[::2] | preseed[1::2] << 32
61      random_seed.set_random_seed(seed[0])
62      with test_util.use_gpu():
63        for stateless_op, stateful_op in cases:
64          stateful = stateful_op(seed=seed[1])
65          pure = stateless_op(seed=preseed)
66          self.assertAllEqual(self.evaluate(stateful), self.evaluate(pure))
67
68  def _test_determinism(self, cases):
69    # Stateless values should be equal iff the seeds are equal (roughly)
70    cases = tuple(cases)
71    with self.test_session(use_gpu=True):
72      for seed_type in [dtypes.int32, dtypes.int64]:
73        seed_t = array_ops.placeholder(seed_type, shape=[2])
74        seeds = [(x, y) for x in range(5) for y in range(5)] * 3
75        for stateless_op, _ in cases:
76          pure = stateless_op(seed=seed_t)
77          values = [
78              (seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds
79          ]
80          for s0, v0 in values:
81            for s1, v1 in values:
82              self.assertEqual(s0 == s1, np.all(v0 == v1))
83
84  def _float_cases(self, shape_dtypes=(None,)):
85    float_cases = (
86        # Uniform distribution, with and without range
87        (stateless.stateless_random_uniform, random_ops.random_uniform, {}),
88        (stateless.stateless_random_uniform, random_ops.random_uniform,
89         dict(minval=2.2, maxval=7.1)),
90        # Normal distribution, with and without mean+stddev
91        (stateless.stateless_random_normal, random_ops.random_normal, {}),
92        (stateless.stateless_random_normal, random_ops.random_normal,
93         dict(mean=2, stddev=3)),
94        # Truncated normal distribution, with and without mean+stddev
95        (stateless.stateless_truncated_normal, random_ops.truncated_normal, {}),
96        (stateless.stateless_truncated_normal, random_ops.truncated_normal,
97         dict(mean=3, stddev=4)),
98    )
99    for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
100      for shape_dtype in shape_dtypes:
101        for shape in (), (3,), (2, 5):
102          if shape_dtype is not None:
103            shape = constant_op.constant(shape, dtype=shape_dtype)
104          for stateless_op, stateful_op, kwds in float_cases:
105            kwds = dict(shape=shape, dtype=dtype, **kwds)
106            yield (functools.partial(stateless_op, **kwds),
107                   functools.partial(stateful_op, **kwds))
108
109  def _int_cases(self, shape_dtypes=(None,)):
110    for shape_dtype in shape_dtypes:
111      for shape in (), (3,), (2, 5):
112        if shape_dtype is not None:
113          shape = constant_op.constant(shape, dtype=shape_dtype)
114        for dtype in dtypes.int32, dtypes.int64:
115          kwds = dict(minval=2, maxval=11111, dtype=dtype, shape=shape)
116          yield (functools.partial(stateless.stateless_random_uniform, **kwds),
117                 functools.partial(random_ops.random_uniform, **kwds))
118
119  def _multinomial_cases(self):
120    num_samples = 10
121    for logits_dtype in np.float16, np.float32, np.float64:
122      for output_dtype in dtypes.int32, dtypes.int64:
123        for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
124                                                  [0.25, 0.75]]):
125          kwds = dict(
126              logits=constant_op.constant(logits, dtype=logits_dtype),
127              num_samples=num_samples,
128              output_dtype=output_dtype)
129          yield (functools.partial(stateless.stateless_multinomial, **kwds),
130                 functools.partial(random_ops.multinomial, **kwds))
131
132  @test_util.run_deprecated_v1
133  def testMatchFloat(self):
134    self._test_match(self._float_cases())
135
136  @test_util.run_deprecated_v1
137  def testMatchInt(self):
138    self._test_match(self._int_cases())
139
140  @test_util.run_deprecated_v1
141  def testMatchMultinomial(self):
142    self._test_match(self._multinomial_cases())
143
144  @test_util.run_deprecated_v1
145  def testDeterminismFloat(self):
146    self._test_determinism(
147        self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
148
149  @test_util.run_deprecated_v1
150  def testDeterminismInt(self):
151    self._test_determinism(
152        self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
153
154  @test_util.run_deprecated_v1
155  def testDeterminismMultinomial(self):
156    self._test_determinism(self._multinomial_cases())
157
158
159if __name__ == '__main__':
160  test.main()
161