1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.filter()`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.data.kernel_tests import test_base 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import sparse_tensor 26from tensorflow.python.ops import map_fn 27from tensorflow.python.ops import math_ops 28 29 30class FilterTestBase(test_base.DatasetTestBase): 31 """Base class for FilterDataset tests.""" 32 33 def apply_filter(self, input_dataset, predicate): 34 raise NotImplementedError("FilterTestBase._apply_filter") 35 36 def testFilterDataset(self): 37 components = ( 38 np.arange(7, dtype=np.int64), 39 np.array([[1, 2, 3]], dtype=np.int64) * np.arange( 40 7, dtype=np.int64)[:, np.newaxis], 41 np.array(37.0, dtype=np.float64) * np.arange(7) 42 ) 43 def _map_fn(x, y, z): 44 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 45 46 def do_test(count, modulus): # pylint: disable=missing-docstring 47 dataset = dataset_ops.Dataset.from_tensor_slices(components).map( 48 _map_fn).repeat(count) 49 # pylint: disable=g-long-lambda 50 dataset = self.apply_filter( 51 dataset, lambda x, _y, _z: math_ops.equal( 52 math_ops.mod(x, modulus), 0)) 53 # pylint: enable=g-long-lambda 54 self.assertEqual( 55 [c.shape[1:] for c in components], 56 [shape for shape in dataset_ops.get_legacy_output_shapes(dataset)]) 57 get_next = self.getNext(dataset) 58 for _ in range(count): 59 for i in [x for x in range(7) if x**2 % modulus == 0]: 60 result = self.evaluate(get_next()) 61 for component, result_component in zip(components, result): 62 self.assertAllEqual(component[i]**2, result_component) 63 with self.assertRaises(errors.OutOfRangeError): 64 self.evaluate(get_next()) 65 66 do_test(14, 2) 67 do_test(4, 18) 68 69 # Test an empty dataset. 70 do_test(0, 1) 71 72 def testFilterRange(self): 73 dataset = dataset_ops.Dataset.range(4) 74 dataset = self.apply_filter( 75 dataset, lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2)) 76 self.assertDatasetProduces(dataset, expected_output=[0, 1, 3]) 77 78 def testFilterDict(self): 79 dataset = dataset_ops.Dataset.range(10).map( 80 lambda x: {"foo": x * 2, "bar": x ** 2}) 81 dataset = self.apply_filter( 82 dataset, lambda d: math_ops.equal(d["bar"] % 2, 0)) 83 dataset = dataset.map(lambda d: d["foo"] + d["bar"]) 84 self.assertDatasetProduces( 85 dataset, 86 expected_output=[(i * 2 + i**2) for i in range(10) if not (i**2) % 2]) 87 88 def testUseStepContainerInFilter(self): 89 input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 90 91 # Define a predicate that returns true for the first element of 92 # the sequence and not the second, and uses `tf.map_fn()`. 93 def _predicate(xs): 94 squared_xs = map_fn.map_fn(lambda x: x * x, xs) 95 summed = math_ops.reduce_sum(squared_xs) 96 return math_ops.equal(summed, 1 + 4 + 9) 97 98 dataset = dataset_ops.Dataset.from_tensor_slices( 99 [[1, 2, 3], [4, 5, 6]]) 100 dataset = self.apply_filter(dataset, _predicate) 101 self.assertDatasetProduces(dataset, expected_output=[input_data[0]]) 102 103 def testSparse(self): 104 105 def _map_fn(i): 106 return sparse_tensor.SparseTensorValue( 107 indices=np.array([[0, 0]]), 108 values=(i * np.array([1])), 109 dense_shape=np.array([1, 1])), i 110 111 def _filter_fn(_, i): 112 return math_ops.equal(i % 2, 0) 113 114 dataset = dataset_ops.Dataset.range(10).map(_map_fn) 115 dataset = self.apply_filter(dataset, _filter_fn) 116 dataset = dataset.map(lambda x, i: x) 117 self.assertDatasetProduces( 118 dataset, expected_output=[_map_fn(i * 2)[0] for i in range(5)]) 119 120 def testShortCircuit(self): 121 dataset = dataset_ops.Dataset.zip( 122 (dataset_ops.Dataset.range(10), 123 dataset_ops.Dataset.from_tensors(True).repeat(None) 124 )) 125 dataset = self.apply_filter(dataset, lambda x, y: y) 126 self.assertDatasetProduces( 127 dataset, expected_output=[(i, True) for i in range(10)]) 128 129 def testParallelFilters(self): 130 dataset = dataset_ops.Dataset.range(10) 131 dataset = self.apply_filter(dataset, lambda x: math_ops.equal(x % 2, 0)) 132 next_elements = [self.getNext(dataset) for _ in range(10)] 133 self.assertEqual([0 for _ in range(10)], 134 self.evaluate( 135 [next_element() for next_element in next_elements])) 136