1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.filter()`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python.data.kernel_tests import checkpoint_test_base 24from tensorflow.python.data.kernel_tests import test_base 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.framework import combinations 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.ops import map_fn 30from tensorflow.python.ops import math_ops 31from tensorflow.python.platform import test 32 33 34def _test_combinations(): 35 36 def filter_fn(dataset, predicate): 37 return dataset.filter(predicate) 38 39 def legacy_filter_fn(dataset, predicate): 40 return dataset.filter_with_legacy_function(predicate) 41 42 filter_combinations = combinations.combine( 43 tf_api_version=[1, 2], 44 mode=["eager", "graph"], 45 apply_filter=combinations.NamedObject("filter_fn", filter_fn)) 46 47 legacy_filter_combinations = combinations.combine( 48 tf_api_version=1, 49 mode=["eager", "graph"], 50 apply_filter=combinations.NamedObject("legacy_filter_fn", 51 legacy_filter_fn)) 52 53 return filter_combinations + legacy_filter_combinations 54 55 56class FilterTest(test_base.DatasetTestBase, parameterized.TestCase): 57 58 @combinations.generate(_test_combinations()) 59 def testFilterDataset(self, apply_filter): 60 components = (np.arange(7, dtype=np.int64), 61 np.array([[1, 2, 3]], dtype=np.int64) * 62 np.arange(7, dtype=np.int64)[:, np.newaxis], 63 np.array(37.0, dtype=np.float64) * np.arange(7)) 64 65 def _map_fn(x, y, z): 66 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 67 68 def do_test(count, modulus): # pylint: disable=missing-docstring 69 dataset = dataset_ops.Dataset.from_tensor_slices(components).map( 70 _map_fn).repeat(count) 71 # pylint: disable=g-long-lambda 72 dataset = apply_filter( 73 dataset, 74 lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0)) 75 # pylint: enable=g-long-lambda 76 self.assertEqual( 77 [c.shape[1:] for c in components], 78 [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) 79 get_next = self.getNext(dataset) 80 for _ in range(count): 81 for i in [x for x in range(7) if x**2 % modulus == 0]: 82 result = self.evaluate(get_next()) 83 for component, result_component in zip(components, result): 84 self.assertAllEqual(component[i]**2, result_component) 85 with self.assertRaises(errors.OutOfRangeError): 86 self.evaluate(get_next()) 87 88 do_test(14, 2) 89 do_test(4, 18) 90 91 # Test an empty dataset. 92 do_test(0, 1) 93 94 @combinations.generate(_test_combinations()) 95 def testFilterRange(self, apply_filter): 96 dataset = dataset_ops.Dataset.range(4) 97 dataset = apply_filter(dataset, 98 lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2)) 99 self.assertDatasetProduces(dataset, expected_output=[0, 1, 3]) 100 101 @combinations.generate(_test_combinations()) 102 def testFilterDict(self, apply_filter): 103 dataset = dataset_ops.Dataset.range(10).map( 104 lambda x: {"foo": x * 2, "bar": x**2}) 105 dataset = apply_filter(dataset, lambda d: math_ops.equal(d["bar"] % 2, 0)) 106 dataset = dataset.map(lambda d: d["foo"] + d["bar"]) 107 self.assertDatasetProduces( 108 dataset, 109 expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2]) 110 111 @combinations.generate(_test_combinations()) 112 def testUseStepContainerInFilter(self, apply_filter): 113 input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 114 115 # Define a predicate that returns true for the first element of 116 # the sequence and not the second, and uses `tf.map_fn()`. 117 def _predicate(xs): 118 squared_xs = map_fn.map_fn(lambda x: x * x, xs) 119 summed = math_ops.reduce_sum(squared_xs) 120 return math_ops.equal(summed, 1 + 4 + 9) 121 122 dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6]]) 123 dataset = apply_filter(dataset, _predicate) 124 self.assertDatasetProduces(dataset, expected_output=[input_data[0]]) 125 126 @combinations.generate(_test_combinations()) 127 def testSparse(self, apply_filter): 128 129 def _map_fn(i): 130 return sparse_tensor.SparseTensorValue( 131 indices=np.array([[0, 0]]), 132 values=(i * np.array([1])), 133 dense_shape=np.array([1, 1])), i 134 135 def _filter_fn(_, i): 136 return math_ops.equal(i % 2, 0) 137 138 dataset = dataset_ops.Dataset.range(10).map(_map_fn) 139 dataset = apply_filter(dataset, _filter_fn) 140 dataset = dataset.map(lambda x, i: x) 141 self.assertDatasetProduces( 142 dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)]) 143 144 @combinations.generate(_test_combinations()) 145 def testShortCircuit(self, apply_filter): 146 dataset = dataset_ops.Dataset.zip( 147 (dataset_ops.Dataset.range(10), 148 dataset_ops.Dataset.from_tensors(True).repeat(None))) 149 dataset = apply_filter(dataset, lambda x, y: y) 150 self.assertDatasetProduces( 151 dataset, expected_output=[(i, True) for i in range(10)]) 152 153 @combinations.generate(_test_combinations()) 154 def testParallelFilters(self, apply_filter): 155 dataset = dataset_ops.Dataset.range(10) 156 dataset = apply_filter(dataset, lambda x: math_ops.equal(x % 2, 0)) 157 next_elements = [self.getNext(dataset) for _ in range(10)] 158 self.assertEqual([0 for _ in range(10)], 159 self.evaluate( 160 [next_element() for next_element in next_elements])) 161 162 163class FilterCheckpointTest(checkpoint_test_base.CheckpointTestBase, 164 parameterized.TestCase): 165 166 def _build_filter_range_graph(self, div): 167 return dataset_ops.Dataset.range(100).filter( 168 lambda x: math_ops.not_equal(math_ops.mod(x, div), 2)) 169 170 @combinations.generate( 171 combinations.times(test_base.default_test_combinations(), 172 checkpoint_test_base.default_test_combinations())) 173 def test(self, verify_fn): 174 div = 3 175 num_outputs = sum(x % 3 != 2 for x in range(100)) 176 verify_fn(self, lambda: self._build_filter_range_graph(div), num_outputs) 177 178 def _build_filter_dict_graph(self): 179 return dataset_ops.Dataset.range(10).map(lambda x: { 180 "foo": x * 2, 181 "bar": x**2 182 }).filter(lambda d: math_ops.equal(d["bar"] % 2, 0)).map( 183 lambda d: d["foo"] + d["bar"]) 184 185 @combinations.generate( 186 combinations.times(test_base.default_test_combinations(), 187 checkpoint_test_base.default_test_combinations())) 188 def testDict(self, verify_fn): 189 num_outputs = sum((x**2) % 2 == 0 for x in range(10)) 190 verify_fn(self, self._build_filter_dict_graph, num_outputs) 191 192 def _build_sparse_filter(self): 193 194 def _map_fn(i): 195 return sparse_tensor.SparseTensor( 196 indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i 197 198 def _filter_fn(_, i): 199 return math_ops.equal(i % 2, 0) 200 201 return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( 202 lambda x, i: x) 203 204 @combinations.generate( 205 combinations.times(test_base.default_test_combinations(), 206 checkpoint_test_base.default_test_combinations())) 207 def testSparse(self, verify_fn): 208 verify_fn(self, self._build_sparse_filter, num_outputs=5) 209 210 211if __name__ == "__main__": 212 test.main() 213