• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.rejection_resample()`."""
16from absl.testing import parameterized
17import numpy as np
18
19from tensorflow.python.data.kernel_tests import test_base
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.framework import combinations
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import errors
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import random_ops
26from tensorflow.python.ops import string_ops
27from tensorflow.python.platform import test
28from tensorflow.python.util import compat
29
30
31class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
32
33  @combinations.generate(
34      combinations.times(test_base.default_test_combinations(),
35                         combinations.combine(initial_known=[True, False])))
36  def testDistribution(self, initial_known):
37    classes = np.random.randint(5, size=(10000,))  # Uniformly sampled
38    target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
39    initial_dist = [0.2] * 5 if initial_known else None
40    classes = math_ops.cast(classes, dtypes.int64)  # needed for Windows build.
41    dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
42        200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
43
44    get_next = self.getNext(
45        dataset.rejection_resample(
46            target_dist=target_dist,
47            initial_dist=initial_dist,
48            class_func=lambda c, _: c,
49            seed=27))
50
51    returned = []
52    while len(returned) < 2000:
53      returned.append(self.evaluate(get_next()))
54
55    returned_classes, returned_classes_and_data = zip(*returned)
56    _, returned_data = zip(*returned_classes_and_data)
57    self.assertAllEqual([compat.as_bytes(str(c)) for c in returned_classes],
58                        returned_data)
59    total_returned = len(returned_classes)
60    class_counts = np.array(
61        [len([True for v in returned_classes if v == c]) for c in range(5)])
62    returned_dist = class_counts / total_returned
63    self.assertAllClose(target_dist, returned_dist, atol=1e-2)
64
65  @combinations.generate(
66      combinations.times(test_base.default_test_combinations(),
67                         combinations.combine(only_initial_dist=[True, False])))
68  def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
69    init_dist = [0.5, 0.5]
70    target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
71    num_classes = len(init_dist)
72    # We don't need many samples to test that this works.
73    num_samples = 100
74    data_np = np.random.choice(num_classes, num_samples, p=init_dist)
75
76    dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
77
78    # Reshape distribution.
79    dataset = dataset.rejection_resample(
80        class_func=lambda x: x, target_dist=target_dist, initial_dist=init_dist)
81
82    get_next = self.getNext(dataset)
83
84    returned = []
85    with self.assertRaises(errors.OutOfRangeError):
86      while True:
87        returned.append(self.evaluate(get_next()))
88
89  @combinations.generate(test_base.default_test_combinations())
90  def testRandomClasses(self):
91    init_dist = [0.25, 0.25, 0.25, 0.25]
92    target_dist = [0.0, 0.0, 0.0, 1.0]
93    num_classes = len(init_dist)
94    # We don't need many samples to test a dirac-delta target distribution.
95    num_samples = 100
96    data_np = np.random.choice(num_classes, num_samples, p=init_dist)
97
98    dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
99
100    # Apply a random mapping that preserves the data distribution.
101    def _remap_fn(_):
102      return math_ops.cast(
103          random_ops.random_uniform([1]) * num_classes, dtypes.int32)[0]
104
105    dataset = dataset.map(_remap_fn)
106
107    # Reshape distribution.
108    dataset = dataset.rejection_resample(
109        class_func=lambda x: x, target_dist=target_dist, initial_dist=init_dist)
110
111    get_next = self.getNext(dataset)
112
113    returned = []
114    with self.assertRaises(errors.OutOfRangeError):
115      while True:
116        returned.append(self.evaluate(get_next()))
117
118    classes, _ = zip(*returned)
119    bincount = np.bincount(
120        np.array(classes), minlength=num_classes).astype(
121            np.float32) / len(classes)
122
123    self.assertAllClose(target_dist, bincount, atol=1e-2)
124
125  @combinations.generate(test_base.default_test_combinations())
126  def testExhaustion(self):
127    init_dist = [0.5, 0.5]
128    target_dist = [0.9, 0.1]
129    dataset = dataset_ops.Dataset.range(10000)
130    dataset = dataset.rejection_resample(
131        class_func=lambda x: x % 2,
132        target_dist=target_dist,
133        initial_dist=init_dist)
134
135    get_next = self.getNext(dataset)
136    returned = []
137    with self.assertRaises(errors.OutOfRangeError):
138      while True:
139        returned.append(self.evaluate(get_next()))
140
141    classes, _ = zip(*returned)
142    bincount = np.bincount(
143        np.array(classes), minlength=len(init_dist)).astype(
144            np.float32) / len(classes)
145
146    self.assertAllClose(target_dist, bincount, atol=1e-2)
147
148  @parameterized.parameters(
149      ("float32", "float64"),
150      ("float64", "float32"),
151      ("float64", "float64"),
152      ("float64", None),
153  )
154  def testOtherDtypes(self, target_dtype, init_dtype):
155    target_dist = np.array([0.5, 0.5], dtype=target_dtype)
156
157    if init_dtype is None:
158      init_dist = None
159    else:
160      init_dist = np.array([0.5, 0.5], dtype=init_dtype)
161
162    dataset = dataset_ops.Dataset.range(10)
163    dataset = dataset.rejection_resample(
164        class_func=lambda x: x % 2,
165        target_dist=target_dist,
166        initial_dist=init_dist)
167    get_next = self.getNext(dataset)
168    self.evaluate(get_next())
169
170
171if __name__ == "__main__":
172  test.main()
173