1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""while_v2 and gradient. 16 17This is a version of while_loop that emits a single While op, as well as the 18gradient function for While ops produced by while_loop. This will eventually 19replace the current tf.while_loop implementation once it reaches feature and 20performance parity. 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import collections 27 28from tensorflow.core.framework import attr_value_pb2 29from tensorflow.python.client import pywrap_tf_session as c_api 30from tensorflow.python.eager import backprop_util 31from tensorflow.python.framework import auto_control_deps_utils as acd 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import func_graph as func_graph_module 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import tensor_shape 37from tensorflow.python.framework import tensor_spec 38from tensorflow.python.framework import tensor_util 39from tensorflow.python.framework import type_spec 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import control_flow_ops 42from tensorflow.python.ops import control_flow_util as util_v1 43from tensorflow.python.ops import control_flow_util_v2 as util 44from tensorflow.python.ops import default_gradient 45from tensorflow.python.ops import gen_functional_ops 46from tensorflow.python.ops import gen_resource_variable_ops 47from tensorflow.python.ops import gradients_util 48from tensorflow.python.ops import handle_data_util 49from tensorflow.python.ops import list_ops 50from tensorflow.python.ops import math_ops 51from tensorflow.python.ops import tensor_array_ops 52from tensorflow.python.ops import while_v2_indexed_slices_rewriter 53from tensorflow.python.util import compat 54from tensorflow.python.util import nest 55from tensorflow.python.util import object_identity 56 57# pylint: disable=protected-access 58 59# Controls parallelism in the presence of side-effecting ops like variable 60# operations, print, py_function, etc. Can be set to True, False, or 61# "stateless_cond" (default). 62# Note that loops without side-effecting operations always execute with maximum 63# parallelism, ignoring this setting. When False, loops with side-effecting ops 64# execute sequentially, one iteration at a time. 65# When True, loops with side-effecting ops may execute parts of different 66# iterations in parallel; caution: if the loop condition contains 67# side-effecting ops, this mode produces unspecified results. 68# Setting it to "stateless_cond" automatically sets this mode to True when 69# the loop condition is free of side-effecting ops. 70# TODO(b/152548567): Change this to "stateless_cond". 71glob_stateful_parallelism = False 72 73 74def while_loop(cond, 75 body, 76 loop_vars, 77 shape_invariants=None, 78 parallel_iterations=10, 79 maximum_iterations=None, 80 name=None, 81 return_same_structure=True, 82 back_prop=True): 83 """Like tf.while_loop, except emits a single While op.""" 84 # Keep the original loop_vars around to know which args were TensorArrays. 85 orig_loop_vars = loop_vars 86 # Cache its length since we use it at multiple places below. 87 len_orig_loop_vars = len(orig_loop_vars) 88 89 # Convert TensorArrays to their flow variables. These get converted back to 90 # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and 91 # `wrapped_body` below. 92 loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) 93 loop_vars = nest.map_structure( 94 ops.internal_convert_to_tensor_or_indexed_slices, loop_vars, 95 expand_composites=True) 96 if shape_invariants is not None: 97 nest.assert_same_structure(orig_loop_vars, shape_invariants, 98 expand_composites=False) 99 signature = nest.map_structure( 100 control_flow_ops._shape_invariant_to_type_spec, loop_vars, 101 list(shape_invariants), expand_composites=False) 102 shape_invariants = nest.map_structure( 103 control_flow_ops._get_shape_invariant, loop_vars, 104 list(shape_invariants), expand_composites=False) 105 106 else: 107 signature = nest.map_structure( 108 type_spec.type_spec_from_value, loop_vars, expand_composites=False) 109 shape_invariants = nest.map_structure( 110 control_flow_ops._get_shape_invariant, loop_vars, 111 expand_composites=False) 112 if not name: 113 name = "while" 114 115 with ops.name_scope(name) as scope: 116 with ops.name_scope(None): 117 cond_name = util.unique_fn_name(scope, "cond") 118 body_name = util.unique_fn_name(scope, "body") 119 maximum_iterations_loop_var = _build_maximum_iterations_loop_var( 120 maximum_iterations) 121 loop_counter = constant_op.constant( 122 0, 123 dtype=maximum_iterations_loop_var.dtype 124 if maximum_iterations is not None else None, 125 name="loop_counter") 126 # Add loop counter needed for computing gradients. 127 loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars 128 129 shape_invariants = [tensor_shape.TensorShape([])] * 2 + shape_invariants 130 signature = ( 131 [tensor_spec.TensorSpec.from_tensor(loop_counter), 132 tensor_spec.TensorSpec.from_tensor(maximum_iterations_loop_var)] + 133 signature) 134 135 # Automatic control dependencies are added in defuns, but not in v1 136 # graphs. Propagate that behavior here. 137 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 138 139 def wrapped_cond(loop_counter, maximum_iterations_arg, *args): 140 """Extra `cond` wrapper that can handle the extra counter loop_var.""" 141 # Convert the flow variables in `args` to TensorArrays. `args` should 142 # already have the same structure as `orig_loop_vars` but currently there 143 # is no nest.zip so we call `_pack_sequence_as` which flattens both 144 # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays 145 # and packs it into the structure of `orig_loop_vars`. 146 pred = cond(*_pack_sequence_as(orig_loop_vars, args)) 147 if (tensor_util.is_tf_type(pred) and 148 (pred.shape.dims is None or pred.shape.dims)): 149 pred = array_ops.squeeze_v2(pred) 150 151 if maximum_iterations is None: 152 return pred 153 else: 154 return math_ops.logical_and( 155 loop_counter < maximum_iterations_arg, pred) 156 157 # NOTE(skyewm): we set collections to the outer graph's collections for 158 # compatibility with TPUEstimator. 159 cond_graph = func_graph_module.func_graph_from_py_func( 160 cond_name, 161 wrapped_cond, 162 [], # We provide signature instead of args. 163 {}, 164 signature=signature, 165 func_graph=util.WhileCondFuncGraph( 166 cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 167 add_control_dependencies=add_control_dependencies) 168 169 if glob_stateful_parallelism == "stateless_cond": 170 stateful_parallelism = (not any( 171 op._is_stateful for op in cond_graph.get_operations())) 172 else: 173 stateful_parallelism = glob_stateful_parallelism 174 175 def wrapped_body(loop_counter, maximum_iterations_arg, *args): 176 """Loop body augmented with counter update. 177 178 Args: 179 loop_counter: Loop counter which needs to be incremented in the body. 180 maximum_iterations_arg: Maximum iterations of the loop. 181 *args: List of args 182 183 Returns: 184 A list of tensors the same length as args. 185 """ 186 # The function was created with a signature rather than tensors, so 187 # internal placeholders were created without handle data. 188 _copy_handle_data(nest.flatten(loop_vars[2:], expand_composites=True), 189 nest.flatten(args, expand_composites=True)) 190 # Capture the tensors already captured in cond_graph so that they appear 191 # in the same order in body_graph.external_captures. 192 for t in cond_graph.external_captures: 193 ops.get_default_graph().capture(t) 194 195 # Convert the flow variables in `args` to TensorArrays. `args` should 196 # already have the same structure as `orig_loop_vars` but currently there 197 # is no nest.zip so we call `_pack_sequence_as` which flattens both 198 # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays 199 # and packs it into the structure of `orig_loop_vars`. 200 outputs = body(*_pack_sequence_as(orig_loop_vars, args)) 201 if not nest.is_sequence_or_composite(outputs): 202 outputs = [outputs] 203 # Compare the structure of input and output of body converting the 204 # top-level tuples to list to be compatible with legacy while_loop. 205 nest.assert_same_structure(list(outputs), list(orig_loop_vars), 206 expand_composites=True) 207 208 outputs = _tensor_array_to_flow(outputs) 209 210 # TODO(srbs): Update lowering code to create _Enter nodes with 211 # is_constant=True for inputs that are directly passed to outputs. 212 return [loop_counter + 1, maximum_iterations_arg] + list(outputs) 213 214 body_graph = func_graph_module.func_graph_from_py_func( 215 body_name, 216 wrapped_body, 217 [], # We provide signature instead of args. 218 {}, 219 signature=signature, 220 func_graph=util.WhileBodyFuncGraph( 221 body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 222 add_control_dependencies=add_control_dependencies, 223 acd_record_initial_resource_uses=stateful_parallelism) 224 # Add external captures of body to the list of loop vars. 225 # Note that external tensors will be treated as loop invariants, i.e., 226 # the value of that tensor in each iteration is the same as it was at the 227 # beginning of the loop execution. 228 loop_vars = loop_vars + body_graph.external_captures 229 # TODO(srbs): Update lowering code to create _Enter nodes with 230 # is_constant=True for inputs that are directly passed to outputs. 231 body_graph.outputs.extend(body_graph.internal_captures) 232 233 # Capture the extra `external_captures` of `body_graph` in `cond_graph` so 234 # that it expects to receive those as arguments. 235 with cond_graph.as_default(): 236 num_cond_captures = len(cond_graph.external_captures) 237 assert (cond_graph.external_captures == 238 body_graph.external_captures[:num_cond_captures]) 239 _duplicate_body_captures_in_cond( 240 cond_graph, body_graph.external_captures[num_cond_captures:]) 241 242 # Make sure that the shapes of the loop outputs are compatible with the 243 # shape invariants, or the shapes of the loop vars if the invariants are not 244 # specified. 245 num_flattened_outputs = len(nest.flatten(orig_loop_vars, 246 expand_composites=True)) 247 # First var is loop counter and second var is maximum_iterations. 248 first_loop_var_index = 2 249 _check_shapes_compat( 250 body_graph.outputs[first_loop_var_index:first_loop_var_index + 251 num_flattened_outputs], 252 nest.flatten( 253 shape_invariants[first_loop_var_index:first_loop_var_index + 254 len_orig_loop_vars], expand_composites=True), 255 nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + 256 len_orig_loop_vars], expand_composites=True)) 257 258 num_original_outputs = len(body_graph.outputs) 259 if back_prop and util.output_all_intermediates(): 260 # Export all tensors in the loop body that may be needed for gradient 261 # computation. We do this by accumulating the intermediate values in 262 # TensorLists. 263 intermediate_tensors = _get_intermediates(body_graph) 264 265 for intermediate_tensor in intermediate_tensors: 266 tensor_list = list_ops.empty_tensor_list( 267 element_dtype=intermediate_tensor.dtype, 268 element_shape=intermediate_tensor.shape, 269 max_num_elements=maximum_iterations) 270 loop_vars.append(tensor_list) 271 with cond_graph.as_default(): 272 # Add a placeholder to cond_graph's inputs corresponding to the 273 # tensor_list. 274 cond_graph.capture(tensor_list) 275 with body_graph.as_default(): 276 # Push the intermediate tensor to the tensor list. This captures the 277 # `tensor_list` as well. 278 appended_tensor_list = list_ops.tensor_list_push_back( 279 tensor_list, intermediate_tensor) 280 # Add this modified tensor list to the list of outputs. 281 body_graph.outputs.append(appended_tensor_list) 282 283 flattened_loop_vars = nest.flatten(loop_vars, expand_composites=True) 284 _check_num_inputs_outputs(cond_graph, body_graph, 285 len(flattened_loop_vars)) 286 _check_inputs_outputs_types_match(body_graph, flattened_loop_vars) 287 288 with ops.control_dependencies( 289 list(cond_graph.control_captures) + list(body_graph.control_captures)): 290 output_shapes = [t.shape for t in body_graph.outputs] 291 orig_loop_vars_range = slice(first_loop_var_index, 292 first_loop_var_index + num_flattened_outputs) 293 output_shapes[orig_loop_vars_range] = nest.flatten( 294 shape_invariants, expand_composites=True)[orig_loop_vars_range] 295 296 outputs = _build_while_op( 297 flattened_loop_vars, 298 cond_graph, 299 body_graph, 300 output_shapes=output_shapes, 301 parallel_iterations=parallel_iterations, 302 name=scope, 303 num_original_outputs=num_original_outputs, 304 stateful_parallelism=stateful_parallelism) 305 if not ops.get_default_graph().building_function: 306 # In V1 graph mode, return identities for each output of the While op, 307 # rather than the output of the While op directly. This makes pruning work 308 # if the output of while_loop() is fetched: the lowering pass converts the 309 # While outputs into IdentityN outputs, which if fetched will cause all 310 # ops in the body to be run (since it takes all exit ops as input). After 311 # lowering, each output identity op will end up with only the appropriate 312 # exit op as input. 313 outputs = tuple(array_ops.identity(t) for t in outputs) 314 315 output_loop_vars = outputs[first_loop_var_index:first_loop_var_index + 316 num_flattened_outputs] 317 if not back_prop: 318 output_loop_vars = [array_ops.stop_gradient(t) for t in output_loop_vars] 319 outputs = _pack_sequence_as(orig_loop_vars, output_loop_vars) 320 321 if return_same_structure: 322 return outputs 323 324 flattened_outputs = nest.flatten(outputs, expand_composites=True) 325 if len(flattened_outputs) == 1: 326 return flattened_outputs[0] 327 else: 328 return outputs 329 330 331@ops.RegisterGradient("StatelessWhile") 332@ops.RegisterGradient("While") 333def _WhileGrad(op, *grads): # pylint: disable=invalid-name 334 """The gradient of a While op produced by while_loop.""" 335 # Note that op is not always the same as while_op because the gradient tape, 336 # for eager mode compatibility, forgets information about the proper op. Since 337 # the loop cannot run in eager mode, however, we can safely introspect into 338 # the graph here. 339 while_op = op.outputs[0].op 340 cond_graph = _get_graph(while_op, "cond", "_cond_graph") 341 body_graph = _get_graph(while_op, "body", "_body_graph") 342 orig_num_params = len(body_graph.outputs) 343 344 maximum_iterations = op.inputs[1] 345 parallel_iterations = op.get_attr("parallel_iterations") 346 347 try: 348 num_original_outputs = while_op.get_attr("_num_original_outputs") 349 except: # pylint: disable=bare-except 350 num_original_outputs = len(while_op.outputs) 351 352 try: 353 stateful_parallelism = while_op.get_attr("_stateful_parallelism") 354 except: # pylint: disable=bare-except 355 stateful_parallelism = False 356 357 num_intermediates = len(while_op.outputs) - num_original_outputs 358 grads = [ 359 _preprocess_grad(grad, body_out, while_in, while_out) # pylint: disable=g-complex-comprehension 360 for grad, body_out, while_in, while_out in zip( 361 grads[:num_original_outputs], 362 body_graph.outputs[:num_original_outputs], 363 while_op.inputs[:num_original_outputs], 364 while_op.outputs[:num_original_outputs]) 365 ] + [None] * num_intermediates 366 367 # Skip gradients with respect to the captures whenever possible. 368 if "skip_input_indices" in op.__dict__ and op.skip_input_indices is not None: 369 captures_start_index = ( 370 len(body_graph.inputs) - len(body_graph.internal_captures)) 371 for i in op.skip_input_indices: 372 if i >= captures_start_index: 373 grads[i] = None 374 375 # We compute the gradient for the sub-graph between trainable ys and xs 376 # with non-None incoming gradients. We later pad the None's to the list of 377 # outputs. 378 ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( 379 body_graph.outputs, body_graph.inputs, grads) if grad is not None]) 380 381 body_grad_graph, args = _create_grad_func( 382 ys, xs, non_none_grads, cond_graph, body_graph, 383 util.unique_grad_fn_name(body_graph.name), op, maximum_iterations, 384 stateful_parallelism) 385 386 if body_grad_graph.while_op_needs_rewrite: 387 # Modify 'op' to output the intermediate accumulators needed by the grad 388 # function. 389 # NOTE(skyewm): if there are any active sessions, this modification to `op` 390 # may make them unrunnable! 391 392 cond_graph.name += "_rewritten" 393 body_graph.name += "_rewritten" 394 395 # `body_grad_graph.extra_inputs` here is equivalent to skimming off the new 396 # `body_graph.external_captures` added during `_create_grad_func`. 397 new_inputs = body_grad_graph.extra_inputs 398 new_outputs = body_graph.outputs[orig_num_params:] 399 400 while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) 401 while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) 402 if len(body_graph.output_types) != len(while_op.inputs) + len(new_inputs): 403 # Continuing leads to an invalid graph with disconnected inputs. 404 raise AssertionError( 405 "Inputs and outputs constructed for the forward op of a While " 406 "gradient don't match with 'output_types' at " 407 f"{len(body_graph.output_types)},'inputs' at length " 408 f"{len(while_op.inputs)}, and 'new_inputs' at length " 409 f"{len(new_inputs)}. This doesn't make sense, please file a bug.") 410 while_op._set_type_list_attr("T", body_graph.output_types) 411 while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) 412 while_op._add_while_inputs(new_inputs) 413 while_op._add_outputs([t.dtype for t in new_outputs], 414 [t.shape for t in new_outputs]) 415 _copy_handle_data(new_outputs, while_op.outputs[orig_num_params:]) 416 417 # Do not ignore grads wrt extra outputs when computing higher order 418 # derivatives. 419 while_op._set_attr("_num_original_outputs", 420 attr_value_pb2.AttrValue(i=len(while_op.outputs))) 421 422 captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, 423 while_op) 424 loop_vars = args + captured_inputs 425 426 # This modifies body_grad_graph. 427 loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( 428 grads, body_grad_graph, loop_vars, while_op.inputs) 429 430 def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters, 431 *unused_args): 432 return counter < forward_loop_iters 433 434 grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) 435 cond_grad_graph = func_graph_module.func_graph_from_py_func( 436 grad_cond_name, grad_cond, loop_vars, {}, 437 func_graph=util.WhileCondFuncGraph(grad_cond_name)) 438 439 _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) 440 441 outputs = _build_while_op( 442 loop_vars, 443 cond_grad_graph, 444 body_grad_graph, 445 output_shapes=[t.shape for t in body_grad_graph.outputs], 446 parallel_iterations=parallel_iterations, 447 name="%s_grad" % while_op.name, 448 num_original_outputs=len(body_grad_graph.outputs), 449 stateful_parallelism=stateful_parallelism) 450 451 # See comment in while_loop. 452 outputs = [array_ops.identity(t) for t in outputs] 453 return _get_structured_grad_output(outputs, grads, body_grad_graph) 454 455 456def _build_while_op(loop_vars, cond_graph, body_graph, output_shapes, 457 parallel_iterations, name, num_original_outputs, 458 stateful_parallelism): 459 """Builds the functional StatelessWhile/While op.""" 460 cond_stateful_ops = [ 461 op for op in cond_graph.get_operations() if op._is_stateful 462 ] 463 body_stateful_ops = [ 464 op for op in body_graph.get_operations() if op._is_stateful 465 ] 466 if (cond_stateful_ops or body_stateful_ops): 467 op_fn = gen_functional_ops._while 468 else: 469 op_fn = gen_functional_ops.stateless_while 470 471 def _make_op(inputs): 472 while_op, tensors = util.get_op_and_outputs(op_fn( 473 inputs, 474 util.create_new_tf_function(cond_graph), 475 util.create_new_tf_function(body_graph), 476 output_shapes=output_shapes, 477 parallel_iterations=parallel_iterations, 478 name=name)) 479 _copy_handle_data(body_graph.outputs, tensors) 480 util.maybe_set_lowering_attr(while_op) 481 util.maybe_propagate_compile_time_consts_in_xla(while_op) 482 _set_read_only_resource_inputs_attr(while_op, [cond_graph, body_graph]) 483 # This is needed so we do not compute derivative wrt these extra outputs. 484 while_op._set_attr("_num_original_outputs", 485 attr_value_pb2.AttrValue(i=num_original_outputs)) 486 while_op._set_attr("_stateful_parallelism", 487 attr_value_pb2.AttrValue(b=stateful_parallelism)) 488 # The while op may be created inside a tf.function, in which case ops 489 # needs to capture "through" it when taking gradients; outer_graph is used 490 # as a sanity check that capturing only happens from parent to child. 491 cond_graph.outer_graph = ops.get_default_graph() 492 body_graph.outer_graph = ops.get_default_graph() 493 while_op._cond_graph = cond_graph 494 while_op._body_graph = body_graph 495 return tensors 496 return util.run_as_function_for_tape_gradients(_make_op, loop_vars) 497 498 499def _get_intermediates(func_graph): 500 """Returns all tensors in `func_graph` that should be accumulated.""" 501 # We currently accumulate output tensors of most ops in the function and rely 502 # on the pruning pass to get rid of the unused accumulators at runtime. 503 # However, this can bloat the GraphDef and make debugging harder so we perform 504 # some optimizations. 505 # 506 # Optimization we currently perform: 507 # 1. We do not accumulate tensors which already have an accumulator 508 # in the loop body. 509 # 2. We do not accumulate outputs of Identity nodes. When building the 510 # FuncGraph, we add an Identity node for each output (see 511 # `AutomaticControlDependencies.mark_as_return`). Accumulating outputs 512 # of all these nodes bloats the GraphDef quite a bit so we remove those. 513 # Since the gradient of an Identity node does not rely on its forward op's 514 # input this is safe to do. 515 # 516 # Other possible optimizations: 517 # 1. Only accumulate tensors that will be required by the backward pass. 518 # This will require running the gradient pass and hence would increase the 519 # graph building time for the forward pass. 520 # 2. Do not accumulate Const nodes created inside the loop body. 521 # 3. Do not accumulate loop vars that are returned as-is just like captured 522 # tensors. 523 intermediates = [] 524 reverse_captures = dict((v.ref(), k) for k, v in func_graph.captures) 525 526 for op in func_graph.get_operations(): 527 if op.type == "Identity": 528 continue 529 # Accumulating mutexes can cause deadlock. 530 if op.type == "MutexLock": 531 continue 532 for o in op.outputs: 533 if (o is not func_graph.inputs[0] and # Loop counter. 534 o.dtype != dtypes.resource and # Do not accumulate resource tensors. 535 _get_accumulator(o) is None and # Has existing accumulator. 536 o.ref() not in reverse_captures 537 ): # Captured value, hence loop invariant. 538 intermediates.append(o) 539 return intermediates 540 541 542def _preprocess_grad(grad, body_graph_output, while_op_input, while_op_output): 543 """Returns the initial gradient to be used for a given output tensor. 544 545 Args: 546 grad: the original gradient Tensor passed to the gradient function. 547 body_graph_output: the corresponding Tensor in the body graph. 548 while_op_input: the corresponding Tensor input of the While op. 549 while_op_output: the corresponding Tensor output of the While op. 550 551 Returns: 552 A Tensor or None. 553 """ 554 # Set the incoming gradient of non-trainable inputs to None. It is possible 555 # that we receive non-None gradients for non-trainable types in nested while 556 # loops because we accumulate outputs of the inner while as variant tensors 557 # which are trainable and hence receive zeros_like tensors in the gradient 558 # pass. The non-trainable tensors then receive the popped zeros tensor from 559 # this zeros variant. The gradient for the loop vars corresponding to these 560 # tensors is None or zeros (this happens only if the loop var is accumulated 561 # as well) in _grad_fn so we reset these. 562 # TODO(b/118712257): Remove once we can handle None output grads in _grad_fn. 563 if not _is_trainable(body_graph_output): 564 return None 565 566 # GradientTape initializes resource and variant grads as None instead of 567 # zeros. Set to zeros so _GradientsHelper computes the gradients instead of 568 # returning None. 569 # TODO(b/143286622): The supports_default_grad check is needed 570 # because While op emits non-differentiable resource tensors 571 # as outputs. Remove this check when that is not the case. 572 # Note: We use `while_op_input` instead of `while_op_output` for the call 573 # to `supports_default_grad` because `while_op_output` may be missing 574 # handle_data if the While is in a restored saved model. 575 if (while_op_output.dtype in (dtypes.resource, dtypes.variant) and 576 default_gradient.supports_default_grad(while_op_input) and grad is None): 577 return _zeros_like(while_op_input, while_op_output) 578 579 # Convert IndexedSlices to dense tensors since it is unlikely that downstream 580 # gradient functions with properly handle indexed slices. This is similar to 581 # what we do in tf.function gradients. 582 if isinstance(grad, ops.IndexedSlices): 583 return ops.convert_to_tensor(grad) 584 585 return grad 586 587 588# TODO(skyewm): make this return constants if op_output's shape is fully 589# defined (this can be done by checking the "shape" attr of resource vars). 590def _zeros_like(op_input, op_output): 591 """Like array_ops.zeros_like() but also accepts resource var handles.""" 592 if op_output.dtype == dtypes.resource: 593 # Note: We use `op_input` instead of `op_output` to get the zeros dtype 594 # because `op_output` may be missing handle_data if the While is in a 595 # restored saved model. 596 return array_ops.zeros( 597 gen_resource_variable_ops.variable_shape(op_output), 598 dtype=default_gradient.get_zeros_dtype(op_input)) 599 return array_ops.zeros_like(op_output) 600 601 602def _is_trainable(tensor): 603 """Returns whether the given tensor is trainable.""" 604 if not backprop_util.IsTrainable(tensor): 605 return False 606 607 # Special case: untrainable accumulator output. The gradients algorithm 608 # doesn't know about tensor lists of untrainable elements. In theory the 609 # tensor list gradient functions should return None as appropriate, but 610 # because we can't return None from the gradient function we filter out 611 # untrainable accumulator output here to avoid computing the gradient at all. 612 if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0: 613 assert tensor.dtype == dtypes.variant 614 element_type = tensor.op.get_attr("element_dtype") 615 return backprop_util.IsTrainable(element_type) 616 617 return True 618 619 620def _get_graph(while_op, func_attr_name, attr_graph_name): 621 """Returns `FuncGraph` for the given function attribute. 622 623 Args: 624 while_op: The While Operation. 625 func_attr_name: string 626 attr_graph_name: cached forward graph name 627 628 Returns: 629 `FuncGraph` 630 """ 631 func_graph = getattr(while_op, attr_graph_name, None) 632 if func_graph is None: 633 # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. 634 input_shapes = [ 635 tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") 636 ] 637 func_name = while_op.get_attr(func_attr_name).name 638 func_graph = util.get_func_graph(while_op, input_shapes, func_name) 639 func_graph._while = while_op 640 return func_graph 641 642 643def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op, 644 maximum_iterations, stateful_parallelism): 645 """Builds and returns the gradient FuncGraph of `func_graph` and its args. 646 647 The returned grad_func_graph must be called with the returned 648 args + grad_func_graph.captures. 649 650 Args: 651 ys: A `Tensor` or list of tensors to be differentiated. 652 xs: A `Tensor` or list of tensors to be used for differentiation. 653 grads: The incoming grads for `ys`. 654 cond_graph: FuncGraph for the forward cond function. 655 body_graph: FuncGraph for the forward body function. 656 name: Name of the returned gradient function. 657 while_op: The forward While op. 658 maximum_iterations: Tensor. The maximum number of iterations. 659 stateful_parallelism: Bool, see tf.while_loop. 660 661 Returns: 662 2-tuple of (grad_func_graph, args). 663 """ 664 assert len(ys) == len(grads) 665 666 total_iters = while_op.outputs[0] 667 counter = constant_op.constant( 668 0, dtype=total_iters.dtype, name="grad_counter") 669 670 # Build frozen sets so that we do not have linear time lookups in 671 # `_is_loop_invariant`. Note: `body_graph.inputs` and `body_graph.outputs` 672 # may get updated during gradient computation because we add accumulators to 673 # the forward op. However, those are not loop invariants so wouldn't affect 674 # the output of `_is_loop_invariant`. Also we would never attempt to capture 675 # those accumulators so `_is_loop_invariant` should never receive those new 676 # tensors as args. 677 body_graph_inputs = object_identity.ObjectIdentitySet(body_graph.inputs) 678 body_graph_outputs = object_identity.ObjectIdentitySet(body_graph.outputs) 679 680 args = [counter, maximum_iterations, total_iters] + list(grads) 681 # Note: The returned function does not have `args` in the list of 682 # `external_captures`. 683 grad_func_graph = func_graph_module.func_graph_from_py_func( 684 name, 685 lambda *args: _grad_fn(ys, xs, args, body_graph), 686 args, {}, 687 func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph, 688 maximum_iterations, while_op, 689 body_graph_inputs, body_graph_outputs), 690 acd_record_initial_resource_uses=stateful_parallelism) 691 692 # Update the list of outputs with tensors corresponding to the captured 693 # tensors. We capture 3 types of tensors when building the grad fn: 694 # 1. Accumulators for forward graph intermediates which are not loop 695 # invariants. The outputs corresponding to these are populated in 696 # `internal_capture_to_output` by `_WhileBodyGradFuncGraph`. 697 # 2. Resources, which are output as is. 698 # 3. Forward graph loop invariants, which are output as is. 699 for external_capture, internal_capture in grad_func_graph.captures: 700 if (ops.tensor_id(internal_capture) 701 in grad_func_graph.internal_capture_to_output): 702 new_output = grad_func_graph.internal_capture_to_output[ops.tensor_id( 703 internal_capture)] 704 else: 705 raise ValueError( 706 f"Tensor {str(internal_capture)} which captures " 707 f"{str(external_capture)} is in list of " 708 f"internal_captures but not in internal_capture_to_output.") 709 grad_func_graph.outputs.append(new_output) 710 grad_func_graph.structured_outputs.append(new_output) 711 712 return grad_func_graph, args 713 714 715def _grad_fn(ys, xs, args, func_graph): 716 """Computes the gradient of `func_graph` in the current graph. 717 718 This function builds the gradient graph of the corresponding forward-pass 719 `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs. 720 721 Args: 722 ys: A `Tensor` or list of tensors to be differentiated. 723 xs: A `Tensor` or list of tensors to be used for differentiation. 724 args: The input arguments. 725 args[0] - Loop counter 726 args[1] - Total number of iterations. 727 args[2] - maximum_iterations. 728 args[3:] - Incoming gradients for `ys`. 729 func_graph: function.FuncGraph. The corresponding forward-pass function. 730 731 Returns: 732 The output gradient Tensors. 733 """ 734 grad_ys = args[3:] 735 736 # Build the gradient graph. Note that this builds the gradient computation of 737 # func_graph in the current graph, which requires capturing tensors from 738 # func_graph. The captured func_graph tensors are resolved to external tensors 739 # after the forward While op has been rewritten in _resolve_grad_captures. 740 # TODO(srbs): Mark GradientsHelper as public? 741 grad_outs = gradients_util._GradientsHelper( 742 ys, xs, grad_ys=grad_ys, src_graph=func_graph, 743 unconnected_gradients="zero") 744 745 # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there 746 # is a tf.StopGradient in the loop body. 747 assert all(g is not None for g in grad_outs) 748 counter = args[0] 749 maximum_iterations = args[1] 750 total_iters = args[2] 751 return [counter + 1, maximum_iterations, total_iters] + grad_outs 752 753 754def _resolve_grad_captures(body_graph, body_grad_graph, while_op): 755 """Returns the tensors to pass as captured inputs to `body_grad_graph`. 756 757 `body_grad_graph` may have external references to: 758 1. Its outer graph containing the input gradients. These are left as-is. 759 2. Accumulators captured from the forward-pass graph. These should have been 760 added as `while_op` outputs after the gradient graph was built. We replace 761 these with the corresponding output of `while_op`, i.e. a tensor in 762 `body_graph.outer_graph`. In the case of nested control flow or functions, 763 the gradient logic handling `body_grad_graph.outer_graph` will make sure 764 the tensor from `body_graph.outer_graph` is also correctly captured. 765 766 Args: 767 body_graph: FuncGraph. The forward-pass body function. 768 body_grad_graph: FuncGraph. The body gradients function. 769 while_op: The forward-pass While Operation calling `body_graph`. 770 771 Returns: 772 A list of input tensors to be passed as the captured inputs to 773 `body_grad_graph`. 774 """ 775 new_capture_inputs = [] 776 for t in body_grad_graph.external_captures: 777 # Resolve tensors captured from the forward graph to the outputs of the 778 # forward while_op. 779 if t.graph == body_graph: 780 # Captured accumulator or loop invariant. 781 for i, output in enumerate(t.graph.outputs): 782 if output is t: 783 t = while_op.outputs[i] 784 break 785 786 # Note: We rely on the capturing logic of the gradient While op graph to 787 # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2 788 # and while_v2 handle this while building their gradient functions. 789 assert t.graph == body_graph.outer_graph 790 791 new_capture_inputs.append(t) 792 return new_capture_inputs 793 794 795def _get_structured_grad_output(outputs, grads, body_grad_graph): 796 """Returns the values that should be returned from the while grad function. 797 798 Args: 799 outputs: the raw Tensor outputs of the grad While op. 800 grads: the input gradients to the gradient function. 801 body_grad_graph: _WhileBodyGradFuncGraph. 802 803 Returns: 804 A list of gradient values. May include Nones. 805 """ 806 result = [] 807 # outputs[0] is the loop counter. 808 # outputs[1] is maximum_iterations. 809 # outputs[2] is the total number of loop iterations. 810 outputs_idx = 3 811 structured_outputs_idx = 3 812 for g in grads: 813 # Set None as the output gradient for tensors with None input gradient. 814 if g is None: 815 result.append(None) 816 continue 817 output = body_grad_graph.structured_outputs[structured_outputs_idx] 818 structured_outputs_idx += 1 819 if isinstance(output, ops.IndexedSlices): 820 # TODO(skyewm): is there a more robust way to determine the order of 821 # flattened IndexedSlices components? 822 result.append(ops.IndexedSlices( 823 values=outputs[outputs_idx], 824 indices=outputs[outputs_idx + 1], 825 dense_shape=outputs[outputs_idx + 2])) 826 outputs_idx += 3 827 else: 828 assert isinstance(output, ops.Tensor) 829 result.append(outputs[outputs_idx]) 830 outputs_idx += 1 831 832 return result 833 834 835def _get_accumulator(tensor): 836 r"""Returns TensorList if any containing accumulated values of tensor. 837 838 We try to find a pattern of the form: 839 840 input_tl tensor 841 \ / 842 (TensorListPushBack) 843 | 844 output_tl 845 846 which satisfies the following conditions: 847 848 1. input_tl must be in tensor.graph.inputs. 849 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs. 850 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t). 851 852 output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is 853 returned if such a pattern is found else None is returned. 854 855 Args: 856 tensor: The Tensor to be accumulated. 857 858 Returns: 859 A variant tensor in the same graph as `tensor` or None if no accumulator is 860 found. 861 """ 862 assert isinstance(tensor.graph, func_graph_module.FuncGraph) 863 864 def get_func_graph_output(t): 865 """Returns t or Identity(t) whichever exists in graph outputs else None.""" 866 for output in tensor.graph.outputs: 867 if output is t: 868 return t 869 # tf.defun adds an Identity for each output, check whether that is the case. 870 identity_op = t.consumers()[0] 871 if (identity_op.type == "Identity" and 872 any(identity_op.outputs[0] is t for t in tensor.graph.outputs)): 873 return identity_op.outputs[0] 874 return None 875 876 for consumer in tensor.consumers(): 877 # Find the consumer that is a TensorListPushBack node whose TensorList input 878 # is in the list of function inputs. 879 if consumer.type != "TensorListPushBack": 880 continue 881 882 accum_input_idx = -1 883 for accum_input_idx, inp in enumerate(tensor.graph.inputs): 884 if inp is consumer.inputs[0]: 885 break 886 else: 887 continue 888 889 output = get_func_graph_output(consumer.outputs[0]) 890 if output is None: 891 # The TensorList output of `consumer` is not in the list of function 892 # outputs. 893 continue 894 895 for accum_output_idx, out in enumerate(tensor.graph.outputs): 896 if out is output: 897 if accum_input_idx == accum_output_idx: 898 return output 899 break 900 901 return None 902 903 904OptimizedReductionOpsCacheKey = collections.namedtuple( 905 "OptimizedReductionOpsCacheKey", [ 906 "op_type", 907 "inputs", 908 "dtypes", 909 "input_types", 910 "name", 911 "attrs", 912 "op_def", 913 "compute_device", 914 ]) 915 916 917class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): 918 """FuncGraph for the gradient function of the body of a While op. 919 920 Contains the logic for capturing the tensors from the body of the forward 921 While op which is as follows: 922 1. If the tensor is of resource type (these are not accumulated): 923 a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop 924 inputs and outputs at the same index. 925 b. Lookup the corresponding resource tensor in the forward outer graph and 926 try to capture that. 927 2. If the tensor is not of resource type: 928 a. Create an accumulator for that tensor and output it from the forward 929 pass. Note this also requires adding it as an input to the forward pass. 930 b. Capture the accumulator from the forward pass in this FuncGraph. This 931 will later be resolved to the correct output of the forward While op. 932 c. Pop a value from the captured placeholder and use it as the captured 933 value for the forward pass tensor. 934 935 This only allows capturing tensors in the forward graph. A ValueError is 936 raised if an attempt is made to capture a tensor not in the forward graph. 937 To manually capture a tensor that is not in the forward graph, call `capture` 938 with `allowlisted=True`. 939 940 Note: The `captures` dict does not contain the forward tensor since it is not 941 directly captured. It contains the accumulator corresponding to this forward 942 tensor. 943 944 Attributes: 945 while_op_needs_rewrite: True if any non-resource intermediates were 946 captured, meaning the forward While op needs to be rewritten to output the 947 corresponding accumulators. 948 extra_inputs: list of EmptyTensorList tensors to be used as initial input to 949 the new accumulators in the forward graph. It may also contain external 950 captures of the custom gradient function. 951 internal_capture_to_output: dict from a tensor_id(captured placeholder) to 952 the corresponding tensor that needs to be added to the list of outputs. 953 For instance, when capturing an accumulator TensorList this contains the 954 TensorList obtained after popping a tensor from the list. Other entries 955 in this dict are expected, though not enforced, to be identities. 956 This dict is needed because these output tensors need to be added to 957 FuncGraph.outputs "after" the tensors returned from the gradient function. 958 """ 959 960 def __init__(self, name, forward_cond_graph, forward_body_graph, 961 maximum_iterations, forward_while_op, body_graph_inputs, 962 body_graph_outputs): 963 super(_WhileBodyGradFuncGraph, self).__init__(name) 964 self.extra_inputs = [] 965 self.internal_capture_to_output = {} 966 # FuncGraph for the body of the forward While op. 967 self._forward_graph = forward_body_graph 968 # FuncGraph for the cond of the forward While op. 969 self._forward_cond_graph = forward_cond_graph 970 self._maximum_iterations = maximum_iterations 971 self._forward_while_op = forward_while_op 972 # Dict from forward intermediate tensor to its indirectly captured tensor 973 # in this graph. Indirect capturing happens in two ways: 974 # 1. For non-resource tensors we capture their accumulators from the forward 975 # outer graph and pop values from that accumulator inside this graph 976 # using TensorListPopBack. 977 # 2. For resource tensors we directly capture their corresponding tensor 978 # in the forward outer graph. 979 self._indirect_captures = {} 980 981 @property 982 def while_op_needs_rewrite(self): 983 return self.extra_inputs 984 985 def _create_op_internal( 986 self, 987 op_type, 988 inputs, 989 dtypes=None, # pylint: disable=redefined-outer-name 990 input_types=None, 991 name=None, 992 attrs=None, 993 op_def=None, 994 compute_device=True): 995 # For a reduction op, if op is in the gradient body graph and its input is 996 # from the forward graph, moving op to the forward graph means we would 997 # store the tensor after the reduction as opposed to the tensor before 998 # reduction, and therefore could significantly reduce memory consumption. 999 # For now, we do this only for a few ops. 1000 # 1001 # We don't do this if any input tensor has already been accumulated. This 1002 # can happen if we output all intermediates in the forward pass. 1003 # 1004 # If in XLA context, do not move constant ops to forward pass as pushing to 1005 # and popping from a TensorList removes the constant property of an op and 1006 # breaks XLA compilation, which requires certain inputs to be compile-time 1007 # constant for certain ops. 1008 # 1009 # This optimization is currently also disabled when under a persistent tape, 1010 # since it leads to an unbounded number of side outputs. With caching it may 1011 # be possible to re-enable it. 1012 optimized_reduction_ops = { 1013 "Shape", "Size", "Rank", "TensorListElementShape", "TensorListLength" 1014 } 1015 if (op_type in optimized_reduction_ops and 1016 not util.output_all_intermediates() and 1017 all(input.graph is self._forward_graph for input in inputs) and 1018 all(_get_accumulator(input) is None for input in inputs) and 1019 not util_v1.GraphOrParentsInXlaContext(self._forward_graph) and 1020 not util.graph_wrapped_for_higher_order_tape_gradients( 1021 self._forward_graph)): 1022 return self._move_op_to_forward_graph( 1023 op_type, 1024 inputs, 1025 dtypes=dtypes, 1026 input_types=input_types, 1027 name=name, 1028 attrs=attrs, 1029 op_def=op_def, 1030 compute_device=compute_device) 1031 1032 return super(_WhileBodyGradFuncGraph, self)._create_op_internal( 1033 op_type, 1034 inputs, 1035 dtypes=dtypes, 1036 input_types=input_types, 1037 name=name, 1038 attrs=attrs, 1039 op_def=op_def, 1040 compute_device=compute_device) 1041 1042 def _move_op_to_forward_graph( 1043 self, 1044 op_type, 1045 inputs, 1046 dtypes=None, # pylint: disable=redefined-outer-name 1047 input_types=None, 1048 name=None, 1049 attrs=None, 1050 op_def=None, 1051 compute_device=True): 1052 # We have a cache of reduction ops that have already been moved to the 1053 # forward graph, and we will check it first to avoid moving an op twice. 1054 if not hasattr(self._forward_graph, "_optimized_reduction_ops_cache"): 1055 self._forward_graph._optimized_reduction_ops_cache = {} 1056 cache_key = self._get_optimized_reduction_ops_cache_key( 1057 op_type, inputs, dtypes, input_types, name, attrs, op_def, 1058 compute_device) 1059 cached_op = self._forward_graph._optimized_reduction_ops_cache.get( 1060 cache_key) 1061 if cached_op is not None: 1062 # This op has already been moved to the forward graph and we have it in 1063 # the cache. 1064 return cached_op 1065 1066 with self._forward_graph.as_default(): 1067 # `name` was built using name_scope stack of gradient graph and may not 1068 # be unique in the forward graph. `Graph.create_op` does not uniquify 1069 # names which are name scopes i.e. end in `/`. To ensure that the op 1070 # created gets a unique name in the forward graph we get rid of the 1071 # trailing slash. 1072 name = ops.name_from_scope_name(name) 1073 result = self._forward_graph._create_op_internal( 1074 op_type, 1075 inputs, 1076 dtypes=dtypes, 1077 input_types=input_types, 1078 name=name, 1079 attrs=attrs, 1080 op_def=op_def, 1081 compute_device=compute_device) 1082 1083 # Store the op we just moved to the forward graph so that it does 1084 # not need to be added there again. 1085 self._forward_graph._optimized_reduction_ops_cache[cache_key] = result 1086 return result 1087 1088 def _get_optimized_reduction_ops_cache_key( 1089 self, 1090 op_type, 1091 inputs, 1092 dtypes=None, # pylint: disable=redefined-outer-name 1093 input_types=None, 1094 name=None, 1095 attrs=None, 1096 op_def=None, 1097 compute_device=True): 1098 # We need all elements of CacheKey to be hashable. 1099 inputs = tuple(map(lambda t: t.ref(), inputs)) 1100 1101 if dtypes is not None: 1102 dtypes = tuple(dtypes) 1103 1104 if input_types is not None: 1105 input_types = tuple(input_types) 1106 1107 if attrs is not None: 1108 hashable_attrs = [] 1109 for attr_name, attr_value in sorted(attrs.items()): 1110 hashable_attrs.append((attr_name, attr_value.SerializeToString())) 1111 attrs = tuple(hashable_attrs) 1112 1113 if op_def is not None: 1114 op_def = op_def.SerializeToString() 1115 1116 return OptimizedReductionOpsCacheKey(op_type, inputs, dtypes, input_types, 1117 name, attrs, op_def, compute_device) 1118 1119 def _capture_helper(self, tensor, name): 1120 """Implements the capturing described in the class docstring.""" 1121 captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor)) 1122 if captured_tensor is not None: 1123 return captured_tensor 1124 1125 if tensor.graph is not self._forward_graph: 1126 already_captured = self.captured(tensor) 1127 captured_tensor = super(_WhileBodyGradFuncGraph, self)._capture_helper( 1128 tensor, name) 1129 if not already_captured: 1130 # Adds the captured tensor to the list of outputs so that the input 1131 # and output signatures match. 1132 self.internal_capture_to_output[ops.tensor_id( 1133 captured_tensor)] = captured_tensor 1134 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 1135 return captured_tensor 1136 1137 while tensor.op.type == "Identity": 1138 # We do not accumulate the output of identity nodes so we try to capture 1139 # the input of the Identity node instead. 1140 tensor = tensor.op.inputs[0] 1141 1142 captured_tensor = self._indirect_captures.get(ops.tensor_id(tensor)) 1143 if captured_tensor is not None: 1144 return captured_tensor 1145 1146 # No need to accumulate loop invariants. Capture them directly. 1147 # The captured tensor gets resolved to the corresponding while output in 1148 # `_resolve_grad_captures`. 1149 if _is_loop_invariant(tensor, self._forward_graph.inputs, 1150 self._forward_graph.outputs): 1151 captured_tensor = super(_WhileBodyGradFuncGraph, 1152 self)._capture_helper(tensor, name) 1153 # Add to `internal_capture_to_output` so that this gets added to the list 1154 # of outputs. 1155 self.internal_capture_to_output[ops.tensor_id( 1156 captured_tensor)] = captured_tensor 1157 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 1158 return captured_tensor 1159 1160 # Do not accumulate Const nodes. Instead copy them directly in the backward 1161 # graph. 1162 # TODO(srbs): This just checks for `Const` nodes. Consider checking for 1163 # graph compile time consts in general. 1164 # TODO(srbs): Consider making this a loop input. 1165 if constant_op.is_constant(tensor): 1166 real_value = constant_op.constant( 1167 tensor_util.constant_value(tensor), dtype=tensor.dtype) 1168 self._indirect_captures[ops.tensor_id(tensor)] = real_value 1169 return real_value 1170 1171 # Resource tensors are not accumulated and handled specially. 1172 if tensor.dtype == dtypes.resource: 1173 return self._resource_capture_helper(tensor) 1174 1175 # Create or find an existing accumulator output for `tensor` in the forward 1176 # graph, and fetch from this accumulator in the gradient graph to get the 1177 # raw intermediate value. 1178 accumulator = _get_accumulator(tensor) 1179 if accumulator is None: 1180 # Create the initial empty tensor list. 1181 # 1182 # Note: We clear the control dependencies to avoid a cycle in case a 1183 # control tensor has an input path to an output of the forward While. 1184 # 1185 # E.g.: 1186 # x = tf.while_loop(...) 1187 # y = f(x) 1188 # with tf.control_dependencies([y]): 1189 # tf.gradients(y, x) 1190 # 1191 # Since the EmptyTensorList is fed back into the forward While, not 1192 # removing the control edge would cause a cycle. 1193 with self._forward_graph.outer_graph.as_default(): 1194 with util.clear_control_inputs(): 1195 tensor_list = list_ops.empty_tensor_list( 1196 element_dtype=tensor.dtype, 1197 element_shape=tensor.shape, 1198 max_num_elements=self._maximum_iterations, 1199 name=_build_accumulator_name(tensor)) 1200 self.extra_inputs.append(tensor_list) 1201 1202 # Push the intermediate tensor to the tensor list. This captures 1203 # `tensor_list`. 1204 with self._forward_graph.as_default(): 1205 accumulator = list_ops.tensor_list_push_back(tensor_list, tensor) 1206 # Add the modified tensor list to the list of outputs. This output will be 1207 # all the accumulated values. 1208 self._forward_graph.outputs.append(accumulator) 1209 1210 # Capture in the cond graph as well so the forward cond and body inputs 1211 # match. 1212 with self._forward_cond_graph.as_default(): 1213 self._forward_cond_graph.capture(tensor_list) 1214 1215 # Capture the accumulator tensor list in the gradient graph directly from 1216 # the forward graph -- we'll later modify this to capture the final list 1217 # output by the forward While op instead. 1218 captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper( 1219 accumulator, name) 1220 1221 # Pop the intermediate value from the tensor list in the gradient graph. 1222 new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( 1223 captured_accumulator, element_dtype=tensor.dtype) 1224 1225 self._indirect_captures[ops.tensor_id(tensor)] = captured_tensor 1226 self.internal_capture_to_output[ops.tensor_id( 1227 captured_accumulator)] = new_tensor_list 1228 return captured_tensor 1229 1230 def _resource_capture_helper(self, tensor): 1231 """Returns the captured resource tensor. 1232 1233 Resource-type tensors are not accumulated. If a resource tensor exists in 1234 the loop body it must either be a loop input or an output of a nested While 1235 op inside the loop body which had captured the external resource. 1236 1237 Args: 1238 tensor: the external resource Tensor to be captured. 1239 1240 Returns: 1241 Tensor in this graph. 1242 """ 1243 assert tensor.dtype == dtypes.resource 1244 1245 forward_graph_input_names = [t.name for t in self._forward_graph.inputs] 1246 forward_graph_name_to_opdef = { 1247 op.name: op.node_def for op in self._forward_graph.get_operations()} 1248 index = util.resource_input_index( 1249 tensor.name, forward_graph_input_names, 1250 forward_graph_name_to_opdef, 1251 self._forward_graph._functions) 1252 1253 input_placeholder = self._forward_graph.inputs[index] 1254 tensor_in_outer_graph = self._forward_graph._while.inputs[index] 1255 1256 assert input_placeholder.dtype == dtypes.resource 1257 assert tensor_in_outer_graph.dtype == dtypes.resource 1258 # This must be a loop invariant. However, infrastructure 1259 # (e.g. tf.vectorized_map) may insert identity nodes, function calls, conds, 1260 # etc. which take and return the resource tensor unmodified; this means that 1261 # the Python objects may differ. 1262 if index != util.resource_input_index( 1263 self._forward_graph.outputs[index].name, forward_graph_input_names, 1264 forward_graph_name_to_opdef, 1265 self._forward_graph._functions): 1266 raise AssertionError( 1267 f"Resource tensors must be loop invariants {tensor_in_outer_graph}") 1268 1269 self._indirect_captures[ops.tensor_id(tensor)] = self.capture( 1270 tensor_in_outer_graph) 1271 return self._indirect_captures[ops.tensor_id(tensor)] 1272 1273 1274def _check_shapes_compat(output_tensors, shape_invariants, input_tensors): 1275 for (t, shape, input_t) in zip(output_tensors, shape_invariants, 1276 input_tensors): 1277 if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape): 1278 raise ValueError( 1279 f"Input tensor `{input_t.name}` enters the loop with shape {shape}, " 1280 f"but has shape {t.shape} after one iteration. To allow the shape to " 1281 "vary across iterations, use the `shape_invariants` argument of " 1282 "tf.while_loop to specify a less-specific shape.") 1283 1284 1285def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars): 1286 """Checks the number of inputs/outputs of `cond_graph` and `body_graph`.""" 1287 assert len(cond_graph.inputs) == num_flattened_loop_vars, ( 1288 "cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs), 1289 num_flattened_loop_vars)) 1290 assert len(cond_graph.outputs) == 1, ( 1291 "cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs)) 1292 assert len(body_graph.inputs) == num_flattened_loop_vars, ( 1293 "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs), 1294 num_flattened_loop_vars)) 1295 assert len(body_graph.outputs) == num_flattened_loop_vars, ( 1296 "body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs), 1297 num_flattened_loop_vars)) 1298 1299 1300def _check_inputs_outputs_types_match(body_graph, flattened_loop_vars): 1301 for inp, out, loop_var in zip(body_graph.inputs, body_graph.outputs, 1302 flattened_loop_vars): 1303 if inp.dtype != out.dtype: 1304 raise TypeError( 1305 f"Loop var {loop_var.name} enters the loop with type {inp.dtype} " 1306 f"but has type {out.dtype} after 1 iteration. {loop_var.name} type " 1307 "should remain constant.") 1308 1309 1310def _build_cond_placeholders_name_prefix(cond_graph): 1311 return cond_graph.unique_name(cond_graph.name + "___redundant_placeholder") 1312 1313 1314def _duplicate_body_captures_in_cond(cond_graph, body_graph_captures): 1315 """Creates placeholders for body captures in cond_graph. 1316 1317 This is needed to match signatures of cond and body graphs. 1318 1319 Args: 1320 cond_graph: cond branch graph 1321 body_graph_captures: Tensors which were captured when building the 1322 `body_graph`. 1323 """ 1324 types = [t.dtype.as_datatype_enum for t in body_graph_captures] 1325 # TODO(srbs): Providing a unique prefix does not ensure that there is no 1326 # conflict between the placeholder names and existing nodes in the graph. 1327 # However passing a list of strings may not be performant. 1328 # Ideally we should move `Graph.unique_name` to C++ or make 1329 # `Graph._names_in_use` a trie so that we can find a unique prefix. 1330 # TODO(b/143286622): This should not be required once captures are separated 1331 # from regular loop vars. 1332 placeholders = c_api.TF_CreatePlaceholders( 1333 cond_graph._c_graph, types, 1334 compat.as_str(_build_cond_placeholders_name_prefix(cond_graph))) 1335 placeholder_ops = [ 1336 _OperationWithOutputs(ph.oper, cond_graph) 1337 for ph in placeholders 1338 ] 1339 1340 tensors = [] 1341 for op, ph, dtype in zip(placeholder_ops, placeholders, types): 1342 tensor = ops.Tensor._create_with_tf_output(op, 0, dtype, ph) 1343 op._outputs = [tensor] 1344 tensors.append(tensor) 1345 1346 # Update `cond_graph._captures` and `cond_graph.inputs` to contain the 1347 # newly created placeholders. 1348 tuples = zip(body_graph_captures, tensors) 1349 keys = [id(t) for t in body_graph_captures] 1350 cond_graph._captures.update(zip(keys, tuples)) 1351 cond_graph.inputs.extend(tensors) 1352 1353 1354def _copy_handle_data(src_tensors, tgt_tensors): 1355 for src_t, tgt_t in zip(src_tensors, tgt_tensors): 1356 handle_data_util.copy_handle_data(src_t, tgt_t) 1357 1358 1359def _graph_name(graph): 1360 if isinstance(graph, func_graph_module.FuncGraph): 1361 return graph.name 1362 return "Base" 1363 1364 1365def _pack_sequence_as(structure_with_tas, loop_vars): 1366 """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays.""" 1367 1368 def flow_to_tensor_array(flow, ta): # pylint: disable=missing-docstring 1369 return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance( # pylint: disable=g-long-ternary 1370 ta, tensor_array_ops.TensorArray) else flow) 1371 1372 flattened_loop_vars = [ 1373 flow_to_tensor_array(*z) 1374 for z in zip(nest.flatten(loop_vars, expand_composites=True), 1375 nest.flatten(structure_with_tas, expand_composites=True)) 1376 ] 1377 return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars, 1378 expand_composites=True) 1379 1380 1381def _tensor_array_to_flow(loop_vars): 1382 1383 def f(maybe_ta): 1384 if isinstance(maybe_ta, tensor_array_ops.TensorArray): 1385 return maybe_ta.flow 1386 return maybe_ta 1387 1388 return nest.map_structure(f, loop_vars, expand_composites=True) 1389 1390 1391def _build_maximum_iterations_loop_var(maximum_iterations): 1392 if maximum_iterations is None: 1393 # Default value for max_num_elements to EmptyTensorList meaning that the 1394 # list size is unbounded. 1395 maximum_iterations = -1 1396 # EmptyTensorList expects `max_num_elements` to be of type int32. 1397 return ops.convert_to_tensor( 1398 maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") 1399 1400 1401def _build_accumulator_name(tensor): 1402 # Tensor name may be of the form "pow/y:0". Name scope does not allow ":". 1403 return "{}/accumulator".format(tensor.name).replace(":", "_") 1404 1405 1406def _is_loop_invariant(tensor, inputs, outputs): 1407 return (any(tensor is t for t in inputs) and 1408 any(tensor is t for t in outputs)) 1409 1410 1411class _OperationWithOutputs(ops.Operation): 1412 """Operation with pre-built `TF_Output`s. 1413 1414 The C API for creating the extra placeholders for the cond graph returns 1415 SWIG wrapped TF_Output* pointers which we can use directly for 1416 `Operation.outputs`. The default constructor for `Operation` does not provide 1417 a way of specifying pre-built output tensors and always creates them. This is 1418 a performance overhead. It is not clear if adding that feature to the 1419 `Operation` API would be generally useful so for now we just have our own 1420 lightweight `Operation` implementation. Note that this does not extract a 1421 stacktrace as well since we don't expect this operation to be used. 1422 1423 TODO(b/143286622): This should not be required once captures are separated 1424 from regular loop vars. 1425 """ 1426 1427 def __init__(self, c_op, g): 1428 self._c_op = c_op 1429 self._graph = g 1430 self._outputs = None # Initialized by _duplicate_body_captures_in_cond(). 1431 self._id_value = g._add_op(self, self.name) 1432 self._is_stateful = False 1433 1434 1435def _set_read_only_resource_inputs_attr(op, branch_graphs): 1436 """Sets the list of resource inputs which are read-only. 1437 1438 This is used by AutomaticControlDependencies. 1439 1440 Args: 1441 op: While Operation. 1442 branch_graphs: List of branch FuncGraphs. 1443 """ 1444 read_only_indices = set(range(len(op.inputs))) 1445 for branch_graph in branch_graphs: 1446 if not read_only_indices: 1447 break 1448 branch_read_only_indices = acd.get_read_only_resource_input_indices_graph( 1449 branch_graph) 1450 read_only_indices = read_only_indices.intersection(branch_read_only_indices) 1451 1452 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 1453 sorted(read_only_indices)) 1454 1455# pylint: enable=protected-access 1456