• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the experimental input pipeline ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
23from tensorflow.contrib.data.python.ops import grouping
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import string_ops
33from tensorflow.python.platform import test
34
35
36class GroupByWindowTest(test.TestCase):
37
38  def testSimple(self):
39    components = np.random.randint(100, size=(200,)).astype(np.int64)
40    iterator = (
41        dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
42        .apply(
43            grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
44                                     4)).make_initializable_iterator())
45    init_op = iterator.initializer
46    get_next = iterator.get_next()
47
48    with self.test_session() as sess:
49      sess.run(init_op)
50      counts = []
51      with self.assertRaises(errors.OutOfRangeError):
52        while True:
53          result = sess.run(get_next)
54          self.assertTrue(
55              all(x % 2 == 0
56                  for x in result) or all(x % 2 == 1)
57              for x in result)
58          counts.append(result.shape[0])
59
60      self.assertEqual(len(components), sum(counts))
61      num_full_batches = len([c for c in counts if c == 4])
62      self.assertGreaterEqual(num_full_batches, 23)
63      self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
64
65  def testImmediateOutput(self):
66    components = np.array(
67        [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
68    iterator = (
69        dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
70            grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
71                                     4)).make_initializable_iterator())
72    init_op = iterator.initializer
73    get_next = iterator.get_next()
74
75    with self.test_session() as sess:
76      sess.run(init_op)
77      # The input is infinite, so this test demonstrates that:
78      # 1. We produce output without having to consume the entire input,
79      # 2. Different buckets can produce output at different rates, and
80      # 3. For deterministic input, the output is deterministic.
81      for _ in range(3):
82        self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
83        self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
84        self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
85        self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
86
87  def testSmallGroups(self):
88    components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
89    iterator = (
90        dataset_ops.Dataset.from_tensor_slices(components).apply(
91            grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
92                                     4)).make_initializable_iterator())
93    init_op = iterator.initializer
94    get_next = iterator.get_next()
95
96    with self.test_session() as sess:
97      sess.run(init_op)
98      self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
99      self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
100      # The small outputs at the end are deterministically produced in key
101      # order.
102      self.assertAllEqual([0, 0, 0], sess.run(get_next))
103      self.assertAllEqual([1], sess.run(get_next))
104
105  def testReduceFuncError(self):
106    components = np.random.randint(100, size=(200,)).astype(np.int64)
107
108    def reduce_func(_, xs):
109      # Introduce an incorrect padded shape that cannot (currently) be
110      # detected at graph construction time.
111      return xs.padded_batch(
112          4,
113          padded_shapes=(tensor_shape.TensorShape([]),
114                         constant_op.constant([5], dtype=dtypes.int64) * -1))
115
116    iterator = (
117        dataset_ops.Dataset.from_tensor_slices(components)
118        .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
119            grouping.group_by_window(lambda x, _: x % 2, reduce_func,
120                                     32)).make_initializable_iterator())
121    init_op = iterator.initializer
122    get_next = iterator.get_next()
123
124    with self.test_session() as sess:
125      sess.run(init_op)
126      with self.assertRaises(errors.InvalidArgumentError):
127        sess.run(get_next)
128
129  def testConsumeWindowDatasetMoreThanOnce(self):
130    components = np.random.randint(50, size=(200,)).astype(np.int64)
131
132    def reduce_func(key, window):
133      # Apply two different kinds of padding to the input: tight
134      # padding, and quantized (to a multiple of 10) padding.
135      return dataset_ops.Dataset.zip((
136          window.padded_batch(
137              4, padded_shapes=tensor_shape.TensorShape([None])),
138          window.padded_batch(
139              4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
140      ))
141
142    iterator = (
143        dataset_ops.Dataset.from_tensor_slices(components)
144        .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
145        .apply(grouping.group_by_window(
146            lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
147            reduce_func, 4))
148        .make_initializable_iterator())
149    init_op = iterator.initializer
150    get_next = iterator.get_next()
151
152    with self.test_session() as sess:
153      sess.run(init_op)
154      counts = []
155      with self.assertRaises(errors.OutOfRangeError):
156        while True:
157          tight_result, multiple_of_10_result = sess.run(get_next)
158          self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
159          self.assertAllEqual(tight_result,
160                              multiple_of_10_result[:, :tight_result.shape[1]])
161          counts.append(tight_result.shape[0])
162      self.assertEqual(len(components), sum(counts))
163
164
165class GroupByWindowSerializationTest(
166    dataset_serialization_test_base.DatasetSerializationTestBase):
167
168  def _build_dataset(self, components):
169    return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
170        grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
171
172  def testCoreGroupByWindow(self):
173    components = np.array(
174        [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
175    self.verify_unused_iterator(
176        lambda: self._build_dataset(components), 12, verify_exhausted=False)
177    self.verify_init_before_restore(
178        lambda: self._build_dataset(components), 12, verify_exhausted=False)
179    self.verify_multiple_breaks(
180        lambda: self._build_dataset(components), 12, verify_exhausted=False)
181    self.verify_reset_restored_iterator(
182        lambda: self._build_dataset(components), 12, verify_exhausted=False)
183    self.verify_restore_in_empty_graph(
184        lambda: self._build_dataset(components), 12, verify_exhausted=False)
185    diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)
186    self.verify_restore_in_modified_graph(
187        lambda: self._build_dataset(components),
188        lambda: self._build_dataset(diff_components),
189        12,
190        verify_exhausted=False)
191
192
193# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
194# Currently, they use a constant batch size, though should be made to use a
195# different batch size per key.
196class BucketTest(test.TestCase):
197
198  def _dynamicPad(self, bucket, window, window_size):
199    # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
200    # generic form of padded_batch that pads every component
201    # dynamically and does not rely on static shape information about
202    # the arguments.
203    return dataset_ops.Dataset.zip(
204        (dataset_ops.Dataset.from_tensors(bucket),
205         window.padded_batch(
206             32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
207                 [None]), tensor_shape.TensorShape([3])))))
208
209  def testSingleBucket(self):
210
211    def _map_fn(v):
212      return (v, array_ops.fill([v], v),
213              array_ops.fill([3], string_ops.as_string(v)))
214
215    input_dataset = (
216        dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
217
218    bucketed_dataset = input_dataset.apply(
219        grouping.group_by_window(
220            lambda x, y, z: 0,
221            lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
222
223    iterator = bucketed_dataset.make_initializable_iterator()
224    init_op = iterator.initializer
225    get_next = iterator.get_next()
226
227    with self.test_session() as sess:
228      sess.run(init_op)
229
230      which_bucket, bucketed_values = sess.run(get_next)
231
232      self.assertEqual(0, which_bucket)
233
234      expected_scalar_int = np.arange(32, dtype=np.int64)
235      expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
236      for i in range(32):
237        expected_unk_int64[i, :i] = i
238      expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
239
240      self.assertAllEqual(expected_scalar_int, bucketed_values[0])
241      self.assertAllEqual(expected_unk_int64, bucketed_values[1])
242      self.assertAllEqual(expected_vec3_str, bucketed_values[2])
243
244  def testEvenOddBuckets(self):
245
246    def _map_fn(v):
247      return (v, array_ops.fill([v], v),
248              array_ops.fill([3], string_ops.as_string(v)))
249
250    input_dataset = (
251        dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
252
253    bucketed_dataset = input_dataset.apply(
254        grouping.group_by_window(
255            lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
256            lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
257
258    iterator = bucketed_dataset.make_initializable_iterator()
259    init_op = iterator.initializer
260    get_next = iterator.get_next()
261
262    with self.test_session() as sess:
263      sess.run(init_op)
264
265      # Get two minibatches (one containing even values, one containing odds)
266      which_bucket_even, bucketed_values_even = sess.run(get_next)
267      which_bucket_odd, bucketed_values_odd = sess.run(get_next)
268
269      # Count number of bucket_tensors.
270      self.assertEqual(3, len(bucketed_values_even))
271      self.assertEqual(3, len(bucketed_values_odd))
272
273      # Ensure bucket 0 was used for all minibatch entries.
274      self.assertAllEqual(0, which_bucket_even)
275      self.assertAllEqual(1, which_bucket_odd)
276
277      # Test the first bucket outputted, the events starting at 0
278      expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
279      expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
280      for i in range(0, 32):
281        expected_unk_int64[i, :2 * i] = 2 * i
282        expected_vec3_str = np.vstack(
283            3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
284
285      self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
286      self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
287      self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
288
289      # Test the second bucket outputted, the odds starting at 1
290      expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
291      expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
292      for i in range(0, 32):
293        expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
294        expected_vec3_str = np.vstack(
295            3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
296
297      self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
298      self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
299      self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
300
301  def testEvenOddBucketsFilterOutAllOdd(self):
302
303    def _map_fn(v):
304      return {
305          "x": v,
306          "y": array_ops.fill([v], v),
307          "z": array_ops.fill([3], string_ops.as_string(v))
308      }
309
310    def _dynamic_pad_fn(bucket, window, _):
311      return dataset_ops.Dataset.zip(
312          (dataset_ops.Dataset.from_tensors(bucket),
313           window.padded_batch(
314               32, {
315                   "x": tensor_shape.TensorShape([]),
316                   "y": tensor_shape.TensorShape([None]),
317                   "z": tensor_shape.TensorShape([3])
318               })))
319
320    input_dataset = (
321        dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
322        .filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
323
324    bucketed_dataset = input_dataset.apply(
325        grouping.group_by_window(
326            lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
327            lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
328
329    iterator = bucketed_dataset.make_initializable_iterator()
330    init_op = iterator.initializer
331    get_next = iterator.get_next()
332
333    with self.test_session() as sess:
334      sess.run(init_op)
335
336      # Get two minibatches ([0, 2, ...] and [64, 66, ...])
337      which_bucket0, bucketed_values_even0 = sess.run(get_next)
338      which_bucket1, bucketed_values_even1 = sess.run(get_next)
339
340      # Ensure that bucket 1 was completely filtered out
341      self.assertAllEqual(0, which_bucket0)
342      self.assertAllEqual(0, which_bucket1)
343      self.assertAllEqual(
344          np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
345      self.assertAllEqual(
346          np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
347
348  def testDynamicWindowSize(self):
349    components = np.arange(100).astype(np.int64)
350
351    # Key fn: even/odd
352    # Reduce fn: batches of 5
353    # Window size fn: even=5, odd=10
354
355    def window_size_func(key):
356      window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
357      return window_sizes[key]
358
359    dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
360        grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
361                                 None, window_size_func))
362    iterator = dataset.make_initializable_iterator()
363    init_op = iterator.initializer
364    get_next = iterator.get_next()
365
366    with self.test_session() as sess:
367      sess.run(init_op)
368      with self.assertRaises(errors.OutOfRangeError):
369        batches = 0
370        while True:
371          result = sess.run(get_next)
372          is_even = all(x % 2 == 0 for x in result)
373          is_odd = all(x % 2 == 1 for x in result)
374          self.assertTrue(is_even or is_odd)
375          expected_batch_size = 5 if is_even else 10
376          self.assertEqual(expected_batch_size, result.shape[0])
377          batches += 1
378
379      self.assertEqual(batches, 15)
380
381
382if __name__ == "__main__":
383  test.main()
384