1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.rejection_resample()`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python.data.kernel_tests import test_base 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.framework import combinations 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import random_ops 30from tensorflow.python.ops import string_ops 31from tensorflow.python.platform import test 32from tensorflow.python.util import compat 33 34 35class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase): 36 37 @combinations.generate( 38 combinations.times(test_base.default_test_combinations(), 39 combinations.combine(initial_known=[True, False]))) 40 def testDistribution(self, initial_known): 41 classes = np.random.randint(5, size=(10000,)) # Uniformly sampled 42 target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] 43 initial_dist = [0.2] * 5 if initial_known else None 44 classes = math_ops.cast(classes, dtypes.int64) # needed for Windows build. 45 dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle( 46 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat() 47 48 get_next = self.getNext( 49 dataset.rejection_resample( 50 target_dist=target_dist, 51 initial_dist=initial_dist, 52 class_func=lambda c, _: c, 53 seed=27)) 54 55 returned = [] 56 while len(returned) < 2000: 57 returned.append(self.evaluate(get_next())) 58 59 returned_classes, returned_classes_and_data = zip(*returned) 60 _, returned_data = zip(*returned_classes_and_data) 61 self.assertAllEqual([compat.as_bytes(str(c)) for c in returned_classes], 62 returned_data) 63 total_returned = len(returned_classes) 64 class_counts = np.array( 65 [len([True for v in returned_classes if v == c]) for c in range(5)]) 66 returned_dist = class_counts / total_returned 67 self.assertAllClose(target_dist, returned_dist, atol=1e-2) 68 69 @combinations.generate( 70 combinations.times(test_base.default_test_combinations(), 71 combinations.combine(only_initial_dist=[True, False]))) 72 def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist): 73 init_dist = [0.5, 0.5] 74 target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0] 75 num_classes = len(init_dist) 76 # We don't need many samples to test that this works. 77 num_samples = 100 78 data_np = np.random.choice(num_classes, num_samples, p=init_dist) 79 80 dataset = dataset_ops.Dataset.from_tensor_slices(data_np) 81 82 # Reshape distribution. 83 dataset = dataset.rejection_resample( 84 class_func=lambda x: x, target_dist=target_dist, initial_dist=init_dist) 85 86 get_next = self.getNext(dataset) 87 88 returned = [] 89 with self.assertRaises(errors.OutOfRangeError): 90 while True: 91 returned.append(self.evaluate(get_next())) 92 93 @combinations.generate(test_base.default_test_combinations()) 94 def testRandomClasses(self): 95 init_dist = [0.25, 0.25, 0.25, 0.25] 96 target_dist = [0.0, 0.0, 0.0, 1.0] 97 num_classes = len(init_dist) 98 # We don't need many samples to test a dirac-delta target distribution. 99 num_samples = 100 100 data_np = np.random.choice(num_classes, num_samples, p=init_dist) 101 102 dataset = dataset_ops.Dataset.from_tensor_slices(data_np) 103 104 # Apply a random mapping that preserves the data distribution. 105 def _remap_fn(_): 106 return math_ops.cast( 107 random_ops.random_uniform([1]) * num_classes, dtypes.int32)[0] 108 109 dataset = dataset.map(_remap_fn) 110 111 # Reshape distribution. 112 dataset = dataset.rejection_resample( 113 class_func=lambda x: x, target_dist=target_dist, initial_dist=init_dist) 114 115 get_next = self.getNext(dataset) 116 117 returned = [] 118 with self.assertRaises(errors.OutOfRangeError): 119 while True: 120 returned.append(self.evaluate(get_next())) 121 122 classes, _ = zip(*returned) 123 bincount = np.bincount( 124 np.array(classes), minlength=num_classes).astype( 125 np.float32) / len(classes) 126 127 self.assertAllClose(target_dist, bincount, atol=1e-2) 128 129 @combinations.generate(test_base.default_test_combinations()) 130 def testExhaustion(self): 131 init_dist = [0.5, 0.5] 132 target_dist = [0.9, 0.1] 133 dataset = dataset_ops.Dataset.range(10000) 134 dataset = dataset.rejection_resample( 135 class_func=lambda x: x % 2, 136 target_dist=target_dist, 137 initial_dist=init_dist) 138 139 get_next = self.getNext(dataset) 140 returned = [] 141 with self.assertRaises(errors.OutOfRangeError): 142 while True: 143 returned.append(self.evaluate(get_next())) 144 145 classes, _ = zip(*returned) 146 bincount = np.bincount( 147 np.array(classes), minlength=len(init_dist)).astype( 148 np.float32) / len(classes) 149 150 self.assertAllClose(target_dist, bincount, atol=1e-2) 151 152 @parameterized.parameters( 153 ("float32", "float64"), 154 ("float64", "float32"), 155 ("float64", "float64"), 156 ("float64", None), 157 ) 158 def testOtherDtypes(self, target_dtype, init_dtype): 159 target_dist = np.array([0.5, 0.5], dtype=target_dtype) 160 161 if init_dtype is None: 162 init_dist = None 163 else: 164 init_dist = np.array([0.5, 0.5], dtype=init_dtype) 165 166 dataset = dataset_ops.Dataset.range(10) 167 dataset = dataset.rejection_resample( 168 class_func=lambda x: x % 2, 169 target_dist=target_dist, 170 initial_dist=init_dist) 171 get_next = self.getNext(dataset) 172 self.evaluate(get_next()) 173 174 175if __name__ == "__main__": 176 test.main() 177