• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.rejection_resample()`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python.data.kernel_tests import test_base
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.framework import combinations
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import random_ops
30from tensorflow.python.ops import string_ops
31from tensorflow.python.platform import test
32from tensorflow.python.util import compat
33
34
35class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
36
37  @combinations.generate(
38      combinations.times(test_base.default_test_combinations(),
39                         combinations.combine(initial_known=[True, False])))
40  def testDistribution(self, initial_known):
41    classes = np.random.randint(5, size=(10000,))  # Uniformly sampled
42    target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
43    initial_dist = [0.2] * 5 if initial_known else None
44    classes = math_ops.cast(classes, dtypes.int64)  # needed for Windows build.
45    dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
46        200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
47
48    get_next = self.getNext(
49        dataset.rejection_resample(
50            target_dist=target_dist,
51            initial_dist=initial_dist,
52            class_func=lambda c, _: c,
53            seed=27))
54
55    returned = []
56    while len(returned) < 2000:
57      returned.append(self.evaluate(get_next()))
58
59    returned_classes, returned_classes_and_data = zip(*returned)
60    _, returned_data = zip(*returned_classes_and_data)
61    self.assertAllEqual([compat.as_bytes(str(c)) for c in returned_classes],
62                        returned_data)
63    total_returned = len(returned_classes)
64    class_counts = np.array(
65        [len([True for v in returned_classes if v == c]) for c in range(5)])
66    returned_dist = class_counts / total_returned
67    self.assertAllClose(target_dist, returned_dist, atol=1e-2)
68
69  @combinations.generate(
70      combinations.times(test_base.default_test_combinations(),
71                         combinations.combine(only_initial_dist=[True, False])))
72  def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
73    init_dist = [0.5, 0.5]
74    target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
75    num_classes = len(init_dist)
76    # We don't need many samples to test that this works.
77    num_samples = 100
78    data_np = np.random.choice(num_classes, num_samples, p=init_dist)
79
80    dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
81
82    # Reshape distribution.
83    dataset = dataset.rejection_resample(
84        class_func=lambda x: x, target_dist=target_dist, initial_dist=init_dist)
85
86    get_next = self.getNext(dataset)
87
88    returned = []
89    with self.assertRaises(errors.OutOfRangeError):
90      while True:
91        returned.append(self.evaluate(get_next()))
92
93  @combinations.generate(test_base.default_test_combinations())
94  def testRandomClasses(self):
95    init_dist = [0.25, 0.25, 0.25, 0.25]
96    target_dist = [0.0, 0.0, 0.0, 1.0]
97    num_classes = len(init_dist)
98    # We don't need many samples to test a dirac-delta target distribution.
99    num_samples = 100
100    data_np = np.random.choice(num_classes, num_samples, p=init_dist)
101
102    dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
103
104    # Apply a random mapping that preserves the data distribution.
105    def _remap_fn(_):
106      return math_ops.cast(
107          random_ops.random_uniform([1]) * num_classes, dtypes.int32)[0]
108
109    dataset = dataset.map(_remap_fn)
110
111    # Reshape distribution.
112    dataset = dataset.rejection_resample(
113        class_func=lambda x: x, target_dist=target_dist, initial_dist=init_dist)
114
115    get_next = self.getNext(dataset)
116
117    returned = []
118    with self.assertRaises(errors.OutOfRangeError):
119      while True:
120        returned.append(self.evaluate(get_next()))
121
122    classes, _ = zip(*returned)
123    bincount = np.bincount(
124        np.array(classes), minlength=num_classes).astype(
125            np.float32) / len(classes)
126
127    self.assertAllClose(target_dist, bincount, atol=1e-2)
128
129  @combinations.generate(test_base.default_test_combinations())
130  def testExhaustion(self):
131    init_dist = [0.5, 0.5]
132    target_dist = [0.9, 0.1]
133    dataset = dataset_ops.Dataset.range(10000)
134    dataset = dataset.rejection_resample(
135        class_func=lambda x: x % 2,
136        target_dist=target_dist,
137        initial_dist=init_dist)
138
139    get_next = self.getNext(dataset)
140    returned = []
141    with self.assertRaises(errors.OutOfRangeError):
142      while True:
143        returned.append(self.evaluate(get_next()))
144
145    classes, _ = zip(*returned)
146    bincount = np.bincount(
147        np.array(classes), minlength=len(init_dist)).astype(
148            np.float32) / len(classes)
149
150    self.assertAllClose(target_dist, bincount, atol=1e-2)
151
152  @parameterized.parameters(
153      ("float32", "float64"),
154      ("float64", "float32"),
155      ("float64", "float64"),
156      ("float64", None),
157  )
158  def testOtherDtypes(self, target_dtype, init_dtype):
159    target_dist = np.array([0.5, 0.5], dtype=target_dtype)
160
161    if init_dtype is None:
162      init_dist = None
163    else:
164      init_dist = np.array([0.5, 0.5], dtype=init_dtype)
165
166    dataset = dataset_ops.Dataset.range(10)
167    dataset = dataset.rejection_resample(
168        class_func=lambda x: x % 2,
169        target_dist=target_dist,
170        initial_dist=init_dist)
171    get_next = self.getNext(dataset)
172    self.evaluate(get_next())
173
174
175if __name__ == "__main__":
176  test.main()
177