• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Library for constructing a training loop, suitable for TPUs."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.compiler.xla import xla
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import control_flow_ops
26from tensorflow.python.tpu import tensor_tracer
27from tensorflow.python.tpu import tpu_function
28
29
30def while_loop(condition, body, inputs=None, infeed_queue=None, name=None):
31  """Builds a training loop for TPUs.
32
33  The set of loop-carried tensors corresponds to `inputs`.  Both
34  `condition` and `body` take the current value of the loop-carried
35  tensors. 'body' additionally takes a tuple of infeed from
36  infeed_queue if infeed_queue is not None. `condition` must return a
37  single boolean value that determines whether iteration
38  continues. `body` must return an updated list of values for the
39  loop-carried tensors.
40
41  Args:
42    condition: a Python function that builds the loop condition.
43    body: a Python function that builds the loop body.
44    inputs: a list of initial values passed into the training loop, or
45      None (equivalent to an empty list).
46    infeed_queue: if not None, the infeed queue from which to append a tuple
47      of arguments as inputs to condition.
48    name: (Deprecated) Does nothing.
49
50  Returns:
51    The final values of the loop-carried tensors.
52
53  Raises:
54    TypeError: if body or condition has the wrong signature.
55  """
56  del name
57  # Converts inputs to Tensors.
58  inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
59                                      x in inputs]
60  input_types = [x.dtype for x in inputs]
61  input_arity = len(inputs)
62
63  body_arg_error = xla.check_function_argument_count(
64      body, input_arity, infeed_queue)
65  if body_arg_error is not None:
66    if infeed_queue is None:
67      raise TypeError(
68          "Supplied loop body function cannot be called with the specified "
69          "inputs. You specified %d inputs: %s, but the loop body needs %s" % (
70              input_arity, str([i.name for i in inputs]), body_arg_error))
71    else:
72      raise TypeError(
73          "Supplied loop body function cannot be called with the specified "
74          "inputs. You specified %d inputs: %s and %d additional inputs from "
75          "infeed, but the computation needs %s" % (input_arity, str(
76              [i.name for i in inputs]), infeed_queue.number_of_tuple_elements,
77                                                    body_arg_error))
78  condition_arg_error = xla.check_function_argument_count(
79      condition, input_arity, None)
80  if condition_arg_error is not None:
81    if infeed_queue is None:
82      raise TypeError(
83          "Supplied loop condition function cannot be called with the "
84          "specified inputs. You specified %d inputs: %s, but the loop "
85          "condition needs %s" % (input_arity, str([i.name for i in inputs]),
86                                  condition_arg_error))
87    else:
88      raise TypeError(
89          "Supplied loop condition function cannot be called with the "
90          "specified inputs. You specified %d inputs: %s, but the loop "
91          "condition needs %s. Note that infeed is not passed to the loop "
92          "condition." % (input_arity, str([i.name for i in inputs]),
93                          condition_arg_error))
94
95  def condition_wrapper(*inputs):
96    # Discards the dummy output added for arity-0 loops.
97    if input_arity == 0:
98      inputs = []
99    return condition(*inputs)
100
101  def body_wrapper(*inputs):
102    """Wrapper around `body` that handles infeed queues and control deps."""
103    inputs = list(inputs)
104
105    # Discards the dummy output added for arity-0 loops.
106    if input_arity == 0:
107      inputs = []
108
109    # Runs `body` with the dequeue_ops appended.
110    if infeed_queue:
111      number_of_shards = tpu_function.get_tpu_context().number_of_shards
112      if number_of_shards is None:
113        raise ValueError("Can't build training loop with infeed when there is "
114                         "no tpu_shard_context. Are you building a loop or "
115                         "graph directly rather than from inside tpu.rewrite, "
116                         "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
117      infeed_queue.set_number_of_shards(number_of_shards)
118      dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
119    else:
120      dequeue_ops = []
121    outputs = body(*(inputs + dequeue_ops))
122
123    # If the computation only returned one value, make it a tuple.
124    if not isinstance(outputs, (list, tuple)):
125      outputs = (outputs,)
126
127    outputs = [
128        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
129        for o in outputs
130    ]
131
132    # Separates the returned Operations and Tensors.
133    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
134    output_tensors = [o for o in outputs
135                      if not isinstance(o, ops.Operation)]
136
137    if outputs != output_tensors + output_operations:
138      raise ValueError(
139          "TPU training loop body must return zero or more Tensor values "
140          "followed by zero or more Operations.")
141
142    output_types = [op.dtype for op in output_tensors]
143    if input_types != output_types:
144      raise TypeError(
145          "Mismatch between input types and output types for training loop "
146          "body: {} vs {}".format(input_types, output_types))
147
148    # Add the dequeue operations to output_operations to ensure they are run
149    # by the loop, even if the programmer's loop body does not use them.
150    output_operations += dequeue_ops
151
152    # Add a dummy output, if needed.
153    if not output_tensors:
154      output_tensors = array_ops.constant(0)
155
156    if output_operations:
157      # TODO(phawkins): in principle this is too restrictive since it serializes
158      # the training loop steps. In practice it does not matter since this loop
159      # will be compiled by XLA.
160      output_tensors = control_flow_ops.tuple(output_tensors,
161                                              control_inputs=output_operations)
162
163    if tensor_tracer.TensorTracer.is_enabled():
164      num_replicas = tpu_function.get_tpu_context().number_of_shards
165      if num_replicas is None:
166        num_replicas = 1
167      tt = tensor_tracer.TensorTracer()
168      output_tensors = tt.trace_tpu(ops.get_default_graph(),
169                                    output_tensors, None,
170                                    num_replicas)
171    return output_tensors
172
173  # If the body has arity 0, add a dummy loop-carried value to which we can add
174  # control dependencies from any side-effecting operations.
175  if input_arity == 0:
176    inputs = [array_ops.constant(0)]
177  return control_flow_ops.while_loop(
178      condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
179
180
181def repeat(n, body, inputs=None, infeed_queue=None, name=None):
182  """Builds a training loop that executes a fixed number of iterations.
183
184  The set of loop-carried tensors correspond to `inputs`.
185  `body` must be a function that takes and returns the values of the
186  loop-carried tensors.
187
188  Args:
189    n: the number of loop iterations
190    body: a Python function that builds the loop body.
191    inputs: a list of initial values passed into the training loop or
192      None (equivalent to an empty list).
193    infeed_queue: if not None, the infeed queue from which to append a tuple
194      of arguments as inputs to condition.
195    name: (Deprecated) Does nothing.
196  Returns:
197    The final values of the loop-carried tensors.
198  Raises:
199    ValueError: if there is a type error.
200  """
201  def _convert_to_list(xs):
202    if not isinstance(xs, (list, tuple)):
203      return [xs]
204    else:
205      return list(xs)
206
207  def cond(i, *args):
208    del args
209    return i < n
210
211  def body_wrapper(i, *args):
212    return [i + 1] + _convert_to_list(body(*args))
213
214  inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
215  outputs = while_loop(
216      cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
217  outputs = _convert_to_list(outputs)
218  if len(outputs) == 1:
219    # Returns the Op rather than an empty list.
220    return outputs[0].op
221  else:
222    return outputs[1:]
223