1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Library for constructing a training loop, suitable for TPUs.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.compiler.xla import xla 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.tpu import tensor_tracer 27from tensorflow.python.tpu import tpu_function 28 29 30def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): 31 """Builds a training loop for TPUs. 32 33 The set of loop-carried tensors corresponds to `inputs`. Both 34 `condition` and `body` take the current value of the loop-carried 35 tensors. 'body' additionally takes a tuple of infeed from 36 infeed_queue if infeed_queue is not None. `condition` must return a 37 single boolean value that determines whether iteration 38 continues. `body` must return an updated list of values for the 39 loop-carried tensors. 40 41 Args: 42 condition: a Python function that builds the loop condition. 43 body: a Python function that builds the loop body. 44 inputs: a list of initial values passed into the training loop, or 45 None (equivalent to an empty list). 46 infeed_queue: if not None, the infeed queue from which to append a tuple 47 of arguments as inputs to condition. 48 name: (Deprecated) Does nothing. 49 50 Returns: 51 The final values of the loop-carried tensors. 52 53 Raises: 54 TypeError: if body or condition has the wrong signature. 55 """ 56 del name 57 # Converts inputs to Tensors. 58 inputs = [] if inputs is None else [ops.convert_to_tensor(x) for 59 x in inputs] 60 input_types = [x.dtype for x in inputs] 61 input_arity = len(inputs) 62 63 body_arg_error = xla.check_function_argument_count( 64 body, input_arity, infeed_queue) 65 if body_arg_error is not None: 66 if infeed_queue is None: 67 raise TypeError( 68 "Supplied loop body function cannot be called with the specified " 69 "inputs. You specified %d inputs: %s, but the loop body needs %s" % ( 70 input_arity, str([i.name for i in inputs]), body_arg_error)) 71 else: 72 raise TypeError( 73 "Supplied loop body function cannot be called with the specified " 74 "inputs. You specified %d inputs: %s and %d additional inputs from " 75 "infeed, but the computation needs %s" % (input_arity, str( 76 [i.name for i in inputs]), infeed_queue.number_of_tuple_elements, 77 body_arg_error)) 78 condition_arg_error = xla.check_function_argument_count( 79 condition, input_arity, None) 80 if condition_arg_error is not None: 81 if infeed_queue is None: 82 raise TypeError( 83 "Supplied loop condition function cannot be called with the " 84 "specified inputs. You specified %d inputs: %s, but the loop " 85 "condition needs %s" % (input_arity, str([i.name for i in inputs]), 86 condition_arg_error)) 87 else: 88 raise TypeError( 89 "Supplied loop condition function cannot be called with the " 90 "specified inputs. You specified %d inputs: %s, but the loop " 91 "condition needs %s. Note that infeed is not passed to the loop " 92 "condition." % (input_arity, str([i.name for i in inputs]), 93 condition_arg_error)) 94 95 def condition_wrapper(*inputs): 96 # Discards the dummy output added for arity-0 loops. 97 if input_arity == 0: 98 inputs = [] 99 return condition(*inputs) 100 101 def body_wrapper(*inputs): 102 """Wrapper around `body` that handles infeed queues and control deps.""" 103 inputs = list(inputs) 104 105 # Discards the dummy output added for arity-0 loops. 106 if input_arity == 0: 107 inputs = [] 108 109 # Runs `body` with the dequeue_ops appended. 110 if infeed_queue: 111 number_of_shards = tpu_function.get_tpu_context().number_of_shards 112 if number_of_shards is None: 113 raise ValueError("Can't build training loop with infeed when there is " 114 "no tpu_shard_context. Are you building a loop or " 115 "graph directly rather than from inside tpu.rewrite, " 116 "tpu.batch_parallel, tpu.shard, or tpu.replicate?") 117 infeed_queue.set_number_of_shards(number_of_shards) 118 dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()] 119 else: 120 dequeue_ops = [] 121 outputs = body(*(inputs + dequeue_ops)) 122 123 # If the computation only returned one value, make it a tuple. 124 if not isinstance(outputs, (list, tuple)): 125 outputs = (outputs,) 126 127 outputs = [ 128 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) 129 for o in outputs 130 ] 131 132 # Separates the returned Operations and Tensors. 133 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 134 output_tensors = [o for o in outputs 135 if not isinstance(o, ops.Operation)] 136 137 if outputs != output_tensors + output_operations: 138 raise ValueError( 139 "TPU training loop body must return zero or more Tensor values " 140 "followed by zero or more Operations.") 141 142 output_types = [op.dtype for op in output_tensors] 143 if input_types != output_types: 144 raise TypeError( 145 "Mismatch between input types and output types for training loop " 146 "body: {} vs {}".format(input_types, output_types)) 147 148 # Add the dequeue operations to output_operations to ensure they are run 149 # by the loop, even if the programmer's loop body does not use them. 150 output_operations += dequeue_ops 151 152 # Add a dummy output, if needed. 153 if not output_tensors: 154 output_tensors = array_ops.constant(0) 155 156 if output_operations: 157 # TODO(phawkins): in principle this is too restrictive since it serializes 158 # the training loop steps. In practice it does not matter since this loop 159 # will be compiled by XLA. 160 output_tensors = control_flow_ops.tuple(output_tensors, 161 control_inputs=output_operations) 162 163 if tensor_tracer.TensorTracer.is_enabled(): 164 num_replicas = tpu_function.get_tpu_context().number_of_shards 165 if num_replicas is None: 166 num_replicas = 1 167 tt = tensor_tracer.TensorTracer() 168 output_tensors = tt.trace_tpu(ops.get_default_graph(), 169 output_tensors, None, 170 num_replicas) 171 return output_tensors 172 173 # If the body has arity 0, add a dummy loop-carried value to which we can add 174 # control dependencies from any side-effecting operations. 175 if input_arity == 0: 176 inputs = [array_ops.constant(0)] 177 return control_flow_ops.while_loop( 178 condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1) 179 180 181def repeat(n, body, inputs=None, infeed_queue=None, name=None): 182 """Builds a training loop that executes a fixed number of iterations. 183 184 The set of loop-carried tensors correspond to `inputs`. 185 `body` must be a function that takes and returns the values of the 186 loop-carried tensors. 187 188 Args: 189 n: the number of loop iterations 190 body: a Python function that builds the loop body. 191 inputs: a list of initial values passed into the training loop or 192 None (equivalent to an empty list). 193 infeed_queue: if not None, the infeed queue from which to append a tuple 194 of arguments as inputs to condition. 195 name: (Deprecated) Does nothing. 196 Returns: 197 The final values of the loop-carried tensors. 198 Raises: 199 ValueError: if there is a type error. 200 """ 201 def _convert_to_list(xs): 202 if not isinstance(xs, (list, tuple)): 203 return [xs] 204 else: 205 return list(xs) 206 207 def cond(i, *args): 208 del args 209 return i < n 210 211 def body_wrapper(i, *args): 212 return [i + 1] + _convert_to_list(body(*args)) 213 214 inputs = [0] if inputs is None else [0] + _convert_to_list(inputs) 215 outputs = while_loop( 216 cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name) 217 outputs = _convert_to_list(outputs) 218 if len(outputs) == 1: 219 # Returns the Op rather than an empty list. 220 return outputs[0].op 221 else: 222 return outputs[1:] 223