• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.window()`."""
16from absl.testing import parameterized
17import numpy as np
18
19from tensorflow.python.data.kernel_tests import checkpoint_test_base
20from tensorflow.python.data.kernel_tests import test_base
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.data.util import nest
23from tensorflow.python.eager import context
24from tensorflow.python.framework import combinations
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import sparse_tensor
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.platform import test
31
32
33class WindowTest(test_base.DatasetTestBase, parameterized.TestCase):
34
35  @combinations.generate(
36      combinations.times(
37          test_base.default_test_combinations(),
38          combinations.combine(
39              count=20,
40              size=[10, 14, 17],
41              shift=[7, 14],
42              stride=[1, 2, 6],
43              drop_remainder=[True, False]) + combinations.combine(
44                  count=[0, 1],
45                  size=10,
46                  shift=4,
47                  stride=1,
48                  drop_remainder=[True, False])))
49  def testWindowDataset(self, count, size, shift, stride, drop_remainder):
50    """Tests a dataset that slides a window its input elements."""
51    components = (np.arange(7),
52                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
53                  np.array(37.0) * np.arange(7))
54
55    def _map_fn(x, y, z):
56      return math_ops.square(x), math_ops.square(y), math_ops.square(z)
57
58    def _flat_map_fn(x, y, z):
59      return dataset_ops.Dataset.zip((x.batch(batch_size=size),
60                                      y.batch(batch_size=size),
61                                      z.batch(batch_size=size)))
62
63    dataset = dataset_ops.Dataset.from_tensor_slices(components).map(
64        _map_fn).repeat(count).window(
65            size=size,
66            shift=shift,
67            stride=stride,
68            drop_remainder=drop_remainder).flat_map(_flat_map_fn)
69    get_next = self.getNext(dataset)
70
71    self.assertEqual([[None] + list(c.shape[1:]) for c in components],
72                     [ts.as_list() for ts in nest.flatten(
73                         dataset_ops.get_legacy_output_shapes(dataset))])
74
75    num_full_batches = max(0,
76                           (count * 7 - ((size - 1) * stride + 1)) // shift + 1)
77    for i in range(num_full_batches):
78      result = self.evaluate(get_next())
79      for component, result_component in zip(components, result):
80        for j in range(size):
81          self.assertAllEqual(component[(i * shift + j * stride) % 7]**2,
82                              result_component[j])
83    if not drop_remainder:
84      num_partial_batches = (count * 7) // shift + (
85          (count * 7) % shift > 0) - num_full_batches
86      for i in range(num_partial_batches):
87        result = self.evaluate(get_next())
88        for component, result_component in zip(components, result):
89          remaining = (count * 7) - ((num_full_batches + i) * shift)
90          num_elements = remaining // stride + ((remaining % stride) > 0)
91          for j in range(num_elements):
92            self.assertAllEqual(
93                component[((num_full_batches + i) * shift + j * stride) % 7]**2,
94                result_component[j])
95    with self.assertRaises(errors.OutOfRangeError):
96      self.evaluate(get_next())
97    with self.assertRaises(errors.OutOfRangeError):
98      self.evaluate(get_next())
99
100  @combinations.generate(
101      combinations.times(
102          test_base.default_test_combinations(),
103          combinations.combine(count=20, size=0, shift=3, stride=1) +
104          combinations.combine(count=20, size=3, shift=0, stride=1) +
105          combinations.combine(count=20, size=3, shift=3, stride=0)))
106  def testWindowDatasetInvalid(self, count, size, shift, stride):
107    with self.assertRaises(errors.InvalidArgumentError):
108      ds = dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count).window(
109          size=size, shift=shift,
110          stride=stride).flat_map(lambda x: x.batch(batch_size=size))
111      self.evaluate(ds._variant_tensor)
112
113  @combinations.generate(test_base.default_test_combinations())
114  def testWindowDifferentNestedStructures(self):
115    ds = dataset_ops.Dataset.from_tensor_slices(([1, 2], [3, 4])).window(2)
116    self.getNext(ds)
117    ds = dataset_ops.Dataset.from_tensor_slices({"a": [1, 2]}).window(2)
118    self.getNext(ds)
119
120  @combinations.generate(test_base.default_test_combinations())
121  def testWindowSparse(self):
122
123    def _sparse(i):
124      return sparse_tensor.SparseTensorValue(
125          indices=[[0]], values=(i * [1]), dense_shape=[1])
126
127    dataset = dataset_ops.Dataset.range(10).map(_sparse).window(
128        size=5, shift=3,
129        drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5))
130
131    num_batches = (10 - 5) // 3 + 1
132    expected_output = [
133        sparse_tensor.SparseTensorValue(
134            indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
135            values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
136            dense_shape=[5, 1]) for i in range(num_batches)
137    ]
138    self.assertDatasetProduces(dataset, expected_output=expected_output)
139
140  @combinations.generate(test_base.default_test_combinations())
141  def testWindowSparseWithDifferentDenseShapes(self):
142
143    def _sparse(i):
144      return sparse_tensor.SparseTensorValue(
145          indices=array_ops.expand_dims(
146              math_ops.range(i, dtype=dtypes.int64), 1),
147          values=array_ops.fill([math_ops.cast(i, dtypes.int32)], i),
148          dense_shape=[i])
149
150    dataset = dataset_ops.Dataset.range(10).map(_sparse).window(
151        size=5, shift=3,
152        drop_remainder=True).flat_map(lambda x: x.batch(batch_size=5))
153
154    expected_output = []
155    num_batches = (10 - 5) // 3 + 1
156    for i in range(num_batches):
157      expected_indices = []
158      expected_values = []
159      for j in range(5):
160        for k in range(i * 3 + j):
161          expected_indices.append([j, k])
162          expected_values.append(i * 3 + j)
163      expected_output.append(
164          sparse_tensor.SparseTensorValue(
165              indices=expected_indices,
166              values=expected_values,
167              dense_shape=[5, i * 3 + 5 - 1]))
168    self.assertDatasetProduces(dataset, expected_output=expected_output)
169
170  @combinations.generate(test_base.default_test_combinations())
171  def testNestedWindowSparse(self):
172
173    def _sparse(i):
174      return sparse_tensor.SparseTensorValue(
175          indices=[[0]], values=(i * [1]), dense_shape=[1])
176
177    dataset = dataset_ops.Dataset.range(10).map(_sparse).window(
178        size=4, shift=2,
179        drop_remainder=True).flat_map(lambda x: x.batch(batch_size=4)).window(
180            size=3, shift=1,
181            drop_remainder=True).flat_map(lambda x: x.batch(batch_size=3))
182
183    expected_output = [
184        sparse_tensor.SparseTensorValue(
185            indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
186                     [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
187                     [2, 2, 0], [2, 3, 0]],
188            values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
189            dense_shape=[3, 4, 1]),
190        sparse_tensor.SparseTensorValue(
191            indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [1, 0, 0],
192                     [1, 1, 0], [1, 2, 0], [1, 3, 0], [2, 0, 0], [2, 1, 0],
193                     [2, 2, 0], [2, 3, 0]],
194            values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
195            dense_shape=[3, 4, 1])
196    ]
197    self.assertDatasetProduces(dataset, expected_output=expected_output)
198
199  @combinations.generate(test_base.default_test_combinations())
200  def testWindowShapeError(self):
201
202    def generator():
203      yield [1.0, 2.0, 3.0]
204      yield [4.0, 5.0, 6.0]
205      yield [7.0, 8.0, 9.0, 10.0]
206
207    dataset = dataset_ops.Dataset.from_generator(
208        generator, dtypes.float32, output_shapes=[None]).window(
209            size=3, shift=1).flat_map(lambda x: x.batch(batch_size=3))
210    self.assertDatasetProduces(
211        dataset,
212        expected_error=(
213            errors.InvalidArgumentError,
214            r"Cannot batch tensors with different shapes in component 0. "
215            r"First element had shape \[3\] and element 2 had shape \[4\]."))
216
217  @combinations.generate(test_base.default_test_combinations())
218  def testWindowIgnoreErrors(self):
219    input_values = np.float32([1., np.nan, 2., np.nan, 3.])
220    dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
221        lambda x: array_ops.check_numerics(x, "message")).window(
222            size=2, shift=2, stride=2,
223            drop_remainder=True).flat_map(lambda x: x.batch(batch_size=2))
224    self.assertDatasetProduces(
225        dataset, expected_output=[np.float32([1., 2.]),
226                                  np.float32([2., 3.])])
227
228  @combinations.generate(test_base.default_test_combinations())
229  def testNestedOutput(self):
230    if not context.executing_eagerly():
231      self.skipTest("self.evaluate() does not work with a dataset")
232    dataset = dataset_ops.Dataset.range(100)
233    dataset = dataset_ops.Dataset.zip((dataset, dataset)).window(10)
234    for i, nested_dataset in enumerate(dataset):
235      x, y = nested_dataset
236      self.assertDatasetProduces(x, range(i*10, (i+1)*10))
237      self.assertDatasetProduces(y, range(i*10, (i+1)*10))
238
239  @combinations.generate(test_base.default_test_combinations())
240  def testDropRemainderOutput(self):
241    dataset = dataset_ops.Dataset.range(100)
242    dataset = dataset.window(30, drop_remainder=True)
243    dataset = dataset.flat_map(lambda x: x.batch(30))
244    dataset = dataset.batch(4)
245
246    self.assertDatasetProduces(
247        dataset,
248        expected_output=[[[y + 30 * x for y in range(30)] for x in range(3)]])
249
250  @combinations.generate(test_base.default_test_combinations())
251  def testName(self):
252    dataset = dataset_ops.Dataset.from_tensors(42).window(
253        1, name="window").flat_map(lambda x: x)
254    self.assertDatasetProduces(dataset, [42])
255
256
257class WindowCheckpointTest(checkpoint_test_base.CheckpointTestBase,
258                           parameterized.TestCase):
259
260  def _build_dataset(self):
261    dataset = dataset_ops.Dataset.range(42).window(6).interleave(
262        lambda x: x, cycle_length=2, num_parallel_calls=2)
263    return dataset
264
265  @combinations.generate(
266      combinations.times(test_base.default_test_combinations(),
267                         checkpoint_test_base.default_test_combinations()))
268  def test(self, verify_fn):
269    verify_fn(self, self._build_dataset, num_outputs=42)
270
271
272class WindowCheckpointTest(checkpoint_test_base.CheckpointTestBase,
273                           parameterized.TestCase):
274
275  def _build_dataset(self):
276    dataset = dataset_ops.Dataset.range(42).window(6).interleave(
277        lambda x: x, cycle_length=2, num_parallel_calls=2)
278    return dataset
279
280  @combinations.generate(
281      combinations.times(test_base.default_test_combinations(),
282                         checkpoint_test_base.default_test_combinations()))
283  def test(self, verify_fn):
284    verify_fn(self, self._build_dataset, num_outputs=42)
285
286
287if __name__ == "__main__":
288  test.main()
289