• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.filter()`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python.data.kernel_tests import checkpoint_test_base
24from tensorflow.python.data.kernel_tests import test_base
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.framework import combinations
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.ops import map_fn
30from tensorflow.python.ops import math_ops
31from tensorflow.python.platform import test
32
33
34def _test_combinations():
35
36  def filter_fn(dataset, predicate):
37    return dataset.filter(predicate)
38
39  def legacy_filter_fn(dataset, predicate):
40    return dataset.filter_with_legacy_function(predicate)
41
42  filter_combinations = combinations.combine(
43      tf_api_version=[1, 2],
44      mode=["eager", "graph"],
45      apply_filter=combinations.NamedObject("filter_fn", filter_fn))
46
47  legacy_filter_combinations = combinations.combine(
48      tf_api_version=1,
49      mode=["eager", "graph"],
50      apply_filter=combinations.NamedObject("legacy_filter_fn",
51                                            legacy_filter_fn))
52
53  return filter_combinations + legacy_filter_combinations
54
55
56class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
57
58  @combinations.generate(_test_combinations())
59  def testFilterDataset(self, apply_filter):
60    components = (np.arange(7, dtype=np.int64),
61                  np.array([[1, 2, 3]], dtype=np.int64) *
62                  np.arange(7, dtype=np.int64)[:, np.newaxis],
63                  np.array(37.0, dtype=np.float64) * np.arange(7))
64
65    def _map_fn(x, y, z):
66      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
67
68    def do_test(count, modulus):  # pylint: disable=missing-docstring
69      dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
70          _map_fn).repeat(count)
71      # pylint: disable=g-long-lambda
72      dataset = apply_filter(
73          dataset,
74          lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0))
75      # pylint: enable=g-long-lambda
76      self.assertEqual(
77          [c.shape[1:] for c in components],
78          [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
79      get_next = self.getNext(dataset)
80      for _ in range(count):
81        for i in [x for x in range(7) if x**2 % modulus == 0]:
82          result = self.evaluate(get_next())
83          for component, result_component in zip(components, result):
84            self.assertAllEqual(component[i]**2, result_component)
85      with self.assertRaises(errors.OutOfRangeError):
86        self.evaluate(get_next())
87
88    do_test(14, 2)
89    do_test(4, 18)
90
91    # Test an empty dataset.
92    do_test(0, 1)
93
94  @combinations.generate(_test_combinations())
95  def testFilterRange(self, apply_filter):
96    dataset = dataset_ops.Dataset.range(4)
97    dataset = apply_filter(dataset,
98                           lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2))
99    self.assertDatasetProduces(dataset, expected_output=[0, 1, 3])
100
101  @combinations.generate(_test_combinations())
102  def testFilterDict(self, apply_filter):
103    dataset = dataset_ops.Dataset.range(10).map(
104        lambda x: {"foo": x * 2, "bar": x**2})
105    dataset = apply_filter(dataset, lambda d: math_ops.equal(d["bar"] % 2, 0))
106    dataset = dataset.map(lambda d: d["foo"] + d["bar"])
107    self.assertDatasetProduces(
108        dataset,
109        expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2])
110
111  @combinations.generate(_test_combinations())
112  def testUseStepContainerInFilter(self, apply_filter):
113    input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
114
115    # Define a predicate that returns true for the first element of
116    # the sequence and not the second, and uses `tf.map_fn()`.
117    def _predicate(xs):
118      squared_xs = map_fn.map_fn(lambda x: x * x, xs)
119      summed = math_ops.reduce_sum(squared_xs)
120      return math_ops.equal(summed, 1 + 4 + 9)
121
122    dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6]])
123    dataset = apply_filter(dataset, _predicate)
124    self.assertDatasetProduces(dataset, expected_output=[input_data[0]])
125
126  @combinations.generate(_test_combinations())
127  def testSparse(self, apply_filter):
128
129    def _map_fn(i):
130      return sparse_tensor.SparseTensorValue(
131          indices=np.array([[0, 0]]),
132          values=(i * np.array([1])),
133          dense_shape=np.array([1, 1])), i
134
135    def _filter_fn(_, i):
136      return math_ops.equal(i % 2, 0)
137
138    dataset = dataset_ops.Dataset.range(10).map(_map_fn)
139    dataset = apply_filter(dataset, _filter_fn)
140    dataset = dataset.map(lambda x, i: x)
141    self.assertDatasetProduces(
142        dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)])
143
144  @combinations.generate(_test_combinations())
145  def testShortCircuit(self, apply_filter):
146    dataset = dataset_ops.Dataset.zip(
147        (dataset_ops.Dataset.range(10),
148         dataset_ops.Dataset.from_tensors(True).repeat(None)))
149    dataset = apply_filter(dataset, lambda x, y: y)
150    self.assertDatasetProduces(
151        dataset, expected_output=[(i, True) for i in range(10)])
152
153  @combinations.generate(_test_combinations())
154  def testParallelFilters(self, apply_filter):
155    dataset = dataset_ops.Dataset.range(10)
156    dataset = apply_filter(dataset, lambda x: math_ops.equal(x % 2, 0))
157    next_elements = [self.getNext(dataset) for _ in range(10)]
158    self.assertEqual([0 for _ in range(10)],
159                     self.evaluate(
160                         [next_element() for next_element in next_elements]))
161
162
163class FilterCheckpointTest(checkpoint_test_base.CheckpointTestBase,
164                           parameterized.TestCase):
165
166  def _build_filter_range_graph(self, div):
167    return dataset_ops.Dataset.range(100).filter(
168        lambda x: math_ops.not_equal(math_ops.mod(x, div), 2))
169
170  @combinations.generate(
171      combinations.times(test_base.default_test_combinations(),
172                         checkpoint_test_base.default_test_combinations()))
173  def test(self, verify_fn):
174    div = 3
175    num_outputs = sum(x % 3 != 2 for x in range(100))
176    verify_fn(self, lambda: self._build_filter_range_graph(div), num_outputs)
177
178  def _build_filter_dict_graph(self):
179    return dataset_ops.Dataset.range(10).map(lambda x: {
180        "foo": x * 2,
181        "bar": x**2
182    }).filter(lambda d: math_ops.equal(d["bar"] % 2, 0)).map(
183        lambda d: d["foo"] + d["bar"])
184
185  @combinations.generate(
186      combinations.times(test_base.default_test_combinations(),
187                         checkpoint_test_base.default_test_combinations()))
188  def testDict(self, verify_fn):
189    num_outputs = sum((x**2) % 2 == 0 for x in range(10))
190    verify_fn(self, self._build_filter_dict_graph, num_outputs)
191
192  def _build_sparse_filter(self):
193
194    def _map_fn(i):
195      return sparse_tensor.SparseTensor(
196          indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
197
198    def _filter_fn(_, i):
199      return math_ops.equal(i % 2, 0)
200
201    return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map(
202        lambda x, i: x)
203
204  @combinations.generate(
205      combinations.times(test_base.default_test_combinations(),
206                         checkpoint_test_base.default_test_combinations()))
207  def testSparse(self, verify_fn):
208    verify_fn(self, self._build_sparse_filter, num_outputs=5)
209
210
211if __name__ == "__main__":
212  test.main()
213