1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""while_v2 and gradient. 16 17This is a version of while_loop that emits a single While op, as well as the 18gradient function for While ops produced by while_loop. This will eventually 19replace the current tf.while_loop implementation once it reaches feature and 20performance parity. 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import collections 27 28from tensorflow.core.framework import attr_value_pb2 29from tensorflow.python.client import pywrap_tf_session as c_api 30from tensorflow.python.eager import backprop_util 31from tensorflow.python.framework import auto_control_deps_utils as acd 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import func_graph as func_graph_module 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import tensor_shape 37from tensorflow.python.framework import tensor_spec 38from tensorflow.python.framework import tensor_util 39from tensorflow.python.framework import type_spec 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import control_flow_ops 42from tensorflow.python.ops import control_flow_util as util_v1 43from tensorflow.python.ops import control_flow_util_v2 as util 44from tensorflow.python.ops import custom_gradient 45from tensorflow.python.ops import default_gradient 46from tensorflow.python.ops import gen_functional_ops 47from tensorflow.python.ops import gen_resource_variable_ops 48from tensorflow.python.ops import gradients_util 49from tensorflow.python.ops import list_ops 50from tensorflow.python.ops import math_ops 51from tensorflow.python.ops import tensor_array_ops 52from tensorflow.python.ops import while_v2_indexed_slices_rewriter 53from tensorflow.python.util import compat 54from tensorflow.python.util import nest 55from tensorflow.python.util import object_identity 56 57# pylint: disable=protected-access 58 59 60def while_loop(cond, 61 body, 62 loop_vars, 63 shape_invariants=None, 64 parallel_iterations=10, 65 maximum_iterations=None, 66 name=None, 67 return_same_structure=True, 68 back_prop=True): 69 """Like tf.while_loop, except emits a single While op.""" 70 # Keep the original loop_vars around to know which args were TensorArrays. 71 orig_loop_vars = loop_vars 72 # Cache its length since we use it at multiple places below. 73 len_orig_loop_vars = len(orig_loop_vars) 74 75 # Convert TensorArrays to their flow variables. These get converted back to 76 # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and 77 # `wrapped_body` below. 78 loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) 79 loop_vars = nest.map_structure( 80 ops.internal_convert_to_tensor_or_indexed_slices, loop_vars, 81 expand_composites=True) 82 if shape_invariants is not None: 83 nest.assert_same_structure(orig_loop_vars, shape_invariants, 84 expand_composites=False) 85 signature = nest.map_structure( 86 control_flow_ops._shape_invariant_to_type_spec, loop_vars, 87 list(shape_invariants), expand_composites=False) 88 shape_invariants = nest.map_structure( 89 control_flow_ops._get_shape_invariant, loop_vars, 90 list(shape_invariants), expand_composites=False) 91 92 else: 93 signature = nest.map_structure( 94 type_spec.type_spec_from_value, loop_vars, expand_composites=False) 95 shape_invariants = nest.map_structure( 96 control_flow_ops._get_shape_invariant, loop_vars, 97 expand_composites=False) 98 if not name: 99 name = "while" 100 101 with ops.name_scope(name) as scope: 102 with ops.name_scope(None): 103 cond_name = util.unique_fn_name(scope, "cond") 104 body_name = util.unique_fn_name(scope, "body") 105 maximum_iterations_loop_var = _build_maximum_iterations_loop_var( 106 maximum_iterations) 107 loop_counter = constant_op.constant( 108 0, 109 dtype=maximum_iterations_loop_var.dtype 110 if maximum_iterations is not None else None, 111 name="loop_counter") 112 # Add loop counter needed for computing gradients. 113 loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars 114 115 shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants 116 signature = ( 117 [tensor_spec.TensorSpec.from_tensor(loop_counter), 118 tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] + 119 signature) 120 121 # Automatic control dependencies are added in defuns, but not in v1 122 # graphs. Propagate that behavior here. 123 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 124 125 def wrapped_cond(loop_counter, maximum_iterations_arg, *args): 126 """Extra `cond` wrapper that can handle the extra counter loop_var.""" 127 # Convert the flow variables in `args` to TensorArrays. `args` should 128 # already have the same structure as `orig_loop_vars` but currently there 129 # is no nest.zip so we call `_pack_sequence_as` which flattens both 130 # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays 131 # and packs it into the structure of `orig_loop_vars`. 132 pred = cond(*_pack_sequence_as(orig_loop_vars, args)) 133 if (tensor_util.is_tf_type(pred) and 134 (pred.shape.dims is None or pred.shape.dims)): 135 pred = array_ops.squeeze_v2(pred) 136 137 if maximum_iterations is None: 138 return pred 139 else: 140 return math_ops.logical_and( 141 loop_counter < maximum_iterations_arg, pred) 142 143 # NOTE(skyewm): we set collections to the outer graph's collections for 144 # compatibility with TPUEstimator. 145 cond_graph = func_graph_module.func_graph_from_py_func( 146 cond_name, 147 wrapped_cond, 148 [], # We provide signature instead of args. 149 {}, 150 signature=signature, 151 func_graph=util.WhileCondFuncGraph( 152 cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 153 add_control_dependencies=add_control_dependencies) 154 155 def wrapped_body(loop_counter, maximum_iterations_arg, *args): 156 """Loop body augmented with counter update. 157 158 Args: 159 loop_counter: Loop counter which needs to be incremented in the body. 160 maximum_iterations_arg: Maximum iterations of the loop. 161 *args: List of args 162 163 Returns: 164 A list of tensors the same length as args. 165 """ 166 # The function was created with a signature rather than tensors, so 167 # internal placeholders were created without handle data. 168 _copy_handle_data(nest.flatten(loop_vars[2:], expand_composites=True), 169 nest.flatten(args, expand_composites=True)) 170 # Capture the tensors already captured in cond_graph so that they appear 171 # in the same order in body_graph.external_captures. 172 for t in cond_graph.external_captures: 173 ops.get_default_graph().capture(t) 174 175 # Convert the flow variables in `args` to TensorArrays. `args` should 176 # already have the same structure as `orig_loop_vars` but currently there 177 # is no nest.zip so we call `_pack_sequence_as` which flattens both 178 # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays 179 # and packs it into the structure of `orig_loop_vars`. 180 outputs = body(*_pack_sequence_as(orig_loop_vars, args)) 181 if not nest.is_sequence_or_composite(outputs): 182 outputs = [outputs] 183 # Compare the structure of input and output of body converting the 184 # top-level tuples to list to be compatible with legacy while_loop. 185 nest.assert_same_structure(list(outputs), list(orig_loop_vars), 186 expand_composites=True) 187 188 outputs = _tensor_array_to_flow(outputs) 189 190 # TODO(srbs): Update lowering code to create _Enter nodes with 191 # is_constant=True for inputs that are directly passed to outputs. 192 return [loop_counter + 1, maximum_iterations_arg] + list(outputs) 193 194 body_graph = func_graph_module.func_graph_from_py_func( 195 body_name, 196 wrapped_body, 197 [], # We provide signature instead of args. 198 {}, 199 signature=signature, 200 func_graph=util.WhileBodyFuncGraph( 201 body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 202 add_control_dependencies=add_control_dependencies) 203 # Add external captures of body to the list of loop vars. 204 # Note that external tensors will be treated as loop invariants, i.e., 205 # the value of that tensor in each iteration is the same as it was at the 206 # beginning of the loop execution. 207 loop_vars = loop_vars + body_graph.external_captures 208 # TODO(srbs): Update lowering code to create _Enter nodes with 209 # is_constant=True for inputs that are directly passed to outputs. 210 body_graph.outputs.extend(body_graph.internal_captures) 211 212 # Capture the extra `external_captures` of `body_graph` in `cond_graph` so 213 # that it expects to receive those as arguments. 214 with cond_graph.as_default(): 215 num_cond_captures = len(cond_graph.external_captures) 216 assert (cond_graph.external_captures == 217 body_graph.external_captures[:num_cond_captures]) 218 _duplicate_body_captures_in_cond( 219 cond_graph, body_graph.external_captures[num_cond_captures:]) 220 221 # Make sure that the shapes of the loop outputs are compatible with the 222 # shape invariants, or the shapes of the loop vars if the invariants are not 223 # specified. 224 num_flattened_outputs = len(nest.flatten(orig_loop_vars, 225 expand_composites=True)) 226 # First var is loop counter and second var is maximum_iterations. 227 first_loop_var_index = 2 228 _check_shapes_compat( 229 body_graph.outputs[first_loop_var_index:first_loop_var_index + 230 num_flattened_outputs], 231 nest.flatten( 232 shape_invariants[first_loop_var_index:first_loop_var_index + 233 len_orig_loop_vars], expand_composites=True), 234 nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + 235 len_orig_loop_vars], expand_composites=True)) 236 237 num_original_outputs = len(body_graph.outputs) 238 if back_prop and util.output_all_intermediates(): 239 # Export all tensors in the loop body that may be needed for gradient 240 # computation. We do this by accumulating the intermediate values in 241 # TensorLists. 242 intermediate_tensors = _get_intermediates(body_graph) 243 244 for intermediate_tensor in intermediate_tensors: 245 tensor_list = list_ops.empty_tensor_list( 246 element_dtype=intermediate_tensor.dtype, 247 element_shape=intermediate_tensor.shape, 248 max_num_elements=maximum_iterations) 249 loop_vars.append(tensor_list) 250 with cond_graph.as_default(): 251 # Add a placeholder to cond_graph's inputs corresponding to the 252 # tensor_list. 253 cond_graph.capture(tensor_list) 254 with body_graph.as_default(): 255 # Push the intermediate tensor to the tensor list. This captures the 256 # `tensor_list` as well. 257 appended_tensor_list = list_ops.tensor_list_push_back( 258 tensor_list, intermediate_tensor) 259 # Add this modified tensor list to the list of outputs. 260 body_graph.outputs.append(appended_tensor_list) 261 262 flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True) 263 _check_num_inputs_outputs(cond_graph, body_graph, 264 len(flattened_loop_vars)) 265 _check_inputs_outputs_types_match(body_graph, flattened_loop_vars) 266 267 with ops.control_dependencies( 268 list(cond_graph.control_captures) + list(body_graph.control_captures)): 269 output_shapes = [t.shape for t in body_graph.outputs] 270 orig_loop_vars_range = slice(first_loop_var_index, 271 first_loop_var_index + num_flattened_outputs) 272 output_shapes[orig_loop_vars_range] = nest.flatten( 273 shape_invariants, expand_composites=True)[orig_loop_vars_range] 274 275 outputs = _build_while_op( 276 flattened_loop_vars, 277 cond_graph, 278 body_graph, 279 output_shapes=output_shapes, 280 parallel_iterations=parallel_iterations, 281 name=scope, 282 num_original_outputs=num_original_outputs) 283 if not ops.get_default_graph().building_function: 284 # In V1 graph mode, return identities for each output of the While op, 285 # rather than the output of the While op directly. This makes pruning work 286 # if the output of while_loop() is fetched: the lowering pass converts the 287 # While outputs into IdentityN outputs, which if fetched will cause all 288 # ops in the body to be run (since it takes all exit ops as input). After 289 # lowering, each output identity op will end up with only the appropriate 290 # exit op as input. 291 outputs = tuple(array_ops.identity(t) for t in outputs) 292 293 output_loop_vars = outputs[first_loop_var_index:first_loop_var_index + 294 num_flattened_outputs] 295 if not back_prop: 296 output_loop_vars = [array_ops.stop_gradient(t) for t in output_loop_vars] 297 outputs = _pack_sequence_as(orig_loop_vars, output_loop_vars) 298 299 if return_same_structure: 300 return outputs 301 302 flattened_outputs = nest.flatten(outputs, expand_composites=True) 303 if len(flattened_outputs) == 1: 304 return flattened_outputs[0] 305 else: 306 return outputs 307 308 309@ops.RegisterGradient("StatelessWhile") 310@ops.RegisterGradient("While") 311def _WhileGrad(op, *grads): # pylint: disable=invalid-name 312 """The gradient of a While op produced by while_loop.""" 313 # Note that op is not always the same as while_op because the gradient tape, 314 # for eager mode compatibility, forgets information about the proper op. Since 315 # the loop cannot run in eager mode, however, we can safely introspect into 316 # the graph here. 317 while_op = op.outputs[0].op 318 cond_graph = _get_graph(while_op, "cond", "_cond_graph") 319 body_graph = _get_graph(while_op, "body", "_body_graph") 320 orig_num_params = len(body_graph.outputs) 321 322 maximum_iterations = op.inputs[1] 323 parallel_iterations = op.get_attr("parallel_iterations") 324 325 try: 326 num_original_outputs = while_op.get_attr("_num_original_outputs") 327 except: # pylint: disable=bare-except 328 num_original_outputs = len(while_op.outputs) 329 330 num_intermediates = len(while_op.outputs) - num_original_outputs 331 grads = [ 332 _preprocess_grad(grad, body_out, while_in, while_out) # pylint: disable=g-complex-comprehension 333 for grad, body_out, while_in, while_out in zip( 334 grads[:num_original_outputs], 335 body_graph.outputs[:num_original_outputs], 336 while_op.inputs[:num_original_outputs], 337 while_op.outputs[:num_original_outputs]) 338 ] + [None] * num_intermediates 339 340 # Skip gradients with respect to the captures whenever possible. 341 if "skip_input_indices" in op.__dict__ and op.skip_input_indices is not None: 342 captures_start_index = ( 343 len(body_graph.inputs) - len(body_graph.internal_captures)) 344 for i in op.skip_input_indices: 345 if i >= captures_start_index: 346 grads[i] = None 347 348 # We compute the gradient for the sub-graph between trainable ys and xs 349 # with non-None incoming gradients. We later pad the None's to the list of 350 # outputs. 351 ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( 352 body_graph.outputs, body_graph.inputs, grads) if grad is not None]) 353 354 body_grad_graph, args = _create_grad_func( 355 ys, xs, non_none_grads, cond_graph, body_graph, 356 util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) 357 358 if body_grad_graph.while_op_needs_rewrite: 359 # Modify 'op' to output the intermediate accumulators needed by the grad 360 # function. 361 # NOTE(skyewm): if there are any active sessions, this modification to `op` 362 # may make them unrunnable! 363 364 cond_graph.name += "_rewritten" 365 body_graph.name += "_rewritten" 366 367 # `body_grad_graph.extra_inputs` here is equivalent to skimming off the new 368 # `body_graph.external_captures` added during `_create_grad_func`. 369 new_inputs = body_grad_graph.extra_inputs 370 new_outputs = body_graph.outputs[orig_num_params:] 371 372 while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) 373 while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) 374 if len(body_graph.output_types) != len(while_op.inputs) + len(new_inputs): 375 # Continuing leads to an invalid graph with disconnected inputs. 376 raise AssertionError( 377 "Inputs and outputs constructed for the forward op of a While " 378 "gradient don't match. This doesn't make sense, please file a bug.") 379 while_op._set_type_list_attr("T", body_graph.output_types) 380 while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) 381 while_op._add_while_inputs(new_inputs) 382 while_op._add_outputs([t.dtype for t in new_outputs], 383 [t.shape for t in new_outputs]) 384 _copy_handle_data(new_outputs, while_op.outputs[orig_num_params:]) 385 386 # Do not ignore grads wrt extra outputs when computing higher order 387 # derivatives. 388 while_op._set_attr("_num_original_outputs", 389 attr_value_pb2.AttrValue(i=len(while_op.outputs))) 390 391 captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, 392 while_op) 393 loop_vars = args + captured_inputs 394 395 # This modifies body_grad_graph. 396 loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( 397 grads, body_grad_graph, loop_vars, while_op.inputs) 398 399 def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters, 400 *unused_args): 401 return counter < forward_loop_iters 402 403 grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) 404 cond_grad_graph = func_graph_module.func_graph_from_py_func( 405 grad_cond_name, grad_cond, loop_vars, {}, 406 func_graph=util.WhileCondFuncGraph(grad_cond_name)) 407 408 _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) 409 410 outputs = _build_while_op( 411 loop_vars, 412 cond_grad_graph, 413 body_grad_graph, 414 output_shapes=[t.shape for t in body_grad_graph.outputs], 415 parallel_iterations=parallel_iterations, 416 name="%s_grad" % while_op.name, 417 num_original_outputs=len(body_grad_graph.outputs)) 418 419 # See comment in while_loop. 420 outputs = [array_ops.identity(t) for t in outputs] 421 return _get_structured_grad_output(outputs, grads, body_grad_graph) 422 423 424def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes, 425 parallel_iterations, name, num_original_outputs): 426 """Builds the functional StatelessWhile/While op.""" 427 cond_stateful_ops = [ 428 op for op in cond_graph.get_operations() if op._is_stateful 429 ] 430 body_stateful_ops = [ 431 op for op in body_graph.get_operations() if op._is_stateful 432 ] 433 if (cond_stateful_ops or body_stateful_ops): 434 op_fn = gen_functional_ops._while 435 else: 436 op_fn = gen_functional_ops.stateless_while 437 438 def _make_op(inputs): 439 while_op, tensors = util.get_op_and_outputs(op_fn( 440 inputs, 441 util.create_new_tf_function(cond_graph), 442 util.create_new_tf_function(body_graph), 443 output_shapes=output_shapes, 444 parallel_iterations=parallel_iterations, 445 name=name)) 446 _copy_handle_data(body_graph.outputs, tensors) 447 util.maybe_set_lowering_attr(while_op) 448 util.maybe_propagate_compile_time_consts_in_xla(while_op) 449 _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph]) 450 # This is needed so we do not compute derivative wrt these extra outputs. 451 while_op._set_attr("_num_original_outputs", 452 attr_value_pb2.AttrValue(i=num_original_outputs)) 453 # The while op may be created inside a tf.function, in which case ops 454 # needs to capture "through" it when taking gradients; outer_graph is used 455 # as a sanity check that capturing only happens from parent to child. 456 cond_graph.outer_graph = ops.get_default_graph() 457 body_graph.outer_graph = ops.get_default_graph() 458 while_op._cond_graph = cond_graph 459 while_op._body_graph = body_graph 460 return tensors 461 return util.run_as_function_for_tape_gradients(_make_op, loop_vars) 462 463 464def _get_intermediates(func_graph): 465 """Returns all tensors in `func_graph` that should be accumulated.""" 466 # We currently accumulate output tensors of most ops in the function and rely 467 # on the pruning pass to get rid of the unused accumulators at runtime. 468 # However, this can bloat the GraphDef and make debugging harder so we perform 469 # some optimizations. 470 # 471 # Optimization we currently perform: 472 # 1. We do not accumulate tensors which already have an accumulator 473 # in the loop body. 474 # 2. We do not accumulate outputs of Identity nodes. When building the 475 # FuncGraph, we add an Identity node for each output (see 476 # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs 477 # of all these nodes bloats the GraphDef quite a bit so we remove those. 478 # Since the gradient of an Identity node does not rely on its forward op's 479 # input this is safe to do. 480 # 481 # Other possible optimizations: 482 # 1. Only accumulate tensors that will be required by the backward pass. 483 # This will require running the gradient pass and hence would increase the 484 # graph building time for the forward pass. 485 # 2. Do not accumulate Const nodes created inside the loop body. 486 # 3. Do not accumulate loop vars that are returned as-is just like captured 487 # tensors. 488 intermediates = [] 489 reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures) 490 491 for op in func_graph.get_operations(): 492 if op.type == "Identity": 493 continue 494 # Accumulating mutexes can cause deadlock. 495 if op.type == "MutexLock": 496 continue 497 for o in op.outputs: 498 if (o is not func_graph.inputs[0] and # Loop counter. 499 o.dtype != dtypes.resource and # Do not accumulate resource tensors. 500 _get_accumulator(o) is None and # Has existing accumulator. 501 o.ref() not in reverse_captures 502 ): # Captured value, hence loop invariant. 503 intermediates.append(o) 504 return intermediates 505 506 507def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output): 508 """Returns the initial gradient to be used for a given output tensor. 509 510 Args: 511 grad: the original gradient Tensor passed to the gradient function. 512 body_graph_output: the corresponding Tensor in the body graph. 513 while_op_input: the corresponding Tensor input of the While op. 514 while_op_output: the corresponding Tensor output of the While op. 515 516 Returns: 517 A Tensor or None. 518 """ 519 # Set the incoming gradient of non-trainable inputs to None. It is possible 520 # that we receive non-None gradients for non-trainable types in nested while 521 # loops because we accumulate outputs of the inner while as variant tensors 522 # which are trainable and hence receive zeros_like tensors in the gradient 523 # pass. The non-trainable tensors then receive the popped zeros tensor from 524 # this zeros variant. The gradient for the loop vars corresponding to these 525 # tensors is None or zeros (this happens only if the loop var is accumulated 526 # as well) in _grad_fn so we reset these. 527 # TODO(b/118712257): Remove once we can handle None output grads in _grad_fn. 528 if not _is_trainable(body_graph_output): 529 return None 530 531 # GradientTape initializes resource and variant grads as None instead of 532 # zeros. Set to zeros so _GradientsHelper computes the gradients instead of 533 # returning None. 534 # TODO(b/143286622): The supports_default_grad check is needed 535 # because While op emits non-differentiable resource tensors 536 # as outputs. Remove this check when that is not the case. 537 # Note: We use `while_op_input` instead of `while_op_output` for the call 538 # to `supports_default_grad` because `while_op_output` may be missing 539 # handle_data if the While is in a restored saved model. 540 if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and 541 default_gradient.supports_default_grad(while_op_input) and grad is None): 542 return _zeros_like(while_op_input, while_op_output) 543 544 # Convert IndexedSlices to dense tensors since it is unlikely that downstream 545 # gradient functions with properly handle indexed slices. This is similar to 546 # what we do in tf.function gradients. 547 if isinstance(grad, ops.IndexedSlices): 548 return ops.convert_to_tensor(grad) 549 550 return grad 551 552 553# TODO(skyewm): make this return constants if op_output's shape is fully 554# defined (this can be done by checking the "shape" attr of resource vars). 555def _zeros_like(op_input, op_output): 556 """Like array_ops.zeros_like() but also accepts resource var handles.""" 557 if op_output.dtype == dtypes.resource: 558 # Note: We use `op_input` instead of `op_output` to get the zeros dtype 559 # because `op_output` may be missing handle_data if the While is in a 560 # restored saved model. 561 return array_ops.zeros( 562 gen_resource_variable_ops.variable_shape(op_output), 563 dtype=default_gradient.get_zeros_dtype(op_input)) 564 return array_ops.zeros_like(op_output) 565 566 567def _is_trainable(tensor): 568 """Returns whether the given tensor is trainable.""" 569 if not backprop_util.IsTrainable(tensor): 570 return False 571 572 # Special case: untrainable accumulator output. The gradients algorithm 573 # doesn't know about tensor lists of untrainable elements. In theory the 574 # tensor list gradient functions should return None as appropriate, but 575 # because we can't return None from the gradient function we filter out 576 # untrainable accumulator output here to avoid computing the gradient at all. 577 if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0: 578 assert tensor.dtype == dtypes.variant 579 element_type = tensor.op.get_attr("element_dtype") 580 return backprop_util.IsTrainable(element_type) 581 582 return True 583 584 585def _get_graph(while_op, func_attr_name, attr_graph_name): 586 """Returns `FuncGraph` for the given function attribute. 587 588 Args: 589 while_op: The While Operation. 590 func_attr_name: string 591 attr_graph_name: cached forward graph name 592 593 Returns: 594 `FuncGraph` 595 """ 596 func_graph = getattr(while_op, attr_graph_name, None) 597 if func_graph is None: 598 # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. 599 input_shapes = [ 600 tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") 601 ] 602 func_name = while_op.get_attr(func_attr_name).name 603 func_graph = util.get_func_graph(while_op, input_shapes, func_name) 604 func_graph._while = while_op 605 return func_graph 606 607 608def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op, 609 maximum_iterations): 610 """Builds and returns the gradient FuncGraph of `func_graph` and its args. 611 612 The returned grad_func_graph must be called with the returned 613 args + grad_func_graph.captures. 614 615 Args: 616 ys: A `Tensor` or list of tensors to be differentiated. 617 xs: A `Tensor` or list of tensors to be used for differentiation. 618 grads: The incoming grads for `ys`. 619 cond_graph: FuncGraph for the forward cond function. 620 body_graph: FuncGraph for the forward body function. 621 name: Name of the returned gradient function. 622 while_op: The forward While op. 623 maximum_iterations: Tensor. The maximum number of iterations. 624 625 Returns: 626 2-tuple of (grad_func_graph, args). 627 """ 628 assert len(ys) == len(grads) 629 630 total_iters = while_op.outputs[0] 631 counter = constant_op.constant( 632 0, dtype=total_iters.dtype, name="grad_counter") 633 634 # Build frozen sets so that we do not have linear time lookups in 635 # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs` 636 # may get updated during gradient computation because we add accumulators to 637 # the forward op. However, those are not loop invariants so wouldn't affect 638 # the output of `_is_loop_invariant`. Also we would never attempt to capture 639 # those accumulators so `_is_loop_invariant` should never receive those new 640 # tensors as args. 641 body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs) 642 body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs) 643 644 args = [counter, maximum_iterations, total_iters] + list(grads) 645 # Note: The returned function does not have `args` in the list of 646 # `external_captures`. 647 grad_func_graph = func_graph_module.func_graph_from_py_func( 648 name, 649 lambda *args: _grad_fn(ys, xs, args, body_graph), 650 args, {}, 651 func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph, 652 maximum_iterations, while_op, 653 body_graph_inputs, body_graph_outputs)) 654 655 # Update the list of outputs with tensors corresponding to the captured 656 # tensors. We capture 3 types of tensors when building the grad fn: 657 # 1. Accumulators for forward graph intermediates which are not loop 658 # invariants. The outputs corresponding to these are populated in 659 # `internal_capture_to_output` by `_WhileBodyGradFuncGraph`. 660 # 2. Resources, which are output as is. 661 # 3. Forward graph loop invariants, which are output as is. 662 for external_capture, internal_capture in grad_func_graph.captures: 663 if (ops.tensor_id(internal_capture) 664 in grad_func_graph.internal_capture_to_output): 665 new_output = grad_func_graph.internal_capture_to_output[ops.tensor_id( 666 internal_capture)] 667 else: 668 raise ValueError( 669 "Tensor %s which captures %s is in list of " 670 "internal_captures but not in internal_capture_to_output." % 671 (str(internal_capture), str(external_capture))) 672 grad_func_graph.outputs.append(new_output) 673 grad_func_graph.structured_outputs.append(new_output) 674 675 return grad_func_graph, args 676 677 678def _grad_fn(ys, xs, args, func_graph): 679 """Computes the gradient of `func_graph` in the current graph. 680 681 This function builds the gradient graph of the corresponding forward-pass 682 `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs. 683 684 Args: 685 ys: A `Tensor` or list of tensors to be differentiated. 686 xs: A `Tensor` or list of tensors to be used for differentiation. 687 args: The input arguments. 688 args[0] - Loop counter 689 args[1] - Total number of iterations. 690 args[2] - maximum_iterations. 691 args[3:] - Incoming gradients for `ys`. 692 func_graph: function.FuncGraph. The corresponding forward-pass function. 693 694 Returns: 695 The output gradient Tensors. 696 """ 697 grad_ys = args[3:] 698 699 # Build the gradient graph. Note that this builds the gradient computation of 700 # func_graph in the current graph, which requires capturing tensors from 701 # func_graph. The captured func_graph tensors are resolved to external tensors 702 # after the forward While op has been rewritten in _resolve_grad_captures. 703 # TODO(srbs): Mark GradientsHelper as public? 704 grad_outs = gradients_util._GradientsHelper( 705 ys, xs, grad_ys=grad_ys, src_graph=func_graph, 706 unconnected_gradients="zero") 707 708 # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there 709 # is a tf.StopGradient in the loop body. 710 assert all(g is not None for g in grad_outs) 711 counter = args[0] 712 maximum_iterations = args[1] 713 total_iters = args[2] 714 return [counter + 1, maximum_iterations, total_iters] + grad_outs 715 716 717def _resolve_grad_captures(body_graph, body_grad_graph, while_op): 718 """Returns the tensors to pass as captured inputs to `body_grad_graph`. 719 720 `body_grad_graph` may have external references to: 721 1. Its outer graph containing the input gradients. These are left as-is. 722 2. Accumulators captured from the forward-pass graph. These should have been 723 added as `while_op` outputs after the gradient graph was built. We replace 724 these with the corresponding output of `while_op`, i.e. a tensor in 725 `body_graph.outer_graph`. In the case of nested control flow or functions, 726 the gradient logic handling `body_grad_graph.outer_graph` will make sure 727 the tensor from `body_graph.outer_graph` is also correctly captured. 728 729 Args: 730 body_graph: FuncGraph. The forward-pass body function. 731 body_grad_graph: FuncGraph. The body gradients function. 732 while_op: The forward-pass While Operation calling `body_graph`. 733 734 Returns: 735 A list of input tensors to be passed as the captured inputs to 736 `body_grad_graph`. 737 """ 738 new_capture_inputs = [] 739 for t in body_grad_graph.external_captures: 740 # Resolve tensors captured from the forward graph to the outputs of the 741 # forward while_op. 742 if t.graph == body_graph: 743 # Captured accumulator or loop invariant. 744 for i, output in enumerate(t.graph.outputs): 745 if output is t: 746 t = while_op.outputs[i] 747 break 748 749 # Note: We rely on the capturing logic of the gradient While op graph to 750 # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2 751 # and while_v2 handle this while building their gradient functions. 752 assert t.graph == body_graph.outer_graph 753 754 new_capture_inputs.append(t) 755 return new_capture_inputs 756 757 758def _get_structured_grad_output(outputs, grads, body_grad_graph): 759 """Returns the values that should be returned from the while grad function. 760 761 Args: 762 outputs: the raw Tensor outputs of the grad While op. 763 grads: the input gradients to the gradient function. 764 body_grad_graph: _WhileBodyGradFuncGraph. 765 766 Returns: 767 A list of gradient values. May include Nones. 768 """ 769 result = [] 770 # outputs[0] is the loop counter. 771 # outputs[1] is maximum_iterations. 772 # outputs[2] is the total number of loop iterations. 773 outputs_idx = 3 774 structured_outputs_idx = 3 775 for g in grads: 776 # Set None as the output gradient for tensors with None input gradient. 777 if g is None: 778 result.append(None) 779 continue 780 output = body_grad_graph.structured_outputs[structured_outputs_idx] 781 structured_outputs_idx += 1 782 if isinstance(output, ops.IndexedSlices): 783 # TODO(skyewm): is there a more robust way to determine the order of 784 # flattened IndexedSlices components? 785 result.append(ops.IndexedSlices( 786 values=outputs[outputs_idx], 787 indices=outputs[outputs_idx + 1], 788 dense_shape=outputs[outputs_idx + 2])) 789 outputs_idx += 3 790 else: 791 assert isinstance(output, ops.Tensor) 792 result.append(outputs[outputs_idx]) 793 outputs_idx += 1 794 795 return result 796 797 798def _get_accumulator(tensor): 799 r"""Returns TensorList if any containing accumulated values of tensor. 800 801 We try to find a pattern of the form: 802 803 input_tl tensor 804 \ / 805 (TensorListPushBack) 806 | 807 output_tl 808 809 which satisfies the following conditions: 810 811 1. input_tl must be in tensor.graph.inputs. 812 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs. 813 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t). 814 815 output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is 816 returned if such a pattern is found else None is returned. 817 818 Args: 819 tensor: The Tensor to be accumulated. 820 821 Returns: 822 A variant tensor in the same graph as `tensor` or None if no accumulator is 823 found. 824 """ 825 assert isinstance(tensor.graph, func_graph_module.FuncGraph) 826 827 def get_func_graph_output(t): 828 """Returns t or Identity(t) whichever exists in graph outputs else None.""" 829 for output in tensor.graph.outputs: 830 if output is t: 831 return t 832 # tf.defun adds an Identity for each output, check whether that is the case. 833 identity_op = t.consumers()[0] 834 if (identity_op.type == "Identity" and 835 any(identity_op.outputs[0] is t for t in tensor.graph.outputs)): 836 return identity_op.outputs[0] 837 return None 838 839 for consumer in tensor.consumers(): 840 # Find the consumer that is a TensorListPushBack node whose TensorList input 841 # is in the list of function inputs. 842 if consumer.type != "TensorListPushBack": 843 continue 844 845 accum_input_idx = -1 846 for accum_input_idx, inp in enumerate(tensor.graph.inputs): 847 if inp is consumer.inputs[0]: 848 break 849 else: 850 continue 851 852 output = get_func_graph_output(consumer.outputs[0]) 853 if output is None: 854 # The TensorList output of `consumer` is not in the list of function 855 # outputs. 856 continue 857 858 for accum_output_idx, out in enumerate(tensor.graph.outputs): 859 if out is output: 860 if accum_input_idx == accum_output_idx: 861 return output 862 break 863 864 return None 865 866 867OptimizedReductionOpsCacheKey = collections.namedtuple( 868 "OptimizedReductionOpsCacheKey", [ 869 "op_type", 870 "inputs", 871 "dtypes", 872 "input_types", 873 "name", 874 "attrs", 875 "op_def", 876 "compute_device", 877 ]) 878 879 880class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): 881 """FuncGraph for the gradient function of the body of a While op. 882 883 Contains the logic for capturing the tensors from the body of the forward 884 While op which is as follows: 885 1. If the tensor is of resource type (these are not accumulated): 886 a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop 887 inputs and outputs at the same index. 888 b. Lookup the corresponding resource tensor in the forward outer graph and 889 try to capture that. 890 2. If the tensor is not of resource type: 891 a. Create an accumulator for that tensor and output it from the forward 892 pass. Note this also requires adding it as an input to the forward pass. 893 b. Capture the accumulator from the forward pass in this FuncGraph. This 894 will later be resolved to the correct output of the forward While op. 895 c. Pop a value from the captured placeholder and use it as the captured 896 value for the forward pass tensor. 897 898 This only allows capturing tensors in the forward graph. A ValueError is 899 raised if an attempt is made to capture a tensor not in the forward graph. 900 To manually capture a tensor that is not in the forward graph, call `capture` 901 with `allowlisted=True`. 902 903 Note: The `captures` dict does not contain the forward tensor since it is not 904 directly captured. It contains the accumulator corresponding to this forward 905 tensor. 906 907 Attributes: 908 while_op_needs_rewrite: True if any non-resource intermediates were 909 captured, meaning the forward While op needs to be rewritten to output the 910 corresponding accumulators. 911 extra_inputs: list of EmptyTensorList tensors to be used as initial input to 912 the new accumulators in the forward graph. It may also contain external 913 captures of the custom gradient function. 914 internal_capture_to_output: dict from a tensor_id(captured placeholder) to 915 the corresponding tensor that needs to be added to the list of outputs. 916 For instance, when capturing an accumulator TensorList this contains the 917 TensorList obtained after popping a tensor from the list. Other entries 918 in this dict are expected, though not enforced, to be identities. 919 This dict is needed because these output tensors need to be added to 920 FuncGraph.outputs "after" the tensors returned from the gradient function. 921 """ 922 923 def __init__(self, name, forward_cond_graph, forward_body_graph, 924 maximum_iterations, forward_while_op, body_graph_inputs, 925 body_graph_outputs): 926 super(_WhileBodyGradFuncGraph, self).__init__(name) 927 self.extra_inputs = [] 928 self.internal_capture_to_output = {} 929 # FuncGraph for the body of the forward While op. 930 self._forward_graph = forward_body_graph 931 # FuncGraph for the cond of the forward While op. 932 self._forward_cond_graph = forward_cond_graph 933 self._maximum_iterations = maximum_iterations 934 self._forward_while_op = forward_while_op 935 # Dict from forward intermediate tensor to its indirectly captured tensor 936 # in this graph. Indirect capturing happens in two ways: 937 # 1. For non-resource tensors we capture their accumulators from the forward 938 # outer graph and pop values from that accumulator inside this graph 939 # using TensorListPopBack. 940 # 2. For resource tensors we directly capture their corresponding tensor 941 # in the forward outer graph. 942 self._indirect_captures = {} 943 944 @property 945 def while_op_needs_rewrite(self): 946 return self.extra_inputs 947 948 def _create_op_internal( 949 self, 950 op_type, 951 inputs, 952 dtypes=None, # pylint: disable=redefined-outer-name 953 input_types=None, 954 name=None, 955 attrs=None, 956 op_def=None, 957 compute_device=True): 958 # For a reduction op, if op is in the gradient body graph and its input is 959 # from the forward graph, moving op to the forward graph means we would 960 # store the tensor after the reduction as opposed to the tensor before 961 # reduction, and therefore could significantly reduce memory consumption. 962 # For now, we do this only for a few ops. 963 # 964 # We don't do this if any input tensor has already been accumulated. This 965 # can happen if we output all intermediates in the forward pass. 966 # 967 # If in XLA context, do not move constant ops to forward pass as pushing to 968 # and popping from a TensorList removes the constant property of an op and 969 # breaks XLA compilation, which requires certain inputs to be compile-time 970 # constant for certain ops. 971 # 972 # This optimization is currently also disabled when under a persistent tape, 973 # since it leads to an unbounded number of side outputs. With caching it may 974 # be possible to re-enable it. 975 optimized_reduction_ops = { 976 "Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength" 977 } 978 if (op_type in optimized_reduction_ops and 979 not util.output_all_intermediates() and 980 all(input.graph is self._forward_graph for input in inputs) and 981 all(_get_accumulator(input) is None for input in inputs) and 982 not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and 983 not util.graph_wrapped_for_higher_order_tape_gradients( 984 self._forward_graph)): 985 return self._move_op_to_forward_graph( 986 op_type, 987 inputs, 988 dtypes=dtypes, 989 input_types=input_types, 990 name=name, 991 attrs=attrs, 992 op_def=op_def, 993 compute_device=compute_device) 994 995 return super(_WhileBodyGradFuncGraph, self)._create_op_internal( 996 op_type, 997 inputs, 998 dtypes=dtypes, 999 input_types=input_types, 1000 name=name, 1001 attrs=attrs, 1002 op_def=op_def, 1003 compute_device=compute_device) 1004 1005 def _move_op_to_forward_graph( 1006 self, 1007 op_type, 1008 inputs, 1009 dtypes=None, # pylint: disable=redefined-outer-name 1010 input_types=None, 1011 name=None, 1012 attrs=None, 1013 op_def=None, 1014 compute_device=True): 1015 # We have a cache of reduction ops that have already been moved to the 1016 # forward graph, and we will check it first to avoid moving an op twice. 1017 if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"): 1018 self._forward_graph._optimized_reduction_ops_cache = {} 1019 cache_key = self._get_optimized_reduction_ops_cache_key( 1020 op_type, inputs, dtypes, input_types, name, attrs, op_def, 1021 compute_device) 1022 cached_op = self._forward_graph._optimized_reduction_ops_cache.get( 1023 cache_key) 1024 if cached_op is not None: 1025 # This op has already been moved to the forward graph and we have it in 1026 # the cache. 1027 return cached_op 1028 1029 with self._forward_graph.as_default(): 1030 # `name` was built using name_scope stack of gradient graph and may not 1031 # be unique in the forward graph. `Graph.create_op` does not uniquify 1032 # names which are name scopes i.e. end in `/`. To ensure that the op 1033 # created gets a unique name in the forward graph we get rid of the 1034 # trailing slash. 1035 name = ops.name_from_scope_name(name) 1036 result = self._forward_graph._create_op_internal( 1037 op_type, 1038 inputs, 1039 dtypes=dtypes, 1040 input_types=input_types, 1041 name=name, 1042 attrs=attrs, 1043 op_def=op_def, 1044 compute_device=compute_device) 1045 1046 # Store the op we just moved to the forward graph so that it does 1047 # not need to be added there again. 1048 self._forward_graph._optimized_reduction_ops_cache[cache_key] = result 1049 return result 1050 1051 def _get_optimized_reduction_ops_cache_key( 1052 self, 1053 op_type, 1054 inputs, 1055 dtypes=None, # pylint: disable=redefined-outer-name 1056 input_types=None, 1057 name=None, 1058 attrs=None, 1059 op_def=None, 1060 compute_device=True): 1061 # We need all elements of CacheKey to be hashable. 1062 inputs = tuple(map(lambda t: t.ref(), inputs)) 1063 1064 if dtypes is not None: 1065 dtypes = tuple(dtypes) 1066 1067 if input_types is not None: 1068 input_types = tuple(input_types) 1069 1070 if attrs is not None: 1071 hashable_attrs = [] 1072 for attr_name, attr_value in sorted(attrs.items()): 1073 hashable_attrs.append((attr_name, attr_value.SerializeToString())) 1074 attrs = tuple(hashable_attrs) 1075 1076 if op_def is not None: 1077 op_def = op_def.SerializeToString() 1078 1079 return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types, 1080 name, attrs, op_def, compute_device) 1081 1082 def _capture_helper(self, tensor, name): 1083 """Implements the capturing described in the class docstring.""" 1084 captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor)) 1085 if captured_tensor is not None: 1086 return captured_tensor 1087 1088 if tensor.graph is not self._forward_graph: 1089 already_captured = self.captured(tensor) 1090 captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper( 1091 tensor, name) 1092 if not already_captured: 1093 # Adds the captured tensor to the list of outputs so that the input 1094 # and output signatures match. 1095 self.internal_capture_to_output[ops.tensor_id( 1096 captured_tensor)] = captured_tensor 1097 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 1098 return captured_tensor 1099 1100 while tensor.op.type == "Identity": 1101 # We do not accumulate the output of identity nodes so we try to capture 1102 # the input of the Identity node instead. 1103 tensor = tensor.op.inputs[0] 1104 1105 captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor)) 1106 if captured_tensor is not None: 1107 return captured_tensor 1108 1109 # No need to accumulate loop invariants. Capture them directly. 1110 # The captured tensor gets resolved to the corresponding while output in 1111 # `_resolve_grad_captures`. 1112 if _is_loop_invariant(tensor, self._forward_graph.inputs, 1113 self._forward_graph.outputs): 1114 captured_tensor = super(_WhileBodyGradFuncGraph, 1115 self)._capture_helper(tensor, name) 1116 # Add to `internal_capture_to_output` so that this gets added to the list 1117 # of outputs. 1118 self.internal_capture_to_output[ops.tensor_id( 1119 captured_tensor)] = captured_tensor 1120 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 1121 return captured_tensor 1122 1123 # Do not accumulate Const nodes. Instead copy them directly in the backward 1124 # graph. 1125 # TODO(srbs): This just checks for `Const` nodes. Consider checking for 1126 # graph compile time consts in general. 1127 # TODO(srbs): Consider making this a loop input. 1128 if constant_op.is_constant(tensor): 1129 real_value = constant_op.constant( 1130 tensor_util.constant_value(tensor), dtype=tensor.dtype) 1131 self._indirect_captures[ops.tensor_id(tensor)] = real_value 1132 return real_value 1133 1134 # Resource tensors are not accumulated and handled specially. 1135 if tensor.dtype == dtypes.resource: 1136 return self._resource_capture_helper(tensor) 1137 1138 # Create or find an existing accumulator output for `tensor` in the forward 1139 # graph, and fetch from this accumulator in the gradient graph to get the 1140 # raw intermediate value. 1141 accumulator = _get_accumulator(tensor) 1142 if accumulator is None: 1143 # Create the initial empty tensor list. 1144 # 1145 # Note: We clear the control dependencies to avoid a cycle in case a 1146 # control tensor has an input path to an output of the forward While. 1147 # 1148 # E.g.: 1149 # x = tf.while_loop(...) 1150 # y = f(x) 1151 # with tf.control_dependencies([y]): 1152 # tf.gradients(y, x) 1153 # 1154 # Since the EmptyTensorList is fed back into the forward While, not 1155 # removing the control edge would cause a cycle. 1156 with self._forward_graph.outer_graph.as_default(): 1157 with util.clear_control_inputs(): 1158 tensor_list = list_ops.empty_tensor_list( 1159 element_dtype=tensor.dtype, 1160 element_shape=tensor.shape, 1161 max_num_elements=self._maximum_iterations, 1162 name=_build_accumulator_name(tensor)) 1163 self.extra_inputs.append(tensor_list) 1164 1165 # Push the intermediate tensor to the tensor list. This captures 1166 # `tensor_list`. 1167 with self._forward_graph.as_default(): 1168 accumulator = list_ops.tensor_list_push_back(tensor_list, tensor) 1169 # Add the modified tensor list to the list of outputs. This output will be 1170 # all the accumulated values. 1171 self._forward_graph.outputs.append(accumulator) 1172 1173 # Capture in the cond graph as well so the forward cond and body inputs 1174 # match. 1175 with self._forward_cond_graph.as_default(): 1176 self._forward_cond_graph.capture(tensor_list) 1177 1178 # Capture the accumulator tensor list in the gradient graph directly from 1179 # the forward graph -- we'll later modify this to capture the final list 1180 # output by the forward While op instead. 1181 captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper( 1182 accumulator, name) 1183 1184 # Pop the intermediate value from the tensor list in the gradient graph. 1185 new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( 1186 captured_accumulator, element_dtype=tensor.dtype) 1187 1188 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 1189 self.internal_capture_to_output[ops.tensor_id( 1190 captured_accumulator)] = new_tensor_list 1191 return captured_tensor 1192 1193 def _resource_capture_helper(self, tensor): 1194 """Returns the captured resource tensor. 1195 1196 Resource-type tensors are not accumulated. If a resource tensor exists in 1197 the loop body it must either be a loop input or an output of a nested While 1198 op inside the loop body which had captured the external resource. 1199 1200 Args: 1201 tensor: the external resource Tensor to be captured. 1202 1203 Returns: 1204 Tensor in this graph. 1205 """ 1206 assert tensor.dtype == dtypes.resource 1207 1208 index = util.resource_input_index( 1209 tensor.name, [t.name for t in self._forward_graph.inputs], 1210 {op.name: op.node_def for op in self._forward_graph.get_operations()}, 1211 self._forward_graph._functions) 1212 1213 input_placeholder = self._forward_graph.inputs[index] 1214 tensor_in_outer_graph = self._forward_graph._while.inputs[index] 1215 1216 assert input_placeholder.dtype == dtypes.resource 1217 assert tensor_in_outer_graph.dtype == dtypes.resource 1218 # This must be a loop invariant. 1219 assert input_placeholder is self._forward_graph.outputs[index], ( 1220 "Resource tensors must be loop invariants %s." % tensor_in_outer_graph) 1221 1222 self._indirect_captures[ops.tensor_id(tensor)] = self.capture( 1223 tensor_in_outer_graph) 1224 return self._indirect_captures[ops.tensor_id(tensor)] 1225 1226 1227def _check_shapes_compat(output_tensors, shape_invariants, input_tensors): 1228 for (t, shape, input_t) in zip(output_tensors, shape_invariants, 1229 input_tensors): 1230 if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape): 1231 raise ValueError( 1232 "Input tensor '%s' enters the loop with shape %s, but has " 1233 "shape %s after one iteration. To allow the shape to vary across " 1234 "iterations, use the `shape_invariants` argument of tf.while_loop to " 1235 "specify a less-specific shape." % (input_t.name, shape, t.shape)) 1236 1237 1238def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars): 1239 """Checks the number of inputs/outputs of `cond_graph` and `body_graph`.""" 1240 assert len(cond_graph.inputs) == num_flattened_loop_vars, ( 1241 "cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs), 1242 num_flattened_loop_vars)) 1243 assert len(cond_graph.outputs) == 1, ( 1244 "cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs)) 1245 assert len(body_graph.inputs) == num_flattened_loop_vars, ( 1246 "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs), 1247 num_flattened_loop_vars)) 1248 assert len(body_graph.outputs) == num_flattened_loop_vars, ( 1249 "body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs), 1250 num_flattened_loop_vars)) 1251 1252 1253def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars): 1254 for inp, out, loop_var in zip(body_graph.inputs, body_graph.outputs, 1255 flattened_loop_vars): 1256 if inp.dtype != out.dtype: 1257 raise TypeError("Loop var {} enters the loop with type {} " 1258 "but has type {} after 1 iteration.".format( 1259 loop_var.name, inp.dtype, out.dtype)) 1260 1261 1262def _build_cond_placeholders_name_prefix(cond_graph): 1263 return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder") 1264 1265 1266def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures): 1267 """Creates placeholders for body captures in cond_graph. 1268 1269 This is needed to match signatures of cond and body graphs. 1270 1271 Args: 1272 cond_graph: cond branch graph 1273 body_graph_captures: Tensors which were captured when building the 1274 `body_graph`. 1275 """ 1276 types = [t.dtype.as_datatype_enum for t in body_graph_captures] 1277 # TODO(srbs): Providing a unique prefix does not ensure that there is no 1278 # conflict between the placeholder names and existing nodes in the graph. 1279 # However passing a list of strings may not be performant. 1280 # Ideally we should move `Graph.unique_name` to C++ or make 1281 # `Graph._names_in_use` a trie so that we can find a unique prefix. 1282 # TODO(b/143286622): This should not be required once captures are separated 1283 # from regular loop vars. 1284 placeholders = c_api.TF_CreatePlaceholders( 1285 cond_graph._c_graph, types, 1286 compat.as_str(_build_cond_placeholders_name_prefix(cond_graph))) 1287 placeholder_ops = [ 1288 _OperationWithOutputs(ph.oper, cond_graph) 1289 for ph in placeholders 1290 ] 1291 1292 tensors = [] 1293 for op, ph, dtype in zip(placeholder_ops, placeholders, types): 1294 tensor = ops.Tensor._create_with_tf_output(op, 0, dtype, ph) 1295 op._outputs = [tensor] 1296 tensors.append(tensor) 1297 1298 # Update `cond_graph._captures` and `cond_graph.inputs` to contain the 1299 # newly created placeholders. 1300 tuples = zip(body_graph_captures, tensors) 1301 keys = [id(t) for t in body_graph_captures] 1302 cond_graph._captures.update(zip(keys, tuples)) 1303 cond_graph.inputs.extend(tensors) 1304 1305 1306def _copy_handle_data(src_tensors, tgt_tensors): 1307 for src_t, tgt_t in zip(src_tensors, tgt_tensors): 1308 custom_gradient.copy_handle_data(src_t, tgt_t) 1309 1310 1311def _graph_name(graph): 1312 if isinstance(graph, func_graph_module.FuncGraph): 1313 return graph.name 1314 return "Base" 1315 1316 1317def _pack_sequence_as(structure_with_tas, loop_vars): 1318 """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays.""" 1319 1320 def flow_to_tensor_array(flow, ta): # pylint: disable=missing-docstring 1321 return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance( # pylint: disable=g-long-ternary 1322 ta, tensor_array_ops.TensorArray) else flow) 1323 1324 flattened_loop_vars = [ 1325 flow_to_tensor_array(*z) 1326 for z in zip(nest.flatten(loop_vars, expand_composites=True), 1327 nest.flatten(structure_with_tas, expand_composites=True)) 1328 ] 1329 return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars, 1330 expand_composites=True) 1331 1332 1333def _tensor_array_to_flow(loop_vars): 1334 1335 def f(maybe_ta): 1336 if isinstance(maybe_ta, tensor_array_ops.TensorArray): 1337 return maybe_ta.flow 1338 return maybe_ta 1339 1340 return nest.map_structure(f, loop_vars, expand_composites=True) 1341 1342 1343def _build_maximum_iterations_loop_var(maximum_iterations): 1344 if maximum_iterations is None: 1345 # Default value for max_num_elements to EmptyTensorList meaning that the 1346 # list size is unbounded. 1347 maximum_iterations = -1 1348 # EmptyTensorList expects `max_num_elements` to be of type int32. 1349 return ops.convert_to_tensor( 1350 maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") 1351 1352 1353def _build_accumulator_name(tensor): 1354 # Tensor name may be of the form "pow/y:0". Name scope does not allow ":". 1355 return "{}/accumulator".format(tensor.name).replace(":", "_") 1356 1357 1358def _is_loop_invariant(tensor, inputs, outputs): 1359 return (any(tensor is t for t in inputs) and 1360 any(tensor is t for t in outputs)) 1361 1362 1363class _OperationWithOutputs(ops.Operation): 1364 """Operation with pre-built `TF_Output`s. 1365 1366 The C API for creating the extra placeholders for the cond graph returns 1367 SWIG wrapped TF_Output* pointers which we can use directly for 1368 `Operation.outputs`. The default constructor for `Operation` does not provide 1369 a way of specifying pre-built output tensors and always creates them. This is 1370 a performance overhead. It is not clear if adding that feature to the 1371 `Operation` API would be generally useful so for now we just have our own 1372 lightweight `Operation` implementation. Note that this does not extract a 1373 stacktrace as well since we don't expect this operation to be used. 1374 1375 TODO(b/143286622): This should not be required once captures are separated 1376 from regular loop vars. 1377 """ 1378 1379 def __init__(self, c_op, g): 1380 self._c_op = c_op 1381 self._graph = g 1382 self._outputs = None # Initialized by _duplicate_body_captures_in_cond(). 1383 self._id_value = g._add_op(self, self.name) 1384 self._is_stateful = False 1385 1386 1387def _set_read_only_resource_inputs_attr(op, branch_graphs): 1388 """Sets the list of resource inputs which are read-only. 1389 1390 This is used by AutomaticControlDependencies. 1391 1392 Args: 1393 op: While Operation. 1394 branch_graphs: List of branch FuncGraphs. 1395 """ 1396 read_only_indices = set(range(len(op.inputs))) 1397 for branch_graph in branch_graphs: 1398 if not read_only_indices: 1399 break 1400 branch_read_only_indices = acd.get_read_only_resource_input_indices_graph( 1401 branch_graph) 1402 read_only_indices = read_only_indices.intersection(branch_read_only_indices) 1403 1404 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 1405 sorted(read_only_indices)) 1406 1407# pylint: enable=protected-access 1408