1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.window()`.""" 16from absl.testing import parameterized 17import numpy as np 18 19from tensorflow.python.data.kernel_tests import checkpoint_test_base 20from tensorflow.python.data.kernel_tests import test_base 21from tensorflow.python.data.ops import dataset_ops 22from tensorflow.python.data.util import nest 23from tensorflow.python.eager import context 24from tensorflow.python.framework import combinations 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import sparse_tensor 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import test 31 32 33class WindowTest(test_base.DatasetTestBase, parameterized.TestCase): 34 35 @combinations.generate( 36 combinations.times( 37 test_base.default_test_combinations(), 38 combinations.combine( 39 count=20, 40 size=[10, 14, 17], 41 shift=[7, 14], 42 stride=[1, 2, 6], 43 drop_remainder=[True, False]) + combinations.combine( 44 count=[0, 1], 45 size=10, 46 shift=4, 47 stride=1, 48 drop_remainder=[True, False]))) 49 def testWindowDataset(self, count, size, shift, stride, drop_remainder): 50 """Tests a dataset that slides a window its input elements.""" 51 components = (np.arange(7), 52 np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], 53 np.array(37.0) * np.arange(7)) 54 55 def _map_fn(x, y, z): 56 return math_ops.square(x), math_ops.square(y), math_ops.square(z) 57 58 def _flat_map_fn(x, y, z): 59 return dataset_ops.Dataset.zip((x.batch(batch_size=size), 60 y.batch(batch_size=size), 61 z.batch(batch_size=size))) 62 63 dataset = dataset_ops.Dataset.from_tensor_slices(components).map( 64 _map_fn).repeat(count).window( 65 size=size, 66 shift=shift, 67 stride=stride, 68 drop_remainder=drop_remainder).flat_map(_flat_map_fn) 69 get_next = self.getNext(dataset) 70 71 self.assertEqual([[None] + list(c.shape[1:]) for c in components], 72 [ts.as_list() for ts in nest.flatten( 73 dataset_ops.get_legacy_output_shapes(dataset))]) 74 75 num_full_batches = max(0, 76 (count * 7 - ((size - 1) * stride + 1)) // shift + 1) 77 for i in range(num_full_batches): 78 result = self.evaluate(get_next()) 79 for component, result_component in zip(components, result): 80 for j in range(size): 81 self.assertAllEqual(component[(i * shift + j * stride) % 7]**2, 82 result_component[j]) 83 if not drop_remainder: 84 num_partial_batches = (count * 7) // shift + ( 85 (count * 7) % shift > 0) - num_full_batches 86 for i in range(num_partial_batches): 87 result = self.evaluate(get_next()) 88 for component, result_component in zip(components, result): 89 remaining = (count * 7) - ((num_full_batches + i) * shift) 90 num_elements = remaining // stride + ((remaining % stride) > 0) 91 for j in range(num_elements): 92 self.assertAllEqual( 93 component[((num_full_batches + i) * shift + j * stride) % 7]**2, 94 result_component[j]) 95 with self.assertRaises(errors.OutOfRangeError): 96 self.evaluate(get_next()) 97 with self.assertRaises(errors.OutOfRangeError): 98 self.evaluate(get_next()) 99 100 @combinations.generate( 101 combinations.times( 102 test_base.default_test_combinations(), 103 combinations.combine(count=20, size=0, shift=3, stride=1) + 104 combinations.combine(count=20, size=3, shift=0, stride=1) + 105 combinations.combine(count=20, size=3, shift=3, stride=0))) 106 def testWindowDatasetInvalid(self, count, size, shift, stride): 107 with self.assertRaises(errors.InvalidArgumentError): 108 ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window( 109 size=size, shift=shift, 110 stride=stride).flat_map(lambda x: x.batch(batch_size=size)) 111 self.evaluate(ds._variant_tensor) 112 113 @combinations.generate(test_base.default_test_combinations()) 114 def testWindowDifferentNestedStructures(self): 115 ds = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])).window(2) 116 self.getNext(ds) 117 ds = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2]}).window(2) 118 self.getNext(ds) 119 120 @combinations.generate(test_base.default_test_combinations()) 121 def testWindowSparse(self): 122 123 def _sparse(i): 124 return sparse_tensor.SparseTensorValue( 125 indices=[[0]], values=(i * [1]), dense_shape=[1]) 126 127 dataset = dataset_ops.Dataset.range(10).map(_sparse).window( 128 size=5, shift=3, 129 drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5)) 130 131 num_batches = (10 - 5) // 3 + 1 132 expected_output = [ 133 sparse_tensor.SparseTensorValue( 134 indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], 135 values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4], 136 dense_shape=[5, 1]) for i in range(num_batches) 137 ] 138 self.assertDatasetProduces(dataset, expected_output=expected_output) 139 140 @combinations.generate(test_base.default_test_combinations()) 141 def testWindowSparseWithDifferentDenseShapes(self): 142 143 def _sparse(i): 144 return sparse_tensor.SparseTensorValue( 145 indices=array_ops.expand_dims( 146 math_ops.range(i, dtype=dtypes.int64), 1), 147 values=array_ops.fill([math_ops.cast(i, dtypes.int32)], i), 148 dense_shape=[i]) 149 150 dataset = dataset_ops.Dataset.range(10).map(_sparse).window( 151 size=5, shift=3, 152 drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5)) 153 154 expected_output = [] 155 num_batches = (10 - 5) // 3 + 1 156 for i in range(num_batches): 157 expected_indices = [] 158 expected_values = [] 159 for j in range(5): 160 for k in range(i * 3 + j): 161 expected_indices.append([j, k]) 162 expected_values.append(i * 3 + j) 163 expected_output.append( 164 sparse_tensor.SparseTensorValue( 165 indices=expected_indices, 166 values=expected_values, 167 dense_shape=[5, i * 3 + 5 - 1])) 168 self.assertDatasetProduces(dataset, expected_output=expected_output) 169 170 @combinations.generate(test_base.default_test_combinations()) 171 def testNestedWindowSparse(self): 172 173 def _sparse(i): 174 return sparse_tensor.SparseTensorValue( 175 indices=[[0]], values=(i * [1]), dense_shape=[1]) 176 177 dataset = dataset_ops.Dataset.range(10).map(_sparse).window( 178 size=4, shift=2, 179 drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window( 180 size=3, shift=1, 181 drop_remainder=True).flat_map(lambda x: x.batch(batch_size=3)) 182 183 expected_output = [ 184 sparse_tensor.SparseTensorValue( 185 indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], 186 [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], 187 [2, 2, 0], [2, 3, 0]], 188 values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7], 189 dense_shape=[3, 4, 1]), 190 sparse_tensor.SparseTensorValue( 191 indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0], 192 [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0], 193 [2, 2, 0], [2, 3, 0]], 194 values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9], 195 dense_shape=[3, 4, 1]) 196 ] 197 self.assertDatasetProduces(dataset, expected_output=expected_output) 198 199 @combinations.generate(test_base.default_test_combinations()) 200 def testWindowShapeError(self): 201 202 def generator(): 203 yield [1.0, 2.0, 3.0] 204 yield [4.0, 5.0, 6.0] 205 yield [7.0, 8.0, 9.0, 10.0] 206 207 dataset = dataset_ops.Dataset.from_generator( 208 generator, dtypes.float32, output_shapes=[None]).window( 209 size=3, shift=1).flat_map(lambda x: x.batch(batch_size=3)) 210 self.assertDatasetProduces( 211 dataset, 212 expected_error=( 213 errors.InvalidArgumentError, 214 r"Cannot batch tensors with different shapes in component 0. " 215 r"First element had shape \[3\] and element 2 had shape \[4\].")) 216 217 @combinations.generate(test_base.default_test_combinations()) 218 def testWindowIgnoreErrors(self): 219 input_values = np.float32([1., np.nan, 2., np.nan, 3.]) 220 dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( 221 lambda x: array_ops.check_numerics(x, "message")).window( 222 size=2, shift=2, stride=2, 223 drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2)) 224 self.assertDatasetProduces( 225 dataset, expected_output=[np.float32([1., 2.]), 226 np.float32([2., 3.])]) 227 228 @combinations.generate(test_base.default_test_combinations()) 229 def testNestedOutput(self): 230 if not context.executing_eagerly(): 231 self.skipTest("self.evaluate() does not work with a dataset") 232 dataset = dataset_ops.Dataset.range(100) 233 dataset = dataset_ops.Dataset.zip((dataset, dataset)).window(10) 234 for i, nested_dataset in enumerate(dataset): 235 x, y = nested_dataset 236 self.assertDatasetProduces(x, range(i*10, (i+1)*10)) 237 self.assertDatasetProduces(y, range(i*10, (i+1)*10)) 238 239 @combinations.generate(test_base.default_test_combinations()) 240 def testDropRemainderOutput(self): 241 dataset = dataset_ops.Dataset.range(100) 242 dataset = dataset.window(30, drop_remainder=True) 243 dataset = dataset.flat_map(lambda x: x.batch(30)) 244 dataset = dataset.batch(4) 245 246 self.assertDatasetProduces( 247 dataset, 248 expected_output=[[[y + 30 * x for y in range(30)] for x in range(3)]]) 249 250 @combinations.generate(test_base.default_test_combinations()) 251 def testName(self): 252 dataset = dataset_ops.Dataset.from_tensors(42).window( 253 1, name="window").flat_map(lambda x: x) 254 self.assertDatasetProduces(dataset, [42]) 255 256 257class WindowCheckpointTest(checkpoint_test_base.CheckpointTestBase, 258 parameterized.TestCase): 259 260 def _build_dataset(self): 261 dataset = dataset_ops.Dataset.range(42).window(6).interleave( 262 lambda x: x, cycle_length=2, num_parallel_calls=2) 263 return dataset 264 265 @combinations.generate( 266 combinations.times(test_base.default_test_combinations(), 267 checkpoint_test_base.default_test_combinations())) 268 def test(self, verify_fn): 269 verify_fn(self, self._build_dataset, num_outputs=42) 270 271 272class WindowCheckpointTest(checkpoint_test_base.CheckpointTestBase, 273 parameterized.TestCase): 274 275 def _build_dataset(self): 276 dataset = dataset_ops.Dataset.range(42).window(6).interleave( 277 lambda x: x, cycle_length=2, num_parallel_calls=2) 278 return dataset 279 280 @combinations.generate( 281 combinations.times(test_base.default_test_combinations(), 282 checkpoint_test_base.default_test_combinations())) 283 def test(self, verify_fn): 284 verify_fn(self, self._build_dataset, num_outputs=42) 285 286 287if __name__ == "__main__": 288 test.main() 289