• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Operations for automatic batching and unbatching."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import function
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_spec
24from tensorflow.python.ops import gen_batch_ops
25# pylint: disable=wildcard-import
26from tensorflow.python.ops.gen_batch_ops import *
27# pylint: enable=wildcard-import
28from tensorflow.python.util import nest
29from tensorflow.python.util.tf_export import tf_export
30
31
32@tf_export("nondifferentiable_batch_function")
33def batch_function(num_batch_threads,
34                   max_batch_size,
35                   batch_timeout_micros,
36                   allowed_batch_sizes=None,
37                   max_enqueued_batches=10,
38                   autograph=True,
39                   enable_large_batch_splitting=True):
40  """Batches the computation done by the decorated function.
41
42  So, for example, in the following code
43
44  ```python
45  @batch_function(1, 2, 3)
46  def layer(a):
47    return tf.matmul(a, a)
48
49  b = layer(w)
50  ```
51
52  if more than one session.run call is simultaneously trying to compute `b`
53  the values of `w` will be gathered, non-deterministically concatenated
54  along the first axis, and only one thread will run the computation. See the
55  documentation of the `Batch` op for more details.
56
57  Assumes that all arguments of the decorated function are Tensors which will
58  be batched along their first dimension.
59
60  SparseTensor is not supported. The return value of the decorated function
61  must be a Tensor or a list/tuple of Tensors.
62
63  Args:
64    num_batch_threads: Number of scheduling threads for processing batches
65     of work. Determines the number of batches processed in parallel.
66    max_batch_size: Batch sizes will never be bigger than this.
67    batch_timeout_micros: Maximum number of microseconds to wait before
68     outputting an incomplete batch.
69    allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,
70     does nothing. Otherwise, supplies a list of batch sizes, causing the op
71     to pad batches up to one of those sizes. The entries must increase
72     monotonically, and the final entry must equal max_batch_size.
73    max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10.
74    autograph: Whether to use autograph to compile python and eager style code
75     for efficient graph-mode execution.
76    enable_large_batch_splitting: The value of this option doesn't affect
77     processing output given the same input; it affects implementation details
78     as stated below: 1. Improve batching efficiency by eliminating unnecessary
79     adding. 2.`max_batch_size` specifies the limit of input and
80     `allowed_batch_sizes` specifies the limit of a task to be processed. API
81     user can give an input of size 128 when 'max_execution_batch_size'
82     is 32 -> implementation can split input of 128 into 4 x 32, schedule
83     concurrent processing, and then return concatenated results corresponding
84     to 128.
85
86  Returns:
87    The decorated function will return the unbatched computation output Tensors.
88  """
89
90  def decorator(fn):  # pylint: disable=missing-docstring
91
92    def decorated(*args):  # pylint: disable=missing-docstring
93
94      @function.defun(autograph=autograph)
95      def computation(*computation_args):
96        return fn(*computation_args)
97
98      computation = computation.get_concrete_function(
99          *[tensor_spec.TensorSpec(dtype=x.dtype, shape=x.shape, name=str(i))
100            for i, x in enumerate(args)])
101
102      with ops.name_scope("batch") as name:
103        for a in args:
104          if not isinstance(a, ops.Tensor):
105            raise ValueError("All arguments to functions decorated with "
106                             "`batch_function`  are supposed to be Tensors; "
107                             "found %s" % repr(a))
108        outputs = gen_batch_ops.batch_function(
109            num_batch_threads=num_batch_threads,
110            max_batch_size=max_batch_size,
111            batch_timeout_micros=batch_timeout_micros,
112            allowed_batch_sizes=allowed_batch_sizes,
113            max_enqueued_batches=max_enqueued_batches,
114            shared_name=name,
115            enable_large_batch_splitting=enable_large_batch_splitting,
116            f=computation,
117            in_tensors=list(args),
118            captured_tensors=computation.captured_inputs,
119            Tout=[o.dtype for o in computation.outputs])
120        return nest.pack_sequence_as(
121            computation.structured_outputs, outputs, expand_composites=True)
122
123    return decorated
124
125  return decorator
126