• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tests for the currently experimental in-graph batch ops."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22import time
23
24from tensorflow.python.eager import context
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import function
27from tensorflow.python.framework import test_util
28from tensorflow.python.framework.errors import InvalidArgumentError
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import batch_ops
31from tensorflow.python.ops import gen_batch_ops
32from tensorflow.python.ops import script_ops
33from tensorflow.python.platform import test
34
35
36def delayed_plus1(x):
37  """Sleeps for 100ms then returns x+1."""
38  time.sleep(0.1)
39  return x + 1
40
41
42@test_util.run_all_in_graph_and_eager_modes
43class BatchOpsTest(test.TestCase):
44  """Tests for batch_ops.{un,}batch."""
45
46  # Test for only non eager mode as batching in eager context as a functionality
47  # is TBD.
48  def testBasicBatch(self):
49    """Tests that a single batched tensor executes together and only once."""
50    if context.executing_eagerly():
51      return
52    with self.cached_session() as sess:
53      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
54      batched, index, _ = batch_ops.batch(
55          [inp], num_batch_threads=1, max_batch_size=2,
56          batch_timeout_micros=36000000, grad_timeout_micros=0,
57          batching_queue="")
58      thread_results = []
59
60      def worker():
61        thread_results.extend(
62            sess.run([batched, index], feed_dict={inp: [1]}))
63
64      worker_thread = threading.Thread(target=worker)
65      worker_thread.start()
66      main_results = sess.run([batched, index], feed_dict={inp: [2]})
67      worker_thread.join()
68
69      # At this point either the thread or the main did the batch and the other
70      # should have empty results.
71      if list(thread_results[0][0]):
72        batch_t = thread_results[0][0]
73        index_t = thread_results[1]
74        empty_b = main_results[0][0]
75        empty_m = main_results[1]
76      else:
77        batch_t = main_results[0][0]
78        index_t = main_results[1]
79        empty_b = thread_results[0][0]
80        empty_m = thread_results[1]
81
82      # Check that both the inputs made it out exactly once.
83      self.assertAllEqual(sorted(batch_t), (1, 2))
84      # Check that we get 2 rows in the index tensor.
85      self.assertEqual(len(index_t), 2)
86      # Check that the other ones are empty.
87      self.assertEqual(len(empty_b), 0)
88      self.assertEqual(len(empty_m), 0)
89
90  def testBatchWithPadding(self):
91    """Test that batching with padding up to an allowed batch size works."""
92    if context.executing_eagerly():
93      return
94    with self.cached_session() as sess:
95      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
96      batched, index, _ = batch_ops.batch(
97          [inp], num_batch_threads=1, max_batch_size=10,
98          batch_timeout_micros=100000,  # 100ms
99          allowed_batch_sizes=[5, 10],
100          grad_timeout_micros=0, batching_queue="")
101      thread_results = []
102
103      def worker():
104        thread_results.extend(
105            sess.run([batched, index], feed_dict={inp: [1, 3]}))
106
107      worker_thread = threading.Thread(target=worker)
108      worker_thread.start()
109      main_results = sess.run([batched, index], feed_dict={inp: [2, 4]})
110      worker_thread.join()
111
112      # At this point either the thread or the main did the batch and the other
113      # should have empty results.
114      if list(thread_results[0][0]):
115        batch_t = thread_results[0][0]
116      else:
117        batch_t = main_results[0][0]
118
119      # Check that the batch tensor incorporates the padding.
120      self.assertEqual(len(batch_t), 5)
121
122  def testMultipleBatch(self):
123    """Tests that multiple batched tensors execute together."""
124    if context.executing_eagerly():
125      return
126    with self.cached_session() as sess:
127      inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
128      inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
129      batched, _, _ = batch_ops.batch(
130          [inp0, inp1],
131          num_batch_threads=1,
132          max_batch_size=2,
133          batch_timeout_micros=36000000,
134          grad_timeout_micros=0,
135          batching_queue="")
136      thread_results = []
137
138      def worker():
139        thread_results.extend(
140            sess.run([batched], feed_dict={inp0: [1],
141                                           inp1: [2]}))
142
143      worker_thread = threading.Thread(target=worker)
144      worker_thread.start()
145      main_results = sess.run([batched], feed_dict={inp0: [2], inp1: [3]})
146      worker_thread.join()
147
148      # At this point either the thread or the main did the batch and the other
149      # should have empty results.
150      if list(thread_results[0][0]):
151        batch_t = thread_results[0]
152        empty_t = main_results[0]
153      else:
154        batch_t = main_results[0]
155        empty_t = thread_results[0]
156
157      # Assert that the tensors were batched together.
158      self.assertAllEqual(sorted(batch_t[0]), [1, 2])
159      self.assertAllEqual(sorted(batch_t[1]), [2, 3])
160      self.assertAllEqual(empty_t[0], [])
161      self.assertAllEqual(empty_t[1], [])
162
163  def testIllegalBatchDifferentDim0Sizes(self):
164    """Tests illegally feeding tensors with different dim0 sizes."""
165    if context.executing_eagerly():
166      return
167    with self.cached_session() as sess:
168      inp0 = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
169      inp1 = array_ops.placeholder(dtype=dtypes.int32, shape=[2])
170      batched, index, _ = batch_ops.batch(
171          [inp0, inp1], num_batch_threads=1, max_batch_size=2,
172          batch_timeout_micros=0, grad_timeout_micros=0, batching_queue="")
173      with self.assertRaises(Exception) as raised:
174        _ = sess.run([batched, index], feed_dict={inp0: [0], inp1: [1, 2]})
175      self.assertGreater(
176          raised.exception.message.find("must have equal 0th-dimension size"),
177          0)
178
179  def testBasicUnbatch(self):
180    """Tests that batch and unbatch work together."""
181    if context.executing_eagerly():
182      return
183    with self.cached_session() as sess:
184      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
185      batched, index, id_t = batch_ops.batch(
186          [inp], num_batch_threads=1, max_batch_size=10,
187          batch_timeout_micros=100000,  # 100ms
188          allowed_batch_sizes=[3, 10],
189          grad_timeout_micros=0, batching_queue="")
190      computation = batched[0] + 1
191      result = batch_ops.unbatch(computation, index, id_t,
192                                 timeout_micros=1000000, shared_name="unbatch")
193      thread_results = []
194
195      def worker():
196        thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
197
198      worker_thread = threading.Thread(target=worker)
199      worker_thread.start()
200      main_results = sess.run([result], feed_dict={inp: [2]})
201      worker_thread.join()
202      self.assertEqual(thread_results[0], [2])
203      self.assertEqual(main_results[0], [3])
204
205  def testBasicUnbatchDecorated(self):
206    """Tests that the batch_function decorator works."""
207    if context.executing_eagerly():
208      return
209    with self.cached_session() as sess:
210      # TODO(apassos): Removing this line causes test flakiness! Ideally should
211      # be investigated.
212      default_inp = array_ops.placeholder_with_default(2, shape=[])  # pylint: disable=unused-variable
213
214      @batch_ops.batch_function(1, 10, 100000)
215      def computation(in_t):
216        self.assertTrue(in_t.shape is not None)
217        return in_t + 1
218
219      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
220      result = computation(inp)
221      thread_results = []
222
223      def worker():
224        thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
225
226      worker_thread = threading.Thread(target=worker)
227      worker_thread.start()
228      main_results = sess.run([result], feed_dict={inp: [2]})
229      worker_thread.join()
230      self.assertEqual(thread_results[0], [2])
231      self.assertEqual(main_results[0], [3])
232
233  def testBatchDecoratedWithCapturedInput(self):
234    """Tests that the batch_function decorator works."""
235    if context.executing_eagerly():
236      return
237    with self.cached_session() as sess:
238      captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
239      captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
240
241      @batch_ops.batch_function(1, 10, 100000)
242      def computation(in_t):
243        return in_t + captured_inp0 - captured_inp1
244
245      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
246      result = computation(inp)
247      thread_results = []
248
249      def worker():
250        thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
251
252      worker_thread = threading.Thread(target=worker)
253      worker_thread.start()
254      main_results = sess.run([result], feed_dict={inp: [2]})
255      worker_thread.join()
256      self.assertEqual(thread_results[0], [2])
257      self.assertEqual(main_results[0], [3])
258
259  def testBatchFunctionOp(self):
260    """Tests that the batch_function op works."""
261    if context.executing_eagerly():
262      return
263    with self.cached_session() as sess:
264
265      @function.Defun(dtypes.int32)
266      def computation(in_t):
267        return in_t + 1
268
269      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
270      result = gen_batch_ops.batch_function(
271          [inp],
272          num_batch_threads=1,
273          max_batch_size=10,
274          batch_timeout_micros=100000,
275          Tout=[dtypes.int32],
276          f=computation,
277          captured_tensors=computation.captured_inputs)
278      thread_results = []
279
280      def worker():
281        thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
282
283      worker_thread = threading.Thread(target=worker)
284      worker_thread.start()
285      main_results = sess.run([result], feed_dict={inp: [2]})
286      worker_thread.join()
287      self.assertEqual(thread_results[0], [2])
288      self.assertEqual(main_results[0], [3])
289
290  def testBatchFunctionOpWithCapturedInput(self):
291    """Tests that batch_function op works with captured input."""
292    if context.executing_eagerly():
293      return
294    with self.cached_session() as sess:
295      captured_inp0 = array_ops.placeholder_with_default(2, shape=[])
296      captured_inp1 = array_ops.placeholder_with_default(1, shape=[])
297      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
298
299      @function.Defun(dtypes.int32)
300      def computation(inp):
301        return inp + captured_inp0 - captured_inp1
302
303      result = gen_batch_ops.batch_function(
304          num_batch_threads=1,
305          max_batch_size=10,
306          batch_timeout_micros=100000,  # 100ms
307          allowed_batch_sizes=[3, 10],
308          batching_queue="",
309          f=computation,
310          in_tensors=[inp],
311          captured_tensors=computation.captured_inputs,
312          Tout=[o.type for o in computation.definition.signature.output_arg])
313
314      thread_results = []
315
316      def worker():
317        thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
318
319      worker_thread = threading.Thread(target=worker)
320      worker_thread.start()
321      main_results = sess.run([result], feed_dict={inp: [2]})
322      worker_thread.join()
323      self.assertEqual(thread_results[0], [2])
324      self.assertEqual(main_results[0], [3])
325
326  def testBatchFunctionOpWithInputError(self):
327    """Tests that batch_function op works with error in the inputs."""
328    if context.executing_eagerly():
329      return
330    with self.cached_session() as sess:
331      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
332
333      @function.Defun(dtypes.int32, dtypes.int32)
334      def computation(in0, in1):
335        return in0 + in1
336
337      result = gen_batch_ops.batch_function(
338          [inp],  # computation actually expects 2 inputs.
339          num_batch_threads=1,
340          max_batch_size=10,
341          batch_timeout_micros=100000,  # 100ms
342          batching_queue="",
343          f=computation,
344          captured_tensors=computation.captured_inputs,
345          Tout=[o.type for o in computation.definition.signature.output_arg])
346
347      with self.assertRaisesRegexp(InvalidArgumentError,
348                                   ".*2 arguments.*but 1.*"):
349        sess.run([result], feed_dict={inp: [2]})
350
351  def testBasicUnbatchDecoratedWithReshape(self):
352    """Tests that the batch_function decorator works."""
353    if context.executing_eagerly():
354      return
355    with self.cached_session() as sess:
356
357      @batch_ops.batch_function(1, 10, 100000)
358      def computation(in_t):
359        return array_ops.reshape(in_t, [-1]) + 1
360
361      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1, 1])
362      result = computation(inp)
363      thread_results = []
364
365      def worker():
366        thread_results.extend(sess.run([result], feed_dict={inp: [[1]]}))
367
368      worker_thread = threading.Thread(target=worker)
369      worker_thread.start()
370      main_results = sess.run([result], feed_dict={inp: [[2]]})
371      worker_thread.join()
372      self.assertEqual(thread_results[0], [2])
373      self.assertEqual(main_results[0], [3])
374
375  def testUnbatchTimeout(self):
376    """Tests that the unbatch timeout works."""
377    if context.executing_eagerly():
378      return
379    with self.cached_session() as sess:
380      inp = array_ops.placeholder(dtype=dtypes.int32, shape=[1])
381      batched, index, id_t = batch_ops.batch(
382          [inp], num_batch_threads=1, max_batch_size=2,
383          batch_timeout_micros=36000000, grad_timeout_micros=0,
384          batching_queue="")
385      computation = batched[0] + 1
386      timeout_micros = 10
387      result = batch_ops.unbatch(computation, index, id_t, timeout_micros,
388                                 shared_name="shared_unbatch")
389      # Set up a parallel pipeline that delays the computation, but uses the
390      # same unbatch resource object as the non-delayed pipeline.
391      computation_delayed = script_ops.py_func(delayed_plus1,
392                                               [batched[0]],
393                                               dtypes.int32)
394      result_delayed = batch_ops.unbatch(computation_delayed,
395                                         index,
396                                         id_t,
397                                         timeout_micros,
398                                         shared_name="shared_unbatch")
399
400      thread_results = []
401      def worker():
402        # A first call using the non-delayed pipeline. The batcher will send an
403        # empty tensor along the non-delayed pipeline.
404        thread_results.extend(sess.run([result], feed_dict={inp: [1]}))
405      worker_thread = threading.Thread(target=worker)
406      worker_thread.start()
407      time.sleep(0.1)  # Ensure the thread's call starts first.
408      # A second call using the delayed pipeline.  The batcher will send the
409      # batched tensor along the delayed pipeline, thus delaying the arrival of
410      # the batched tensor at the unbatch op, relative to the empty tensor.
411      #
412      # TODO(olston, apassos): Avoid relying on the order in which the batch op
413      # emits the empty tensor versus the batched one.
414      _ = sess.run([result_delayed], feed_dict={inp: [2]})
415      worker_thread.join()
416      # The thread's call should hit the timeout, and thus get 0 results.
417      self.assertEqual(len(thread_results), 0)
418
419
420if __name__ == "__main__":
421  test.main()
422