1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.filter()`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python.data.kernel_tests import test_base 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.framework import combinations 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import sparse_tensor 28from tensorflow.python.ops import map_fn 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import test 31 32 33def _test_combinations(): 34 35 def filter_fn(dataset, predicate): 36 return dataset.filter(predicate) 37 38 def legacy_filter_fn(dataset, predicate): 39 return dataset.filter_with_legacy_function(predicate) 40 41 filter_combinations = combinations.combine( 42 tf_api_version=[1, 2], 43 mode=["eager", "graph"], 44 apply_filter=combinations.NamedObject("filter_fn", filter_fn)) 45 46 legacy_filter_combinations = combinations.combine( 47 tf_api_version=1, 48 mode=["eager", "graph"], 49 apply_filter=combinations.NamedObject("legacy_filter_fn", 50 legacy_filter_fn)) 51 52 return filter_combinations + legacy_filter_combinations 53 54 55class FilterTest(test_base.DatasetTestBase, parameterized.TestCase): 56 57 @combinations.generate(_test_combinations()) 58 def testFilterDataset(self, apply_filter): 59 components = (np.arange(7, dtype=np.int64), 60 np.array([[1, 2, 3]], dtype=np.int64) * 61 np.arange(7, dtype=np.int64)[:, np.newaxis], 62 np.array(37.0, dtype=np.float64) * np.arange(7)) 63 64 def _map_fn(x, y, z): 65 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 66 67 def do_test(count, modulus): # pylint: disable=missing-docstring 68 dataset = dataset_ops.Dataset.from_tensor_slices(components).map( 69 _map_fn).repeat(count) 70 # pylint: disable=g-long-lambda 71 dataset = apply_filter( 72 dataset, 73 lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0)) 74 # pylint: enable=g-long-lambda 75 self.assertEqual( 76 [c.shape[1:] for c in components], 77 [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) 78 get_next = self.getNext(dataset) 79 for _ in range(count): 80 for i in [x for x in range(7) if x**2 % modulus == 0]: 81 result = self.evaluate(get_next()) 82 for component, result_component in zip(components, result): 83 self.assertAllEqual(component[i]**2, result_component) 84 with self.assertRaises(errors.OutOfRangeError): 85 self.evaluate(get_next()) 86 87 do_test(14, 2) 88 do_test(4, 18) 89 90 # Test an empty dataset. 91 do_test(0, 1) 92 93 @combinations.generate(_test_combinations()) 94 def testFilterRange(self, apply_filter): 95 dataset = dataset_ops.Dataset.range(4) 96 dataset = apply_filter(dataset, 97 lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2)) 98 self.assertDatasetProduces(dataset, expected_output=[0, 1, 3]) 99 100 @combinations.generate(_test_combinations()) 101 def testFilterDict(self, apply_filter): 102 dataset = dataset_ops.Dataset.range(10).map( 103 lambda x: {"foo": x * 2, "bar": x**2}) 104 dataset = apply_filter(dataset, lambda d: math_ops.equal(d["bar"] % 2, 0)) 105 dataset = dataset.map(lambda d: d["foo"] + d["bar"]) 106 self.assertDatasetProduces( 107 dataset, 108 expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2]) 109 110 @combinations.generate(_test_combinations()) 111 def testUseStepContainerInFilter(self, apply_filter): 112 input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 113 114 # Define a predicate that returns true for the first element of 115 # the sequence and not the second, and uses `tf.map_fn()`. 116 def _predicate(xs): 117 squared_xs = map_fn.map_fn(lambda x: x * x, xs) 118 summed = math_ops.reduce_sum(squared_xs) 119 return math_ops.equal(summed, 1 + 4 + 9) 120 121 dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6]]) 122 dataset = apply_filter(dataset, _predicate) 123 self.assertDatasetProduces(dataset, expected_output=[input_data[0]]) 124 125 @combinations.generate(_test_combinations()) 126 def testSparse(self, apply_filter): 127 128 def _map_fn(i): 129 return sparse_tensor.SparseTensorValue( 130 indices=np.array([[0, 0]]), 131 values=(i * np.array([1])), 132 dense_shape=np.array([1, 1])), i 133 134 def _filter_fn(_, i): 135 return math_ops.equal(i % 2, 0) 136 137 dataset = dataset_ops.Dataset.range(10).map(_map_fn) 138 dataset = apply_filter(dataset, _filter_fn) 139 dataset = dataset.map(lambda x, i: x) 140 self.assertDatasetProduces( 141 dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)]) 142 143 @combinations.generate(_test_combinations()) 144 def testShortCircuit(self, apply_filter): 145 dataset = dataset_ops.Dataset.zip( 146 (dataset_ops.Dataset.range(10), 147 dataset_ops.Dataset.from_tensors(True).repeat(None))) 148 dataset = apply_filter(dataset, lambda x, y: y) 149 self.assertDatasetProduces( 150 dataset, expected_output=[(i, True) for i in range(10)]) 151 152 @combinations.generate(_test_combinations()) 153 def testParallelFilters(self, apply_filter): 154 dataset = dataset_ops.Dataset.range(10) 155 dataset = apply_filter(dataset, lambda x: math_ops.equal(x % 2, 0)) 156 next_elements = [self.getNext(dataset) for _ in range(10)] 157 self.assertEqual([0 for _ in range(10)], 158 self.evaluate( 159 [next_element() for next_element in next_elements])) 160 161 162if __name__ == "__main__": 163 test.main() 164