• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.filter()`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.data.kernel_tests import test_base
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import sparse_tensor
26from tensorflow.python.ops import map_fn
27from tensorflow.python.ops import math_ops
28
29
30class FilterTestBase(test_base.DatasetTestBase):
31  """Base class for FilterDataset tests."""
32
33  def apply_filter(self, input_dataset, predicate):
34    raise NotImplementedError("FilterTestBase._apply_filter")
35
36  def testFilterDataset(self):
37    components = (
38        np.arange(7, dtype=np.int64),
39        np.array([[1, 2, 3]], dtype=np.int64) * np.arange(
40            7, dtype=np.int64)[:, np.newaxis],
41        np.array(37.0, dtype=np.float64) * np.arange(7)
42    )
43    def _map_fn(x, y, z):
44      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
45
46    def do_test(count, modulus):  # pylint: disable=missing-docstring
47      dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
48          _map_fn).repeat(count)
49      # pylint: disable=g-long-lambda
50      dataset = self.apply_filter(
51          dataset, lambda x, _y, _z: math_ops.equal(
52              math_ops.mod(x, modulus), 0))
53      # pylint: enable=g-long-lambda
54      self.assertEqual(
55          [c.shape[1:] for c in components],
56          [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)])
57      get_next = self.getNext(dataset)
58      for _ in range(count):
59        for i in [x for x in range(7) if x**2 % modulus == 0]:
60          result = self.evaluate(get_next())
61          for component, result_component in zip(components, result):
62            self.assertAllEqual(component[i]**2, result_component)
63      with self.assertRaises(errors.OutOfRangeError):
64        self.evaluate(get_next())
65
66    do_test(14, 2)
67    do_test(4, 18)
68
69    # Test an empty dataset.
70    do_test(0, 1)
71
72  def testFilterRange(self):
73    dataset = dataset_ops.Dataset.range(4)
74    dataset = self.apply_filter(
75        dataset, lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2))
76    self.assertDatasetProduces(dataset, expected_output=[0, 1, 3])
77
78  def testFilterDict(self):
79    dataset = dataset_ops.Dataset.range(10).map(
80        lambda x: {"foo": x * 2, "bar": x ** 2})
81    dataset = self.apply_filter(
82        dataset, lambda d: math_ops.equal(d["bar"] % 2, 0))
83    dataset = dataset.map(lambda d: d["foo"] + d["bar"])
84    self.assertDatasetProduces(
85        dataset,
86        expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2])
87
88  def testUseStepContainerInFilter(self):
89    input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
90
91    # Define a predicate that returns true for the first element of
92    # the sequence and not the second, and uses `tf.map_fn()`.
93    def _predicate(xs):
94      squared_xs = map_fn.map_fn(lambda x: x * x, xs)
95      summed = math_ops.reduce_sum(squared_xs)
96      return math_ops.equal(summed, 1 + 4 + 9)
97
98    dataset = dataset_ops.Dataset.from_tensor_slices(
99        [[1, 2, 3], [4, 5, 6]])
100    dataset = self.apply_filter(dataset, _predicate)
101    self.assertDatasetProduces(dataset, expected_output=[input_data[0]])
102
103  def testSparse(self):
104
105    def _map_fn(i):
106      return sparse_tensor.SparseTensorValue(
107          indices=np.array([[0, 0]]),
108          values=(i * np.array([1])),
109          dense_shape=np.array([1, 1])), i
110
111    def _filter_fn(_, i):
112      return math_ops.equal(i % 2, 0)
113
114    dataset = dataset_ops.Dataset.range(10).map(_map_fn)
115    dataset = self.apply_filter(dataset, _filter_fn)
116    dataset = dataset.map(lambda x, i: x)
117    self.assertDatasetProduces(
118        dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)])
119
120  def testShortCircuit(self):
121    dataset = dataset_ops.Dataset.zip(
122        (dataset_ops.Dataset.range(10),
123         dataset_ops.Dataset.from_tensors(True).repeat(None)
124        ))
125    dataset = self.apply_filter(dataset, lambda x, y: y)
126    self.assertDatasetProduces(
127        dataset, expected_output=[(i, True) for i in range(10)])
128
129  def testParallelFilters(self):
130    dataset = dataset_ops.Dataset.range(10)
131    dataset = self.apply_filter(dataset, lambda x: math_ops.equal(x % 2, 0))
132    next_elements = [self.getNext(dataset) for _ in range(10)]
133    self.assertEqual([0 for _ in range(10)],
134                     self.evaluate(
135                         [next_element() for next_element in next_elements]))
136