1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.rejection_resample()`.""" 16from absl.testing import parameterized 17import numpy as np 18 19from tensorflow.python.data.kernel_tests import test_base 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.framework import combinations 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import errors 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import random_ops 26from tensorflow.python.ops import string_ops 27from tensorflow.python.platform import test 28from tensorflow.python.util import compat 29 30 31class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase): 32 33 @combinations.generate( 34 combinations.times(test_base.default_test_combinations(), 35 combinations.combine(initial_known=[True, False]))) 36 def testDistribution(self, initial_known): 37 classes = np.random.randint(5, size=(10000,)) # Uniformly sampled 38 target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] 39 initial_dist = [0.2] * 5 if initial_known else None 40 classes = math_ops.cast(classes, dtypes.int64) # needed for Windows build. 41 dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle( 42 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat() 43 44 get_next = self.getNext( 45 dataset.rejection_resample( 46 target_dist=target_dist, 47 initial_dist=initial_dist, 48 class_func=lambda c, _: c, 49 seed=27)) 50 51 returned = [] 52 while len(returned) < 2000: 53 returned.append(self.evaluate(get_next())) 54 55 returned_classes, returned_classes_and_data = zip(*returned) 56 _, returned_data = zip(*returned_classes_and_data) 57 self.assertAllEqual([compat.as_bytes(str(c)) for c in returned_classes], 58 returned_data) 59 total_returned = len(returned_classes) 60 class_counts = np.array( 61 [len([True for v in returned_classes if v == c]) for c in range(5)]) 62 returned_dist = class_counts / total_returned 63 self.assertAllClose(target_dist, returned_dist, atol=1e-2) 64 65 @combinations.generate( 66 combinations.times(test_base.default_test_combinations(), 67 combinations.combine(only_initial_dist=[True, False]))) 68 def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist): 69 init_dist = [0.5, 0.5] 70 target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0] 71 num_classes = len(init_dist) 72 # We don't need many samples to test that this works. 73 num_samples = 100 74 data_np = np.random.choice(num_classes, num_samples, p=init_dist) 75 76 dataset = dataset_ops.Dataset.from_tensor_slices(data_np) 77 78 # Reshape distribution. 79 dataset = dataset.rejection_resample( 80 class_func=lambda x: x, target_dist=target_dist, initial_dist=init_dist) 81 82 get_next = self.getNext(dataset) 83 84 returned = [] 85 with self.assertRaises(errors.OutOfRangeError): 86 while True: 87 returned.append(self.evaluate(get_next())) 88 89 @combinations.generate(test_base.default_test_combinations()) 90 def testRandomClasses(self): 91 init_dist = [0.25, 0.25, 0.25, 0.25] 92 target_dist = [0.0, 0.0, 0.0, 1.0] 93 num_classes = len(init_dist) 94 # We don't need many samples to test a dirac-delta target distribution. 95 num_samples = 100 96 data_np = np.random.choice(num_classes, num_samples, p=init_dist) 97 98 dataset = dataset_ops.Dataset.from_tensor_slices(data_np) 99 100 # Apply a random mapping that preserves the data distribution. 101 def _remap_fn(_): 102 return math_ops.cast( 103 random_ops.random_uniform([1]) * num_classes, dtypes.int32)[0] 104 105 dataset = dataset.map(_remap_fn) 106 107 # Reshape distribution. 108 dataset = dataset.rejection_resample( 109 class_func=lambda x: x, target_dist=target_dist, initial_dist=init_dist) 110 111 get_next = self.getNext(dataset) 112 113 returned = [] 114 with self.assertRaises(errors.OutOfRangeError): 115 while True: 116 returned.append(self.evaluate(get_next())) 117 118 classes, _ = zip(*returned) 119 bincount = np.bincount( 120 np.array(classes), minlength=num_classes).astype( 121 np.float32) / len(classes) 122 123 self.assertAllClose(target_dist, bincount, atol=1e-2) 124 125 @combinations.generate(test_base.default_test_combinations()) 126 def testExhaustion(self): 127 init_dist = [0.5, 0.5] 128 target_dist = [0.9, 0.1] 129 dataset = dataset_ops.Dataset.range(10000) 130 dataset = dataset.rejection_resample( 131 class_func=lambda x: x % 2, 132 target_dist=target_dist, 133 initial_dist=init_dist) 134 135 get_next = self.getNext(dataset) 136 returned = [] 137 with self.assertRaises(errors.OutOfRangeError): 138 while True: 139 returned.append(self.evaluate(get_next())) 140 141 classes, _ = zip(*returned) 142 bincount = np.bincount( 143 np.array(classes), minlength=len(init_dist)).astype( 144 np.float32) / len(classes) 145 146 self.assertAllClose(target_dist, bincount, atol=1e-2) 147 148 @parameterized.parameters( 149 ("float32", "float64"), 150 ("float64", "float32"), 151 ("float64", "float64"), 152 ("float64", None), 153 ) 154 def testOtherDtypes(self, target_dtype, init_dtype): 155 target_dist = np.array([0.5, 0.5], dtype=target_dtype) 156 157 if init_dtype is None: 158 init_dist = None 159 else: 160 init_dist = np.array([0.5, 0.5], dtype=init_dtype) 161 162 dataset = dataset_ops.Dataset.range(10) 163 dataset = dataset.rejection_resample( 164 class_func=lambda x: x % 2, 165 target_dist=target_dist, 166 initial_dist=init_dist) 167 get_next = self.getNext(dataset) 168 self.evaluate(get_next()) 169 170 171if __name__ == "__main__": 172 test.main() 173