1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tests for the currently experimental in-graph batch ops.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import threading 22import time 23 24from tensorflow.python.eager import context 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import function 27from tensorflow.python.framework import test_util 28from tensorflow.python.framework.errors import InvalidArgumentError 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import batch_ops 31from tensorflow.python.ops import gen_batch_ops 32from tensorflow.python.ops import script_ops 33from tensorflow.python.platform import test 34 35 36def delayed_plus1(x): 37 """Sleeps for 100ms then returns x+1.""" 38 time.sleep(0.1) 39 return x + 1 40 41 42@test_util.run_all_in_graph_and_eager_modes 43class BatchOpsTest(test.TestCase): 44 """Tests for batch_ops.{un,}batch.""" 45 46 # Test for only non eager mode as batching in eager context as a functionality 47 # is TBD. 48 def testBasicBatch(self): 49 """Tests that a single batched tensor executes together and only once.""" 50 if context.executing_eagerly(): 51 return 52 with self.cached_session() as sess: 53 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 54 batched, index, _ = batch_ops.batch( 55 [inp], num_batch_threads=1, max_batch_size=2, 56 batch_timeout_micros=36000000, grad_timeout_micros=0, 57 batching_queue="") 58 thread_results = [] 59 60 def worker(): 61 thread_results.extend( 62 sess.run([batched, index], feed_dict={inp: [1]})) 63 64 worker_thread = threading.Thread(target=worker) 65 worker_thread.start() 66 main_results = sess.run([batched, index], feed_dict={inp: [2]}) 67 worker_thread.join() 68 69 # At this point either the thread or the main did the batch and the other 70 # should have empty results. 71 if list(thread_results[0][0]): 72 batch_t = thread_results[0][0] 73 index_t = thread_results[1] 74 empty_b = main_results[0][0] 75 empty_m = main_results[1] 76 else: 77 batch_t = main_results[0][0] 78 index_t = main_results[1] 79 empty_b = thread_results[0][0] 80 empty_m = thread_results[1] 81 82 # Check that both the inputs made it out exactly once. 83 self.assertAllEqual(sorted(batch_t), (1, 2)) 84 # Check that we get 2 rows in the index tensor. 85 self.assertEqual(len(index_t), 2) 86 # Check that the other ones are empty. 87 self.assertEqual(len(empty_b), 0) 88 self.assertEqual(len(empty_m), 0) 89 90 def testBatchWithPadding(self): 91 """Test that batching with padding up to an allowed batch size works.""" 92 if context.executing_eagerly(): 93 return 94 with self.cached_session() as sess: 95 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) 96 batched, index, _ = batch_ops.batch( 97 [inp], num_batch_threads=1, max_batch_size=10, 98 batch_timeout_micros=100000, # 100ms 99 allowed_batch_sizes=[5, 10], 100 grad_timeout_micros=0, batching_queue="") 101 thread_results = [] 102 103 def worker(): 104 thread_results.extend( 105 sess.run([batched, index], feed_dict={inp: [1, 3]})) 106 107 worker_thread = threading.Thread(target=worker) 108 worker_thread.start() 109 main_results = sess.run([batched, index], feed_dict={inp: [2, 4]}) 110 worker_thread.join() 111 112 # At this point either the thread or the main did the batch and the other 113 # should have empty results. 114 if list(thread_results[0][0]): 115 batch_t = thread_results[0][0] 116 else: 117 batch_t = main_results[0][0] 118 119 # Check that the batch tensor incorporates the padding. 120 self.assertEqual(len(batch_t), 5) 121 122 def testMultipleBatch(self): 123 """Tests that multiple batched tensors execute together.""" 124 if context.executing_eagerly(): 125 return 126 with self.cached_session() as sess: 127 inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 128 inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 129 batched, _, _ = batch_ops.batch( 130 [inp0, inp1], 131 num_batch_threads=1, 132 max_batch_size=2, 133 batch_timeout_micros=36000000, 134 grad_timeout_micros=0, 135 batching_queue="") 136 thread_results = [] 137 138 def worker(): 139 thread_results.extend( 140 sess.run([batched], feed_dict={inp0: [1], 141 inp1: [2]})) 142 143 worker_thread = threading.Thread(target=worker) 144 worker_thread.start() 145 main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]}) 146 worker_thread.join() 147 148 # At this point either the thread or the main did the batch and the other 149 # should have empty results. 150 if list(thread_results[0][0]): 151 batch_t = thread_results[0] 152 empty_t = main_results[0] 153 else: 154 batch_t = main_results[0] 155 empty_t = thread_results[0] 156 157 # Assert that the tensors were batched together. 158 self.assertAllEqual(sorted(batch_t[0]), [1, 2]) 159 self.assertAllEqual(sorted(batch_t[1]), [2, 3]) 160 self.assertAllEqual(empty_t[0], []) 161 self.assertAllEqual(empty_t[1], []) 162 163 def testIllegalBatchDifferentDim0Sizes(self): 164 """Tests illegally feeding tensors with different dim0 sizes.""" 165 if context.executing_eagerly(): 166 return 167 with self.cached_session() as sess: 168 inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 169 inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2]) 170 batched, index, _ = batch_ops.batch( 171 [inp0, inp1], num_batch_threads=1, max_batch_size=2, 172 batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="") 173 with self.assertRaises(Exception) as raised: 174 _ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]}) 175 self.assertGreater( 176 raised.exception.message.find("must have equal 0th-dimension size"), 177 0) 178 179 def testBasicUnbatch(self): 180 """Tests that batch and unbatch work together.""" 181 if context.executing_eagerly(): 182 return 183 with self.cached_session() as sess: 184 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 185 batched, index, id_t = batch_ops.batch( 186 [inp], num_batch_threads=1, max_batch_size=10, 187 batch_timeout_micros=100000, # 100ms 188 allowed_batch_sizes=[3, 10], 189 grad_timeout_micros=0, batching_queue="") 190 computation = batched[0] + 1 191 result = batch_ops.unbatch(computation, index, id_t, 192 timeout_micros=1000000, shared_name="unbatch") 193 thread_results = [] 194 195 def worker(): 196 thread_results.extend(sess.run([result], feed_dict={inp: [1]})) 197 198 worker_thread = threading.Thread(target=worker) 199 worker_thread.start() 200 main_results = sess.run([result], feed_dict={inp: [2]}) 201 worker_thread.join() 202 self.assertEqual(thread_results[0], [2]) 203 self.assertEqual(main_results[0], [3]) 204 205 def testBasicUnbatchDecorated(self): 206 """Tests that the batch_function decorator works.""" 207 if context.executing_eagerly(): 208 return 209 with self.cached_session() as sess: 210 # TODO(apassos): Removing this line causes test flakiness! Ideally should 211 # be investigated. 212 default_inp = array_ops.placeholder_with_default(2, shape=[]) # pylint: disable=unused-variable 213 214 @batch_ops.batch_function(1, 10, 100000) 215 def computation(in_t): 216 self.assertTrue(in_t.shape is not None) 217 return in_t + 1 218 219 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 220 result = computation(inp) 221 thread_results = [] 222 223 def worker(): 224 thread_results.extend(sess.run([result], feed_dict={inp: [1]})) 225 226 worker_thread = threading.Thread(target=worker) 227 worker_thread.start() 228 main_results = sess.run([result], feed_dict={inp: [2]}) 229 worker_thread.join() 230 self.assertEqual(thread_results[0], [2]) 231 self.assertEqual(main_results[0], [3]) 232 233 def testBatchDecoratedWithCapturedInput(self): 234 """Tests that the batch_function decorator works.""" 235 if context.executing_eagerly(): 236 return 237 with self.cached_session() as sess: 238 captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) 239 captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) 240 241 @batch_ops.batch_function(1, 10, 100000) 242 def computation(in_t): 243 return in_t + captured_inp0 - captured_inp1 244 245 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 246 result = computation(inp) 247 thread_results = [] 248 249 def worker(): 250 thread_results.extend(sess.run([result], feed_dict={inp: [1]})) 251 252 worker_thread = threading.Thread(target=worker) 253 worker_thread.start() 254 main_results = sess.run([result], feed_dict={inp: [2]}) 255 worker_thread.join() 256 self.assertEqual(thread_results[0], [2]) 257 self.assertEqual(main_results[0], [3]) 258 259 def testBatchFunctionOp(self): 260 """Tests that the batch_function op works.""" 261 if context.executing_eagerly(): 262 return 263 with self.cached_session() as sess: 264 265 @function.Defun(dtypes.int32) 266 def computation(in_t): 267 return in_t + 1 268 269 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 270 result = gen_batch_ops.batch_function( 271 [inp], 272 num_batch_threads=1, 273 max_batch_size=10, 274 batch_timeout_micros=100000, 275 Tout=[dtypes.int32], 276 f=computation, 277 captured_tensors=computation.captured_inputs) 278 thread_results = [] 279 280 def worker(): 281 thread_results.extend(sess.run([result], feed_dict={inp: [1]})) 282 283 worker_thread = threading.Thread(target=worker) 284 worker_thread.start() 285 main_results = sess.run([result], feed_dict={inp: [2]}) 286 worker_thread.join() 287 self.assertEqual(thread_results[0], [2]) 288 self.assertEqual(main_results[0], [3]) 289 290 def testBatchFunctionOpWithCapturedInput(self): 291 """Tests that batch_function op works with captured input.""" 292 if context.executing_eagerly(): 293 return 294 with self.cached_session() as sess: 295 captured_inp0 = array_ops.placeholder_with_default(2, shape=[]) 296 captured_inp1 = array_ops.placeholder_with_default(1, shape=[]) 297 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 298 299 @function.Defun(dtypes.int32) 300 def computation(inp): 301 return inp + captured_inp0 - captured_inp1 302 303 result = gen_batch_ops.batch_function( 304 num_batch_threads=1, 305 max_batch_size=10, 306 batch_timeout_micros=100000, # 100ms 307 allowed_batch_sizes=[3, 10], 308 batching_queue="", 309 f=computation, 310 in_tensors=[inp], 311 captured_tensors=computation.captured_inputs, 312 Tout=[o.type for o in computation.definition.signature.output_arg]) 313 314 thread_results = [] 315 316 def worker(): 317 thread_results.extend(sess.run([result], feed_dict={inp: [1]})) 318 319 worker_thread = threading.Thread(target=worker) 320 worker_thread.start() 321 main_results = sess.run([result], feed_dict={inp: [2]}) 322 worker_thread.join() 323 self.assertEqual(thread_results[0], [2]) 324 self.assertEqual(main_results[0], [3]) 325 326 def testBatchFunctionOpWithInputError(self): 327 """Tests that batch_function op works with error in the inputs.""" 328 if context.executing_eagerly(): 329 return 330 with self.cached_session() as sess: 331 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 332 333 @function.Defun(dtypes.int32, dtypes.int32) 334 def computation(in0, in1): 335 return in0 + in1 336 337 result = gen_batch_ops.batch_function( 338 [inp], # computation actually expects 2 inputs. 339 num_batch_threads=1, 340 max_batch_size=10, 341 batch_timeout_micros=100000, # 100ms 342 batching_queue="", 343 f=computation, 344 captured_tensors=computation.captured_inputs, 345 Tout=[o.type for o in computation.definition.signature.output_arg]) 346 347 with self.assertRaisesRegexp(InvalidArgumentError, 348 ".*2 arguments.*but 1.*"): 349 sess.run([result], feed_dict={inp: [2]}) 350 351 def testBasicUnbatchDecoratedWithReshape(self): 352 """Tests that the batch_function decorator works.""" 353 if context.executing_eagerly(): 354 return 355 with self.cached_session() as sess: 356 357 @batch_ops.batch_function(1, 10, 100000) 358 def computation(in_t): 359 return array_ops.reshape(in_t, [-1]) + 1 360 361 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1]) 362 result = computation(inp) 363 thread_results = [] 364 365 def worker(): 366 thread_results.extend(sess.run([result], feed_dict={inp: [[1]]})) 367 368 worker_thread = threading.Thread(target=worker) 369 worker_thread.start() 370 main_results = sess.run([result], feed_dict={inp: [[2]]}) 371 worker_thread.join() 372 self.assertEqual(thread_results[0], [2]) 373 self.assertEqual(main_results[0], [3]) 374 375 def testUnbatchTimeout(self): 376 """Tests that the unbatch timeout works.""" 377 if context.executing_eagerly(): 378 return 379 with self.cached_session() as sess: 380 inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1]) 381 batched, index, id_t = batch_ops.batch( 382 [inp], num_batch_threads=1, max_batch_size=2, 383 batch_timeout_micros=36000000, grad_timeout_micros=0, 384 batching_queue="") 385 computation = batched[0] + 1 386 timeout_micros = 10 387 result = batch_ops.unbatch(computation, index, id_t, timeout_micros, 388 shared_name="shared_unbatch") 389 # Set up a parallel pipeline that delays the computation, but uses the 390 # same unbatch resource object as the non-delayed pipeline. 391 computation_delayed = script_ops.py_func(delayed_plus1, 392 [batched[0]], 393 dtypes.int32) 394 result_delayed = batch_ops.unbatch(computation_delayed, 395 index, 396 id_t, 397 timeout_micros, 398 shared_name="shared_unbatch") 399 400 thread_results = [] 401 def worker(): 402 # A first call using the non-delayed pipeline. The batcher will send an 403 # empty tensor along the non-delayed pipeline. 404 thread_results.extend(sess.run([result], feed_dict={inp: [1]})) 405 worker_thread = threading.Thread(target=worker) 406 worker_thread.start() 407 time.sleep(0.1) # Ensure the thread's call starts first. 408 # A second call using the delayed pipeline. The batcher will send the 409 # batched tensor along the delayed pipeline, thus delaying the arrival of 410 # the batched tensor at the unbatch op, relative to the empty tensor. 411 # 412 # TODO(olston, apassos): Avoid relying on the order in which the batch op 413 # emits the empty tensor versus the batched one. 414 _ = sess.run([result_delayed], feed_dict={inp: [2]}) 415 worker_thread.join() 416 # The thread's call should hit the timeout, and thus get 0 results. 417 self.assertEqual(len(thread_results), 0) 418 419 420if __name__ == "__main__": 421 test.main() 422