• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.contrib.training.bucket."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import random
21
22import numpy as np
23from tensorflow.contrib.training.python.training import bucket_ops
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes as dtypes_lib
26from tensorflow.python.framework import errors
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import data_flow_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.platform import test
34from tensorflow.python.training import coordinator
35from tensorflow.python.training import queue_runner_impl
36
37
38def _which_bucket(bucket_edges, v):
39  """Identify which bucket v falls into.
40
41  Args:
42    bucket_edges: int array, bucket edges
43    v: int scalar, index
44  Returns:
45    int scalar, the bucket.
46    If v < bucket_edges[0], return 0.
47    If bucket_edges[0] <= v < bucket_edges[1], return 1.
48    ...
49    If bucket_edges[-2] <= v < bucket_edges[-1], return len(bucket_edges).
50    If v >= bucket_edges[-1], return len(bucket_edges) + 1
51  """
52  v = np.asarray(v)
53  full = [0] + bucket_edges
54  found = np.where(np.logical_and(v >= full[:-1], v < full[1:]))[0]
55  if not found.size:
56    return len(full)
57  return found[0]
58
59
60class BucketTest(test.TestCase):
61
62  def setUp(self):
63    ops.reset_default_graph()
64
65    self.scalar_int_feed = array_ops.placeholder(dtypes_lib.int32, ())
66    self.unk_int64_feed = array_ops.placeholder(dtypes_lib.int64, (None,))
67    self.vec3_str_feed = array_ops.placeholder(dtypes_lib.string, (3,))
68    self.sparse_c = sparse_tensor.SparseTensor(
69        indices=[[0]],
70        values=[1.0],
71        dense_shape=[1])
72
73    self._coord = coordinator.Coordinator()
74    # Make capacity very large so we can feed all the inputs in the
75    # main thread without blocking
76    input_queue = data_flow_ops.PaddingFIFOQueue(
77        5000,
78        dtypes=[dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.string],
79        shapes=[(), (None,), (3,)])
80
81    self._input_enqueue_op = input_queue.enqueue(
82        (self.scalar_int_feed, self.unk_int64_feed, self.vec3_str_feed))
83    self.scalar_int, self.unk_int64, self.vec3_str = input_queue.dequeue()
84    self._threads = None
85    self._close_op = input_queue.close()
86    self._sess = None
87
88  def enqueue_inputs(self, sess, feed_dict):
89    sess.run(self._input_enqueue_op, feed_dict=feed_dict)
90
91  def start_queue_runners(self, sess):
92    # Store session to be able to close inputs later
93    if self._sess is None:
94      self._sess = sess
95    self._threads = queue_runner_impl.start_queue_runners(coord=self._coord)
96
97  def tearDown(self):
98    if self._sess is not None:
99      self._sess.run(self._close_op)
100    self._coord.request_stop()
101    self._coord.join(self._threads)
102
103  def testSingleBucket(self):
104    bucketed_dynamic = bucket_ops.bucket(
105        tensors=[self.scalar_int, self.unk_int64, self.vec3_str, self.sparse_c],
106        which_bucket=constant_op.constant(0),
107        num_buckets=2,
108        batch_size=32,
109        num_threads=10,
110        dynamic_pad=True)
111    # Check shape inference on bucketing outputs
112    self.assertAllEqual(
113        [[32], [32, None], [32, 3], [None, None]],
114        [out.get_shape().as_list() for out in bucketed_dynamic[1]])
115    with self.test_session() as sess:
116      for v in range(32):
117        self.enqueue_inputs(sess, {
118            self.scalar_int_feed: v,
119            self.unk_int64_feed: v * [v],
120            self.vec3_str_feed: 3 * [str(v)]
121        })
122      self.start_queue_runners(sess)
123
124      # Get a single minibatch
125      bucketed_values = sess.run(bucketed_dynamic)
126
127      # (which_bucket, bucket_tensors).
128      self.assertEqual(2, len(bucketed_values))
129
130      # Count number of bucket_tensors.
131      self.assertEqual(4, len(bucketed_values[1]))
132
133      # Ensure bucket 0 was used for all minibatch entries.
134      self.assertAllEqual(0, bucketed_values[0])
135
136      expected_scalar_int = np.arange(32)
137      expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
138      for i in range(32):
139        expected_unk_int64[i, :i] = i
140      expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
141
142      # Must resort the output because num_threads > 1 leads to
143      # sometimes-inconsistent insertion order.
144      resort = np.argsort(bucketed_values[1][0])
145      self.assertAllEqual(expected_scalar_int, bucketed_values[1][0][resort])
146      self.assertAllEqual(expected_unk_int64, bucketed_values[1][1][resort])
147      self.assertAllEqual(expected_vec3_str, bucketed_values[1][2][resort])
148
149  def testBatchSizePerBucket(self):
150    which_bucket = control_flow_ops.cond(self.scalar_int < 5,
151                                         lambda: constant_op.constant(0),
152                                         lambda: constant_op.constant(1))
153    batch_sizes = [5, 10]
154    bucketed_dynamic = bucket_ops.bucket(
155        tensors=[self.scalar_int, self.unk_int64, self.vec3_str, self.sparse_c],
156        which_bucket=which_bucket,
157        num_buckets=2,
158        batch_size=batch_sizes,
159        num_threads=1,
160        dynamic_pad=True)
161    # Check shape inference on bucketing outputs
162    self.assertAllEqual(
163        [[None], [None, None], [None, 3], [None, None]],
164        [out.get_shape().as_list() for out in bucketed_dynamic[1]])
165    with self.test_session() as sess:
166      for v in range(15):
167        self.enqueue_inputs(sess, {
168            self.scalar_int_feed: v,
169            self.unk_int64_feed: v * [v],
170            self.vec3_str_feed: 3 * [str(v)]
171        })
172      self.start_queue_runners(sess)
173
174      # Get two minibatches (one with small values, one with large).
175      bucketed_values_0 = sess.run(bucketed_dynamic)
176      bucketed_values_1 = sess.run(bucketed_dynamic)
177
178      # Figure out which output has the small values
179      if bucketed_values_0[0] < 5:
180        bucketed_values_large, bucketed_values_small = (bucketed_values_1,
181                                                        bucketed_values_0)
182      else:
183        bucketed_values_small, bucketed_values_large = (bucketed_values_0,
184                                                        bucketed_values_1)
185
186      # Ensure bucket 0 was used for all minibatch entries.
187      self.assertAllEqual(0, bucketed_values_small[0])
188      self.assertAllEqual(1, bucketed_values_large[0])
189
190      # Check that the batch sizes differ per bucket
191      self.assertEqual(5, len(bucketed_values_small[1][0]))
192      self.assertEqual(10, len(bucketed_values_large[1][0]))
193
194  def testEvenOddBuckets(self):
195    which_bucket = (self.scalar_int % 2)
196    bucketed_dynamic = bucket_ops.bucket(
197        tensors=[self.scalar_int, self.unk_int64, self.vec3_str, self.sparse_c],
198        which_bucket=which_bucket,
199        num_buckets=2,
200        batch_size=32,
201        num_threads=10,
202        dynamic_pad=True)
203    # Check shape inference on bucketing outputs
204    self.assertAllEqual(
205        [[32], [32, None], [32, 3], [None, None]],
206        [out.get_shape().as_list() for out in bucketed_dynamic[1]])
207    with self.test_session() as sess:
208      for v in range(64):
209        self.enqueue_inputs(sess, {
210            self.scalar_int_feed: v,
211            self.unk_int64_feed: v * [v],
212            self.vec3_str_feed: 3 * [str(v)]
213        })
214      self.start_queue_runners(sess)
215
216      # Get two minibatches (one containing even values, one containing odds)
217      bucketed_values_0 = sess.run(bucketed_dynamic)
218      bucketed_values_1 = sess.run(bucketed_dynamic)
219
220      # (which_bucket, bucket_tensors).
221      self.assertEqual(2, len(bucketed_values_0))
222      self.assertEqual(2, len(bucketed_values_1))
223
224      # Count number of bucket_tensors.
225      self.assertEqual(4, len(bucketed_values_0[1]))
226      self.assertEqual(4, len(bucketed_values_1[1]))
227
228      # Figure out which output has the even values (there's
229      # randomness due to the multithreaded nature of bucketing)
230      if bucketed_values_0[0] % 2 == 1:
231        bucketed_values_even, bucketed_values_odd = (bucketed_values_1,
232                                                     bucketed_values_0)
233      else:
234        bucketed_values_even, bucketed_values_odd = (bucketed_values_0,
235                                                     bucketed_values_1)
236
237      # Ensure bucket 0 was used for all minibatch entries.
238      self.assertAllEqual(0, bucketed_values_even[0])
239      self.assertAllEqual(1, bucketed_values_odd[0])
240
241      # Test the first bucket outputted, the events starting at 0
242      expected_scalar_int = np.arange(0, 32 * 2, 2)
243      expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
244      for i in range(0, 32):
245        expected_unk_int64[i, :2 * i] = 2 * i
246      expected_vec3_str = np.vstack(3 *
247                                    [np.arange(0, 32 * 2, 2).astype(bytes)]).T
248
249      # Must resort the output because num_threads > 1 leads to
250      # sometimes-inconsistent insertion order.
251      resort = np.argsort(bucketed_values_even[1][0])
252      self.assertAllEqual(expected_scalar_int,
253                          bucketed_values_even[1][0][resort])
254      self.assertAllEqual(expected_unk_int64,
255                          bucketed_values_even[1][1][resort])
256      self.assertAllEqual(expected_vec3_str, bucketed_values_even[1][2][resort])
257
258      # Test the second bucket outputted, the odds starting at 1
259      expected_scalar_int = np.arange(1, 32 * 2 + 1, 2)
260      expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
261      for i in range(0, 32):
262        expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
263      expected_vec3_str = np.vstack(
264          3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
265
266      # Must resort the output because num_threads > 1 leads to
267      # sometimes-inconsistent insertion order.
268      resort = np.argsort(bucketed_values_odd[1][0])
269      self.assertAllEqual(expected_scalar_int,
270                          bucketed_values_odd[1][0][resort])
271      self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1][1][resort])
272      self.assertAllEqual(expected_vec3_str, bucketed_values_odd[1][2][resort])
273
274  def testEvenOddBucketsFilterOutAllOdd(self):
275    which_bucket = (self.scalar_int % 2)
276    keep_input = math_ops.equal(which_bucket, 0)
277    bucketed_dynamic = bucket_ops.bucket(
278        tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
279        which_bucket=which_bucket,
280        num_buckets=2,
281        batch_size=32,
282        num_threads=10,
283        keep_input=keep_input,
284        dynamic_pad=True)
285    # Check shape inference on bucketing outputs
286    self.assertAllEqual(
287        [[32], [32, None], [32, 3]],
288        [out.get_shape().as_list() for out in bucketed_dynamic[1]])
289    with self.test_session() as sess:
290      for v in range(128):
291        self.enqueue_inputs(sess, {
292            self.scalar_int_feed: v,
293            self.unk_int64_feed: v * [v],
294            self.vec3_str_feed: 3 * [str(v)]
295        })
296      self.start_queue_runners(sess)
297
298      # Get two minibatches ([0, 2, ...] and [64, 66, ...])
299      bucketed_values_even0 = sess.run(bucketed_dynamic)
300      bucketed_values_even1 = sess.run(bucketed_dynamic)
301
302      # Ensure that bucket 1 was completely filtered out
303      self.assertAllEqual(0, bucketed_values_even0[0])
304      self.assertAllEqual(0, bucketed_values_even1[0])
305
306      # Merge their output for sorting and comparison
307      bucketed_values_all_elem0 = np.concatenate((bucketed_values_even0[1][0],
308                                                  bucketed_values_even1[1][0]))
309
310      self.assertAllEqual(
311          np.arange(0, 128, 2), sorted(bucketed_values_all_elem0))
312
313  def testFailOnWrongBucketCapacities(self):
314    with self.assertRaisesRegexp(ValueError, r"must have exactly num_buckets"):
315      bucket_ops.bucket(  # 2 buckets and 3 capacities raises ValueError.
316          tensors=[self.scalar_int, self.unk_int64, self.vec3_str],
317          which_bucket=constant_op.constant(0), num_buckets=2,
318          batch_size=32, bucket_capacities=[3, 4, 5])
319
320
321class BucketBySequenceLengthTest(test.TestCase):
322
323  def _testBucketBySequenceLength(self,
324                                  allow_small_batch,
325                                  bucket_capacities=None,
326                                  drain_entire_queue=True):
327    ops.reset_default_graph()
328
329    # All inputs must be identical lengths across tuple index.
330    # The input reader will get input_length from the first tuple
331    # entry.
332    data_len = 4
333    labels_len = 3
334    input_pairs = [(length, ([np.int64(length)] * data_len,
335                             [str(length).encode("ascii")] * labels_len))
336                   for length in (1, 3, 4, 5, 6, 10)]
337
338    lengths = array_ops.placeholder(dtypes_lib.int32, ())
339    data = array_ops.placeholder(dtypes_lib.int64, (data_len,))
340    labels = array_ops.placeholder(dtypes_lib.string, (labels_len,))
341
342    batch_size = 8
343    bucket_boundaries = [3, 4, 5, 10]
344    num_pairs_to_enqueue = 50 * batch_size + 100
345
346    # Make capacity very large so we can feed all the inputs in the
347    # main thread without blocking
348    input_queue = data_flow_ops.FIFOQueue(
349        5000, (dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.string), (
350            (), (data_len,), (labels_len,)))
351    input_enqueue_op = input_queue.enqueue((lengths, data, labels))
352    lengths_t, data_t, labels_t = input_queue.dequeue()
353    close_input_op = input_queue.close()
354
355    (out_lengths_t, data_and_labels_t) = (bucket_ops.bucket_by_sequence_length(
356        input_length=lengths_t,
357        tensors=[data_t, labels_t],
358        batch_size=batch_size,
359        bucket_boundaries=bucket_boundaries,
360        bucket_capacities=bucket_capacities,
361        allow_smaller_final_batch=allow_small_batch,
362        num_threads=10))
363
364    expected_batch_size = None if allow_small_batch else batch_size
365    self.assertEqual(out_lengths_t.get_shape().as_list(), [expected_batch_size])
366    self.assertEqual(data_and_labels_t[0].get_shape().as_list(),
367                     [expected_batch_size, data_len])
368    self.assertEqual(data_and_labels_t[1].get_shape().as_list(),
369                     [expected_batch_size, labels_len])
370
371    def _read_test(sess):
372      num_pairs_dequeued = 0
373      try:
374        while drain_entire_queue or num_pairs_dequeued < 40 * batch_size:
375          (out_lengths, (data, labels)) = sess.run(
376              (out_lengths_t, data_and_labels_t))
377          num_pairs_dequeued += out_lengths.shape[0]
378          if allow_small_batch:
379            self.assertEqual(data_len, data.shape[1])
380            self.assertEqual(labels_len, labels.shape[1])
381            self.assertGreaterEqual(batch_size, out_lengths.shape[0])
382            self.assertGreaterEqual(batch_size, data.shape[0])
383            self.assertGreaterEqual(batch_size, labels.shape[0])
384          else:
385            self.assertEqual((batch_size, data_len), data.shape)
386            self.assertEqual((batch_size, labels_len), labels.shape)
387            self.assertEqual((batch_size,), out_lengths.shape)
388          for (lr, dr, tr) in zip(out_lengths, data, labels):
389            # Make sure length matches data (here it's the same value).
390            self.assertEqual(dr[0], lr)
391            # Make sure data & labels match.
392            self.assertEqual(dr[0], int(tr[0].decode("ascii")))
393            # Make sure for each row, data came from the same bucket.
394            self.assertEqual(
395                _which_bucket(bucket_boundaries, dr[0]),
396                _which_bucket(bucket_boundaries, dr[1]))
397      except errors.OutOfRangeError:
398        if allow_small_batch:
399          self.assertEqual(num_pairs_to_enqueue, num_pairs_dequeued)
400        else:
401          # Maximum left over in the queues should be at most one less than the
402          # batch_size, for every bucket.
403          num_buckets = len(bucket_boundaries) + 2
404          self.assertLessEqual(
405              num_pairs_to_enqueue - (batch_size - 1) * num_buckets,
406              num_pairs_dequeued)
407
408    with self.test_session() as sess:
409      coord = coordinator.Coordinator()
410
411      # Feed the inputs, then close the input thread.
412      for _ in range(num_pairs_to_enqueue):
413        which = random.randint(0, len(input_pairs) - 1)
414        length, pair = input_pairs[which]
415        sess.run(input_enqueue_op,
416                 feed_dict={lengths: length,
417                            data: pair[0],
418                            labels: pair[1]})
419      sess.run(close_input_op)
420
421      # Start the queue runners
422      threads = queue_runner_impl.start_queue_runners(coord=coord)
423      # Read off the top of the bucket and ensure correctness of output
424      _read_test(sess)
425      coord.request_stop()
426      coord.join(threads)
427
428  def testBucketBySequenceLength(self):
429    self._testBucketBySequenceLength(allow_small_batch=False)
430
431  def testBucketBySequenceLengthAllow(self):
432    self._testBucketBySequenceLength(allow_small_batch=True)
433
434  def testBucketBySequenceLengthBucketCapacities(self):
435    # Above bucket_boundaries = [3, 4, 5, 10] so we need 5 capacities.
436    with self.assertRaisesRegexp(ValueError, r"must have exactly num_buckets"):
437      self._testBucketBySequenceLength(allow_small_batch=False,
438                                       bucket_capacities=[32, 32, 32, 32])
439    # Test with different capacities.
440    capacities = [48, 40, 32, 24, 16]
441    self._testBucketBySequenceLength(allow_small_batch=True,
442                                     bucket_capacities=capacities)
443
444  def testBucketBySequenceLengthShutdown(self):
445    self._testBucketBySequenceLength(allow_small_batch=True,
446                                     drain_entire_queue=False)
447
448
449if __name__ == "__main__":
450  test.main()
451