1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf.contrib.training.bucket.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import random 21 22import numpy as np 23from tensorflow.contrib.training.python.training import bucket_ops 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes as dtypes_lib 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import data_flow_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.platform import test 34from tensorflow.python.training import coordinator 35from tensorflow.python.training import queue_runner_impl 36 37 38def _which_bucket(bucket_edges, v): 39 """Identify which bucket v falls into. 40 41 Args: 42 bucket_edges: int array, bucket edges 43 v: int scalar, index 44 Returns: 45 int scalar, the bucket. 46 If v < bucket_edges[0], return 0. 47 If bucket_edges[0] <= v < bucket_edges[1], return 1. 48 ... 49 If bucket_edges[-2] <= v < bucket_edges[-1], return len(bucket_edges). 50 If v >= bucket_edges[-1], return len(bucket_edges) + 1 51 """ 52 v = np.asarray(v) 53 full = [0] + bucket_edges 54 found = np.where(np.logical_and(v >= full[:-1], v < full[1:]))[0] 55 if not found.size: 56 return len(full) 57 return found[0] 58 59 60class BucketTest(test.TestCase): 61 62 def setUp(self): 63 ops.reset_default_graph() 64 65 self.scalar_int_feed = array_ops.placeholder(dtypes_lib.int32, ()) 66 self.unk_int64_feed = array_ops.placeholder(dtypes_lib.int64, (None,)) 67 self.vec3_str_feed = array_ops.placeholder(dtypes_lib.string, (3,)) 68 self.sparse_c = sparse_tensor.SparseTensor( 69 indices=[[0]], 70 values=[1.0], 71 dense_shape=[1]) 72 73 self._coord = coordinator.Coordinator() 74 # Make capacity very large so we can feed all the inputs in the 75 # main thread without blocking 76 input_queue = data_flow_ops.PaddingFIFOQueue( 77 5000, 78 dtypes=[dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.string], 79 shapes=[(), (None,), (3,)]) 80 81 self._input_enqueue_op = input_queue.enqueue( 82 (self.scalar_int_feed, self.unk_int64_feed, self.vec3_str_feed)) 83 self.scalar_int, self.unk_int64, self.vec3_str = input_queue.dequeue() 84 self._threads = None 85 self._close_op = input_queue.close() 86 self._sess = None 87 88 def enqueue_inputs(self, sess, feed_dict): 89 sess.run(self._input_enqueue_op, feed_dict=feed_dict) 90 91 def start_queue_runners(self, sess): 92 # Store session to be able to close inputs later 93 if self._sess is None: 94 self._sess = sess 95 self._threads = queue_runner_impl.start_queue_runners(coord=self._coord) 96 97 def tearDown(self): 98 if self._sess is not None: 99 self._sess.run(self._close_op) 100 self._coord.request_stop() 101 self._coord.join(self._threads) 102 103 def testSingleBucket(self): 104 bucketed_dynamic = bucket_ops.bucket( 105 tensors=[self.scalar_int, self.unk_int64, self.vec3_str, self.sparse_c], 106 which_bucket=constant_op.constant(0), 107 num_buckets=2, 108 batch_size=32, 109 num_threads=10, 110 dynamic_pad=True) 111 # Check shape inference on bucketing outputs 112 self.assertAllEqual( 113 [[32], [32, None], [32, 3], [None, None]], 114 [out.get_shape().as_list() for out in bucketed_dynamic[1]]) 115 with self.test_session() as sess: 116 for v in range(32): 117 self.enqueue_inputs(sess, { 118 self.scalar_int_feed: v, 119 self.unk_int64_feed: v * [v], 120 self.vec3_str_feed: 3 * [str(v)] 121 }) 122 self.start_queue_runners(sess) 123 124 # Get a single minibatch 125 bucketed_values = sess.run(bucketed_dynamic) 126 127 # (which_bucket, bucket_tensors). 128 self.assertEqual(2, len(bucketed_values)) 129 130 # Count number of bucket_tensors. 131 self.assertEqual(4, len(bucketed_values[1])) 132 133 # Ensure bucket 0 was used for all minibatch entries. 134 self.assertAllEqual(0, bucketed_values[0]) 135 136 expected_scalar_int = np.arange(32) 137 expected_unk_int64 = np.zeros((32, 31)).astype(np.int64) 138 for i in range(32): 139 expected_unk_int64[i, :i] = i 140 expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T 141 142 # Must resort the output because num_threads > 1 leads to 143 # sometimes-inconsistent insertion order. 144 resort = np.argsort(bucketed_values[1][0]) 145 self.assertAllEqual(expected_scalar_int, bucketed_values[1][0][resort]) 146 self.assertAllEqual(expected_unk_int64, bucketed_values[1][1][resort]) 147 self.assertAllEqual(expected_vec3_str, bucketed_values[1][2][resort]) 148 149 def testBatchSizePerBucket(self): 150 which_bucket = control_flow_ops.cond(self.scalar_int < 5, 151 lambda: constant_op.constant(0), 152 lambda: constant_op.constant(1)) 153 batch_sizes = [5, 10] 154 bucketed_dynamic = bucket_ops.bucket( 155 tensors=[self.scalar_int, self.unk_int64, self.vec3_str, self.sparse_c], 156 which_bucket=which_bucket, 157 num_buckets=2, 158 batch_size=batch_sizes, 159 num_threads=1, 160 dynamic_pad=True) 161 # Check shape inference on bucketing outputs 162 self.assertAllEqual( 163 [[None], [None, None], [None, 3], [None, None]], 164 [out.get_shape().as_list() for out in bucketed_dynamic[1]]) 165 with self.test_session() as sess: 166 for v in range(15): 167 self.enqueue_inputs(sess, { 168 self.scalar_int_feed: v, 169 self.unk_int64_feed: v * [v], 170 self.vec3_str_feed: 3 * [str(v)] 171 }) 172 self.start_queue_runners(sess) 173 174 # Get two minibatches (one with small values, one with large). 175 bucketed_values_0 = sess.run(bucketed_dynamic) 176 bucketed_values_1 = sess.run(bucketed_dynamic) 177 178 # Figure out which output has the small values 179 if bucketed_values_0[0] < 5: 180 bucketed_values_large, bucketed_values_small = (bucketed_values_1, 181 bucketed_values_0) 182 else: 183 bucketed_values_small, bucketed_values_large = (bucketed_values_0, 184 bucketed_values_1) 185 186 # Ensure bucket 0 was used for all minibatch entries. 187 self.assertAllEqual(0, bucketed_values_small[0]) 188 self.assertAllEqual(1, bucketed_values_large[0]) 189 190 # Check that the batch sizes differ per bucket 191 self.assertEqual(5, len(bucketed_values_small[1][0])) 192 self.assertEqual(10, len(bucketed_values_large[1][0])) 193 194 def testEvenOddBuckets(self): 195 which_bucket = (self.scalar_int % 2) 196 bucketed_dynamic = bucket_ops.bucket( 197 tensors=[self.scalar_int, self.unk_int64, self.vec3_str, self.sparse_c], 198 which_bucket=which_bucket, 199 num_buckets=2, 200 batch_size=32, 201 num_threads=10, 202 dynamic_pad=True) 203 # Check shape inference on bucketing outputs 204 self.assertAllEqual( 205 [[32], [32, None], [32, 3], [None, None]], 206 [out.get_shape().as_list() for out in bucketed_dynamic[1]]) 207 with self.test_session() as sess: 208 for v in range(64): 209 self.enqueue_inputs(sess, { 210 self.scalar_int_feed: v, 211 self.unk_int64_feed: v * [v], 212 self.vec3_str_feed: 3 * [str(v)] 213 }) 214 self.start_queue_runners(sess) 215 216 # Get two minibatches (one containing even values, one containing odds) 217 bucketed_values_0 = sess.run(bucketed_dynamic) 218 bucketed_values_1 = sess.run(bucketed_dynamic) 219 220 # (which_bucket, bucket_tensors). 221 self.assertEqual(2, len(bucketed_values_0)) 222 self.assertEqual(2, len(bucketed_values_1)) 223 224 # Count number of bucket_tensors. 225 self.assertEqual(4, len(bucketed_values_0[1])) 226 self.assertEqual(4, len(bucketed_values_1[1])) 227 228 # Figure out which output has the even values (there's 229 # randomness due to the multithreaded nature of bucketing) 230 if bucketed_values_0[0] % 2 == 1: 231 bucketed_values_even, bucketed_values_odd = (bucketed_values_1, 232 bucketed_values_0) 233 else: 234 bucketed_values_even, bucketed_values_odd = (bucketed_values_0, 235 bucketed_values_1) 236 237 # Ensure bucket 0 was used for all minibatch entries. 238 self.assertAllEqual(0, bucketed_values_even[0]) 239 self.assertAllEqual(1, bucketed_values_odd[0]) 240 241 # Test the first bucket outputted, the events starting at 0 242 expected_scalar_int = np.arange(0, 32 * 2, 2) 243 expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64) 244 for i in range(0, 32): 245 expected_unk_int64[i, :2 * i] = 2 * i 246 expected_vec3_str = np.vstack(3 * 247 [np.arange(0, 32 * 2, 2).astype(bytes)]).T 248 249 # Must resort the output because num_threads > 1 leads to 250 # sometimes-inconsistent insertion order. 251 resort = np.argsort(bucketed_values_even[1][0]) 252 self.assertAllEqual(expected_scalar_int, 253 bucketed_values_even[1][0][resort]) 254 self.assertAllEqual(expected_unk_int64, 255 bucketed_values_even[1][1][resort]) 256 self.assertAllEqual(expected_vec3_str, bucketed_values_even[1][2][resort]) 257 258 # Test the second bucket outputted, the odds starting at 1 259 expected_scalar_int = np.arange(1, 32 * 2 + 1, 2) 260 expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64) 261 for i in range(0, 32): 262 expected_unk_int64[i, :2 * i + 1] = 2 * i + 1 263 expected_vec3_str = np.vstack( 264 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T 265 266 # Must resort the output because num_threads > 1 leads to 267 # sometimes-inconsistent insertion order. 268 resort = np.argsort(bucketed_values_odd[1][0]) 269 self.assertAllEqual(expected_scalar_int, 270 bucketed_values_odd[1][0][resort]) 271 self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1][1][resort]) 272 self.assertAllEqual(expected_vec3_str, bucketed_values_odd[1][2][resort]) 273 274 def testEvenOddBucketsFilterOutAllOdd(self): 275 which_bucket = (self.scalar_int % 2) 276 keep_input = math_ops.equal(which_bucket, 0) 277 bucketed_dynamic = bucket_ops.bucket( 278 tensors=[self.scalar_int, self.unk_int64, self.vec3_str], 279 which_bucket=which_bucket, 280 num_buckets=2, 281 batch_size=32, 282 num_threads=10, 283 keep_input=keep_input, 284 dynamic_pad=True) 285 # Check shape inference on bucketing outputs 286 self.assertAllEqual( 287 [[32], [32, None], [32, 3]], 288 [out.get_shape().as_list() for out in bucketed_dynamic[1]]) 289 with self.test_session() as sess: 290 for v in range(128): 291 self.enqueue_inputs(sess, { 292 self.scalar_int_feed: v, 293 self.unk_int64_feed: v * [v], 294 self.vec3_str_feed: 3 * [str(v)] 295 }) 296 self.start_queue_runners(sess) 297 298 # Get two minibatches ([0, 2, ...] and [64, 66, ...]) 299 bucketed_values_even0 = sess.run(bucketed_dynamic) 300 bucketed_values_even1 = sess.run(bucketed_dynamic) 301 302 # Ensure that bucket 1 was completely filtered out 303 self.assertAllEqual(0, bucketed_values_even0[0]) 304 self.assertAllEqual(0, bucketed_values_even1[0]) 305 306 # Merge their output for sorting and comparison 307 bucketed_values_all_elem0 = np.concatenate((bucketed_values_even0[1][0], 308 bucketed_values_even1[1][0])) 309 310 self.assertAllEqual( 311 np.arange(0, 128, 2), sorted(bucketed_values_all_elem0)) 312 313 def testFailOnWrongBucketCapacities(self): 314 with self.assertRaisesRegexp(ValueError, r"must have exactly num_buckets"): 315 bucket_ops.bucket( # 2 buckets and 3 capacities raises ValueError. 316 tensors=[self.scalar_int, self.unk_int64, self.vec3_str], 317 which_bucket=constant_op.constant(0), num_buckets=2, 318 batch_size=32, bucket_capacities=[3, 4, 5]) 319 320 321class BucketBySequenceLengthTest(test.TestCase): 322 323 def _testBucketBySequenceLength(self, 324 allow_small_batch, 325 bucket_capacities=None, 326 drain_entire_queue=True): 327 ops.reset_default_graph() 328 329 # All inputs must be identical lengths across tuple index. 330 # The input reader will get input_length from the first tuple 331 # entry. 332 data_len = 4 333 labels_len = 3 334 input_pairs = [(length, ([np.int64(length)] * data_len, 335 [str(length).encode("ascii")] * labels_len)) 336 for length in (1, 3, 4, 5, 6, 10)] 337 338 lengths = array_ops.placeholder(dtypes_lib.int32, ()) 339 data = array_ops.placeholder(dtypes_lib.int64, (data_len,)) 340 labels = array_ops.placeholder(dtypes_lib.string, (labels_len,)) 341 342 batch_size = 8 343 bucket_boundaries = [3, 4, 5, 10] 344 num_pairs_to_enqueue = 50 * batch_size + 100 345 346 # Make capacity very large so we can feed all the inputs in the 347 # main thread without blocking 348 input_queue = data_flow_ops.FIFOQueue( 349 5000, (dtypes_lib.int32, dtypes_lib.int64, dtypes_lib.string), ( 350 (), (data_len,), (labels_len,))) 351 input_enqueue_op = input_queue.enqueue((lengths, data, labels)) 352 lengths_t, data_t, labels_t = input_queue.dequeue() 353 close_input_op = input_queue.close() 354 355 (out_lengths_t, data_and_labels_t) = (bucket_ops.bucket_by_sequence_length( 356 input_length=lengths_t, 357 tensors=[data_t, labels_t], 358 batch_size=batch_size, 359 bucket_boundaries=bucket_boundaries, 360 bucket_capacities=bucket_capacities, 361 allow_smaller_final_batch=allow_small_batch, 362 num_threads=10)) 363 364 expected_batch_size = None if allow_small_batch else batch_size 365 self.assertEqual(out_lengths_t.get_shape().as_list(), [expected_batch_size]) 366 self.assertEqual(data_and_labels_t[0].get_shape().as_list(), 367 [expected_batch_size, data_len]) 368 self.assertEqual(data_and_labels_t[1].get_shape().as_list(), 369 [expected_batch_size, labels_len]) 370 371 def _read_test(sess): 372 num_pairs_dequeued = 0 373 try: 374 while drain_entire_queue or num_pairs_dequeued < 40 * batch_size: 375 (out_lengths, (data, labels)) = sess.run( 376 (out_lengths_t, data_and_labels_t)) 377 num_pairs_dequeued += out_lengths.shape[0] 378 if allow_small_batch: 379 self.assertEqual(data_len, data.shape[1]) 380 self.assertEqual(labels_len, labels.shape[1]) 381 self.assertGreaterEqual(batch_size, out_lengths.shape[0]) 382 self.assertGreaterEqual(batch_size, data.shape[0]) 383 self.assertGreaterEqual(batch_size, labels.shape[0]) 384 else: 385 self.assertEqual((batch_size, data_len), data.shape) 386 self.assertEqual((batch_size, labels_len), labels.shape) 387 self.assertEqual((batch_size,), out_lengths.shape) 388 for (lr, dr, tr) in zip(out_lengths, data, labels): 389 # Make sure length matches data (here it's the same value). 390 self.assertEqual(dr[0], lr) 391 # Make sure data & labels match. 392 self.assertEqual(dr[0], int(tr[0].decode("ascii"))) 393 # Make sure for each row, data came from the same bucket. 394 self.assertEqual( 395 _which_bucket(bucket_boundaries, dr[0]), 396 _which_bucket(bucket_boundaries, dr[1])) 397 except errors.OutOfRangeError: 398 if allow_small_batch: 399 self.assertEqual(num_pairs_to_enqueue, num_pairs_dequeued) 400 else: 401 # Maximum left over in the queues should be at most one less than the 402 # batch_size, for every bucket. 403 num_buckets = len(bucket_boundaries) + 2 404 self.assertLessEqual( 405 num_pairs_to_enqueue - (batch_size - 1) * num_buckets, 406 num_pairs_dequeued) 407 408 with self.test_session() as sess: 409 coord = coordinator.Coordinator() 410 411 # Feed the inputs, then close the input thread. 412 for _ in range(num_pairs_to_enqueue): 413 which = random.randint(0, len(input_pairs) - 1) 414 length, pair = input_pairs[which] 415 sess.run(input_enqueue_op, 416 feed_dict={lengths: length, 417 data: pair[0], 418 labels: pair[1]}) 419 sess.run(close_input_op) 420 421 # Start the queue runners 422 threads = queue_runner_impl.start_queue_runners(coord=coord) 423 # Read off the top of the bucket and ensure correctness of output 424 _read_test(sess) 425 coord.request_stop() 426 coord.join(threads) 427 428 def testBucketBySequenceLength(self): 429 self._testBucketBySequenceLength(allow_small_batch=False) 430 431 def testBucketBySequenceLengthAllow(self): 432 self._testBucketBySequenceLength(allow_small_batch=True) 433 434 def testBucketBySequenceLengthBucketCapacities(self): 435 # Above bucket_boundaries = [3, 4, 5, 10] so we need 5 capacities. 436 with self.assertRaisesRegexp(ValueError, r"must have exactly num_buckets"): 437 self._testBucketBySequenceLength(allow_small_batch=False, 438 bucket_capacities=[32, 32, 32, 32]) 439 # Test with different capacities. 440 capacities = [48, 40, 32, 24, 16] 441 self._testBucketBySequenceLength(allow_small_batch=True, 442 bucket_capacities=capacities) 443 444 def testBucketBySequenceLengthShutdown(self): 445 self._testBucketBySequenceLength(allow_small_batch=True, 446 drain_entire_queue=False) 447 448 449if __name__ == "__main__": 450 test.main() 451