• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.filter()`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python.data.kernel_tests import test_base
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.framework import combinations
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.ops import map_fn
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import test
31
32
33def _test_combinations():
34
35  def filter_fn(dataset, predicate):
36    return dataset.filter(predicate)
37
38  def legacy_filter_fn(dataset, predicate):
39    return dataset.filter_with_legacy_function(predicate)
40
41  filter_combinations = combinations.combine(
42      tf_api_version=[1, 2],
43      mode=["eager", "graph"],
44      apply_filter=combinations.NamedObject("filter_fn", filter_fn))
45
46  legacy_filter_combinations = combinations.combine(
47      tf_api_version=1,
48      mode=["eager", "graph"],
49      apply_filter=combinations.NamedObject("legacy_filter_fn",
50                                            legacy_filter_fn))
51
52  return filter_combinations + legacy_filter_combinations
53
54
55class FilterTest(test_base.DatasetTestBase, parameterized.TestCase):
56
57  @combinations.generate(_test_combinations())
58  def testFilterDataset(self, apply_filter):
59    components = (np.arange(7, dtype=np.int64),
60                  np.array([[1, 2, 3]], dtype=np.int64) *
61                  np.arange(7, dtype=np.int64)[:, np.newaxis],
62                  np.array(37.0, dtype=np.float64) * np.arange(7))
63
64    def _map_fn(x, y, z):
65      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
66
67    def do_test(count, modulus):  # pylint: disable=missing-docstring
68      dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
69          _map_fn).repeat(count)
70      # pylint: disable=g-long-lambda
71      dataset = apply_filter(
72          dataset,
73          lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0))
74      # pylint: enable=g-long-lambda
75      self.assertEqual(
76          [c.shape[1:] for c in components],
77          [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
78      get_next = self.getNext(dataset)
79      for _ in range(count):
80        for i in [x for x in range(7) if x**2 % modulus == 0]:
81          result = self.evaluate(get_next())
82          for component, result_component in zip(components, result):
83            self.assertAllEqual(component[i]**2, result_component)
84      with self.assertRaises(errors.OutOfRangeError):
85        self.evaluate(get_next())
86
87    do_test(14, 2)
88    do_test(4, 18)
89
90    # Test an empty dataset.
91    do_test(0, 1)
92
93  @combinations.generate(_test_combinations())
94  def testFilterRange(self, apply_filter):
95    dataset = dataset_ops.Dataset.range(4)
96    dataset = apply_filter(dataset,
97                           lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2))
98    self.assertDatasetProduces(dataset, expected_output=[0, 1, 3])
99
100  @combinations.generate(_test_combinations())
101  def testFilterDict(self, apply_filter):
102    dataset = dataset_ops.Dataset.range(10).map(
103        lambda x: {"foo": x * 2, "bar": x**2})
104    dataset = apply_filter(dataset, lambda d: math_ops.equal(d["bar"] % 2, 0))
105    dataset = dataset.map(lambda d: d["foo"] + d["bar"])
106    self.assertDatasetProduces(
107        dataset,
108        expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2])
109
110  @combinations.generate(_test_combinations())
111  def testUseStepContainerInFilter(self, apply_filter):
112    input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
113
114    # Define a predicate that returns true for the first element of
115    # the sequence and not the second, and uses `tf.map_fn()`.
116    def _predicate(xs):
117      squared_xs = map_fn.map_fn(lambda x: x * x, xs)
118      summed = math_ops.reduce_sum(squared_xs)
119      return math_ops.equal(summed, 1 + 4 + 9)
120
121    dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6]])
122    dataset = apply_filter(dataset, _predicate)
123    self.assertDatasetProduces(dataset, expected_output=[input_data[0]])
124
125  @combinations.generate(_test_combinations())
126  def testSparse(self, apply_filter):
127
128    def _map_fn(i):
129      return sparse_tensor.SparseTensorValue(
130          indices=np.array([[0, 0]]),
131          values=(i * np.array([1])),
132          dense_shape=np.array([1, 1])), i
133
134    def _filter_fn(_, i):
135      return math_ops.equal(i % 2, 0)
136
137    dataset = dataset_ops.Dataset.range(10).map(_map_fn)
138    dataset = apply_filter(dataset, _filter_fn)
139    dataset = dataset.map(lambda x, i: x)
140    self.assertDatasetProduces(
141        dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)])
142
143  @combinations.generate(_test_combinations())
144  def testShortCircuit(self, apply_filter):
145    dataset = dataset_ops.Dataset.zip(
146        (dataset_ops.Dataset.range(10),
147         dataset_ops.Dataset.from_tensors(True).repeat(None)))
148    dataset = apply_filter(dataset, lambda x, y: y)
149    self.assertDatasetProduces(
150        dataset, expected_output=[(i, True) for i in range(10)])
151
152  @combinations.generate(_test_combinations())
153  def testParallelFilters(self, apply_filter):
154    dataset = dataset_ops.Dataset.range(10)
155    dataset = apply_filter(dataset, lambda x: math_ops.equal(x % 2, 0))
156    next_elements = [self.getNext(dataset) for _ in range(10)]
157    self.assertEqual([0 for _ in range(10)],
158                     self.evaluate(
159                         [next_element() for next_element in next_elements]))
160
161
162if __name__ == "__main__":
163  test.main()
164