1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""while_v2 and gradient. 16 17This is a version of while_loop that emits a single While op, as well as the 18gradient function for While ops produced by while_loop. This will eventually 19replace the current tf.while_loop implementation once it reaches feature and 20performance parity. 21""" 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import func_graph as func_graph_module 29from tensorflow.python.framework import function_def_to_graph 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_spec 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import control_flow_util 36from tensorflow.python.ops import control_flow_util_v2 as util 37from tensorflow.python.ops import custom_gradient 38from tensorflow.python.ops import gen_functional_ops 39from tensorflow.python.ops import gen_resource_variable_ops 40from tensorflow.python.ops import gradients_util 41from tensorflow.python.ops import list_ops 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import tensor_array_ops 44from tensorflow.python.ops import while_v2_indexed_slices_rewriter 45from tensorflow.python.util import nest 46 47# pylint: disable=protected-access 48 49# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows 50# control dependencies on external nodes with at least 1 output. 51# Another idea is to create const nodes outside the loop and add control edges 52# to them and then pass those in as data inputs. This should probably be 53# handled in the CapturingGraph itself. 54 55 56def while_loop(cond, 57 body, 58 loop_vars, 59 shape_invariants=None, 60 parallel_iterations=10, 61 maximum_iterations=None, 62 name=None, 63 return_same_structure=True): 64 """Like tf.while_loop, except emits a single While op.""" 65 # Keep the original loop_vars around to know which args were TensorArrays. 66 orig_loop_vars = loop_vars 67 # Cache its length since we use it at multiple places below. 68 len_orig_loop_vars = len(orig_loop_vars) 69 70 # Convert TensorArrays to their flow variables. These get converted back to 71 # TensorArrays before calling `cond` and `body`. See `wrapped_cond` and 72 # `wrapped_body` below. 73 loop_vars = list(_tensor_array_to_flow(orig_loop_vars)) 74 loop_vars = nest.map_structure( 75 ops.internal_convert_to_tensor_or_indexed_slices, loop_vars) 76 if shape_invariants is not None: 77 nest.assert_same_structure(orig_loop_vars, shape_invariants) 78 else: 79 shape_invariants = nest.map_structure(lambda t: t.shape, loop_vars) 80 81 if not name: 82 name = "while" 83 84 with ops.name_scope(name) as scope: 85 with ops.name_scope(None): 86 cond_name = util.unique_fn_name(scope, "cond") 87 body_name = util.unique_fn_name(scope, "body") 88 maximum_iterations_loop_var = _build_maximum_iterations_loop_var( 89 maximum_iterations) 90 loop_counter = constant_op.constant( 91 0, 92 dtype=maximum_iterations_loop_var.dtype 93 if maximum_iterations is not None else None, 94 name="loop_counter") 95 # Add loop counter needed for computing gradients. 96 loop_vars = [loop_counter, maximum_iterations_loop_var] + loop_vars 97 98 shape_invariants = type(shape_invariants)( 99 [tensor_shape.scalar(), tensor_shape.scalar()]) + shape_invariants 100 101 # Automatic control dependencies are added in defuns, but not in v1 102 # graphs. Propagate that behavior here. 103 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 104 105 # Build a `cond` wrapper that can handle the extra counter loop_var. 106 def wrapped_cond(loop_counter, maximum_iterations_arg, *args): 107 # Convert the flow variables in `args` to TensorArrays. `args` should 108 # already have the same structure as `orig_loop_vars` but currently there 109 # is no nest.zip so we call `_pack_sequence_as` which flattens both 110 # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays 111 # and packs it into the structure of `orig_loop_vars`. 112 if maximum_iterations is None: 113 return cond(*_pack_sequence_as(orig_loop_vars, args)) 114 else: 115 return math_ops.logical_and( 116 loop_counter < maximum_iterations_arg, 117 cond(*_pack_sequence_as(orig_loop_vars, args))) 118 119 # NOTE(skyewm): we set collections to the outer graph's collections for 120 # compatibility with TPUEstimator. 121 cond_graph = func_graph_module.func_graph_from_py_func( 122 cond_name, 123 wrapped_cond, 124 [], # We provide signature instead of args. 125 {}, 126 signature=_build_signature(loop_vars, shape_invariants), 127 func_graph=util.WhileCondFuncGraph( 128 cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 129 add_control_dependencies=add_control_dependencies) 130 131 def wrapped_body(loop_counter, maximum_iterations_arg, *args): 132 """Loop body augmented with counter update. 133 134 Args: 135 loop_counter: Loop counter which needs to be incremented in the body. 136 maximum_iterations_arg: Maximum iterations of the loop. 137 *args: List of args 138 139 Returns: 140 A list of tensors the same length as args. 141 """ 142 # Capture the tensors already captured in cond_graph so that they appear 143 # in the same order in body_graph.external_captures. 144 for t in cond_graph.external_captures: 145 ops.get_default_graph().capture(t) 146 147 # Convert the flow variables in `args` to TensorArrays. `args` should 148 # already have the same structure as `orig_loop_vars` but currently there 149 # is no nest.zip so we call `_pack_sequence_as` which flattens both 150 # `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays 151 # and packs it into the structure of `orig_loop_vars`. 152 outputs = body(*_pack_sequence_as(orig_loop_vars, args)) 153 if not nest.is_sequence(outputs): 154 outputs = [outputs] 155 # Compare the structure of input and output of body converting the 156 # top-level tuples to list to be compatible with legacy while_loop. 157 nest.assert_same_structure(list(outputs), list(orig_loop_vars)) 158 159 outputs = _tensor_array_to_flow(outputs) 160 161 # TODO(srbs): Update lowering code to create _Enter nodes with 162 # is_constant=True for inputs that are directly passed to outputs. 163 return [loop_counter + 1, maximum_iterations_arg] + list(outputs) 164 165 body_graph = func_graph_module.func_graph_from_py_func( 166 body_name, 167 wrapped_body, 168 [], # We provide signature instead of args. 169 {}, 170 signature=_build_signature(loop_vars, shape_invariants), 171 func_graph=util.WhileBodyFuncGraph( 172 body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 173 add_control_dependencies=add_control_dependencies) 174 # Add external captures of body to the list of loop vars. 175 # Note that external tensors will be treated as loop invariants, i.e., 176 # the value of that tensor in each iteration is the same as it was at the 177 # beginning of the loop execution. 178 loop_vars = loop_vars + body_graph.external_captures 179 # TODO(srbs): Update lowering code to create _Enter nodes with 180 # is_constant=True for inputs that are directly passed to outputs. 181 body_graph.outputs.extend(body_graph.internal_captures) 182 183 # Capture the extra `external_captures` of `body_graph` in `cond_graph` so 184 # that it expects to receive those as arguments. 185 with cond_graph.as_default(): 186 num_cond_captures = len(cond_graph.external_captures) 187 assert (cond_graph.external_captures == 188 body_graph.external_captures[:num_cond_captures]) 189 for body_capture in body_graph.external_captures[num_cond_captures:]: 190 assert body_capture not in cond_graph.captures 191 cond_graph.capture(body_capture) 192 193 # Make sure that the shapes of the loop outputs are compatible with the 194 # shape invariants, or the shapes of the loop vars if the invariants are not 195 # specified. 196 num_flattened_outputs = len(nest.flatten(orig_loop_vars)) 197 # First var is loop counter and second var is maximum_iterations. 198 first_loop_var_index = 2 199 _check_shapes_compat( 200 body_graph.outputs[first_loop_var_index:first_loop_var_index + 201 num_flattened_outputs], 202 nest.flatten( 203 shape_invariants[first_loop_var_index:first_loop_var_index + 204 len_orig_loop_vars]), 205 nest.flatten(loop_vars[first_loop_var_index:first_loop_var_index + 206 len_orig_loop_vars])) 207 flattened_loop_vars = nest.flatten(loop_vars) 208 _check_num_inputs_outputs(cond_graph, body_graph, 209 len(flattened_loop_vars)) 210 211 with ops.control_dependencies( 212 list(cond_graph.control_captures) + list(body_graph.control_captures)): 213 outputs = gen_functional_ops._while( 214 flattened_loop_vars, 215 util.create_new_tf_function(cond_graph), 216 util.create_new_tf_function(body_graph), 217 output_shapes=[t.shape for t in body_graph.outputs], 218 parallel_iterations=parallel_iterations, 219 name=scope) 220 221 _copy_handle_data(body_graph.outputs, outputs) 222 util.maybe_set_lowering_attr(outputs[0].op) 223 util.maybe_propagate_compile_time_consts_in_xla(outputs[0].op) 224 225 # Return identities for each output of the While op, rather than the output 226 # of the While op directly. This makes pruning work if the output of 227 # while_loop() is fetched: the lowering pass converts the While outputs into 228 # IdentityN outputs, which if fetched will cause all ops in the body to be 229 # run (since it takes all exit ops as input). After lowering, each output 230 # identity op will end up with only the appropriate exit op as input. 231 outputs = tuple(array_ops.identity(t) for t in outputs) 232 233 outputs = _pack_sequence_as( 234 orig_loop_vars, outputs[first_loop_var_index:first_loop_var_index + 235 num_flattened_outputs]) 236 237 if return_same_structure: 238 return outputs 239 240 flattened_outputs = nest.flatten(outputs) 241 if len(flattened_outputs) == 1: 242 return flattened_outputs[0] 243 else: 244 return outputs 245 246 247@ops.RegisterGradient("While") 248def _WhileGrad(op, *grads): # pylint: disable=invalid-name 249 """The gradient of a While op produced by while_loop.""" 250 # Note that op is not always the same as while_op because the gradient tape, 251 # for eager mode compatibility, forgets information about the proper op. Since 252 # the loop cannot run in eager mode, however, we can safely introspect into 253 # the graph here. 254 while_op = op.outputs[0].op 255 cond_graph = _get_graph(while_op, "cond") 256 body_graph = _get_graph(while_op, "body") 257 orig_num_params = len(body_graph.outputs) 258 259 maximum_iterations = op.inputs[1] 260 parallel_iterations = op.get_attr("parallel_iterations") 261 262 grads = [_preprocess_grad(grad, body_out, while_out) 263 for grad, body_out, while_out 264 in zip(grads, body_graph.outputs, while_op.outputs)] 265 266 # We compute the gradient for the sub-graph between trainable ys and xs 267 # with non-None incoming gradients. We later pad the None's to the list of 268 # outputs. 269 ys, xs, non_none_grads = zip(*[(y, x, grad) for (y, x, grad) in zip( 270 body_graph.outputs, body_graph.inputs, grads) if grad is not None]) 271 272 body_grad_graph, args = _create_grad_func( 273 ys, xs, non_none_grads, cond_graph, body_graph, 274 util.unique_grad_fn_name(body_graph.name), op, maximum_iterations) 275 276 if body_grad_graph.while_op_needs_rewrite: 277 # Modify 'op' to output the intermediate accumulators needed by the grad 278 # function. 279 # NOTE(skyewm): if there are any active sessions, this modification to `op` 280 # may make them unrunnable! 281 282 cond_graph.name += "_rewritten" 283 body_graph.name += "_rewritten" 284 285 new_inputs = body_grad_graph.empty_tensor_lists 286 new_outputs = body_graph.outputs[orig_num_params:] 287 288 while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) 289 while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) 290 while_op._set_type_list_attr("T", body_graph.output_types) 291 while_op._set_shape_list_attr("output_shapes", body_graph.output_shapes) 292 while_op._add_while_inputs(new_inputs) 293 while_op._add_outputs([t.dtype for t in new_outputs], 294 [t.shape for t in new_outputs]) 295 _copy_handle_data(new_outputs, op.outputs[orig_num_params:]) 296 297 captured_inputs = _resolve_grad_captures(body_graph, body_grad_graph, 298 while_op) 299 loop_vars = args + captured_inputs 300 301 # This modifies body_grad_graph. 302 loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices( 303 grads, body_grad_graph, loop_vars, while_op.inputs) 304 305 def grad_cond(counter, unused_maximum_iterations_arg, forward_loop_iters, 306 *unused_args): 307 return counter < forward_loop_iters 308 309 grad_cond_name = util.unique_grad_fn_name(op.get_attr("cond").name) 310 cond_grad_graph = func_graph_module.func_graph_from_py_func( 311 grad_cond_name, grad_cond, loop_vars, {}, 312 func_graph=util.WhileCondFuncGraph(grad_cond_name)) 313 314 _check_num_inputs_outputs(cond_grad_graph, body_grad_graph, len(loop_vars)) 315 316 outputs = gen_functional_ops._while( 317 loop_vars, 318 util.create_new_tf_function(cond_grad_graph), 319 util.create_new_tf_function(body_grad_graph), 320 output_shapes=[t.shape for t in body_grad_graph.outputs], 321 parallel_iterations=parallel_iterations, 322 name="%s_grad" % while_op.name) 323 grad_op = outputs[0].op 324 325 _copy_handle_data(body_grad_graph.outputs, outputs) 326 util.maybe_set_lowering_attr(grad_op) 327 util.maybe_propagate_compile_time_consts_in_xla(grad_op) 328 329 # See comment in while_loop. 330 outputs = [array_ops.identity(t) for t in outputs] 331 return _get_structured_grad_output(outputs, grads, body_grad_graph) 332 333 334def _preprocess_grad(grad, body_graph_output, while_op_output): 335 """Returns the initial gradient to be used for a given output tensor. 336 337 Args: 338 grad: the original gradient Tensor passed to the gradient function. 339 body_graph_output: the corresponding Tensor in the body graph. 340 while_op_output: the corresponding Tensor output of the While op. 341 342 Returns: 343 A Tensor or None. 344 """ 345 # Set the incoming gradient of non-trainable inputs to None. It is possible 346 # that we receive non-None gradients for non-trainable types in nested while 347 # loops because we accumulate outputs of the inner while as variant tensors 348 # which are trainable and hence receive zeros_like tensors in the gradient 349 # pass. The non-trainable tensors then receive the popped zeros tensor from 350 # this zeros variant. The gradient for the loop vars corresponding to these 351 # tensors is None or zeros (this happens only if the loop var is accumulated 352 # as well) in _grad_fn so we reset these. 353 # TODO(b/118712257): Remove once we can handle None output grads in _grad_fn. 354 if not _is_trainable(body_graph_output): 355 return None 356 357 # GradientTape initializes resource and variant grads as None instead of 358 # zeros. Set to zeros so _GradientsHelper computes the gradients instead of 359 # returning None. 360 if (while_op_output.dtype in (dtypes.resource, dtypes.variant) 361 and grad is None): 362 return _zeros_like(while_op_output) 363 364 return grad 365 366 367# TODO(skyewm): make this return constants if op_output's shape is fully 368# defined (this can be done by checking the "shape" attr of resource vars). 369def _zeros_like(op_output): 370 """Like array_ops.zeros_like() but also accepts resource var handles.""" 371 if op_output.dtype == dtypes.resource: 372 return array_ops.zeros( 373 gen_resource_variable_ops.variable_shape(op_output)) 374 return array_ops.zeros_like(op_output) 375 376 377def _is_trainable(tensor): 378 """Returns whether the given tensor is trainable.""" 379 if not gradients_util.IsTrainable(tensor): 380 return False 381 382 # Special case: untrainable accumulator output. The gradients algorithm 383 # doesn't know about tensor lists of untrainable elements. In theory the 384 # tensor list gradient functions should return None as appropriate, but 385 # because we can't return None from the gradient function we filter out 386 # untrainable accumulator output here to avoid computing the gradient at all. 387 if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0: 388 assert tensor.dtype == dtypes.variant 389 element_type = tensor.op.get_attr("element_dtype") 390 return gradients_util.IsTrainable(element_type) 391 392 return True 393 394 395# TODO(srbs): Pull this into common utils for cond_v2 and while_v2. 396def _get_graph(while_op, func_attr_name): 397 """Returns `FuncGraph` for the given function attribute. 398 399 Args: 400 while_op: The While Operation. 401 func_attr_name: string 402 403 Returns: 404 `FuncGraph` 405 """ 406 # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes. 407 input_shapes = [ 408 tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes") 409 ] 410 func_name = while_op.get_attr(func_attr_name).name 411 fdef = while_op.graph._get_function(func_name).definition 412 # `while_op.graph` may not be the same as `ops.get_default_graph()` e.g. 413 # if the `while_op` is in the body of another if/while/defun. We build the 414 # `func_graph` with `while_op.graph` as its `outer_graph`. This resembles how 415 # the `FuncGraph` was built in the forward pass. We need this so that we can 416 # appropriately capture references to outer tensors in the nested grad graphs. 417 with while_op.graph.as_default(): 418 func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes) 419 func_graph._while = while_op 420 return func_graph 421 422 423def _create_grad_func(ys, xs, grads, cond_graph, body_graph, name, while_op, 424 maximum_iterations): 425 """Builds and returns the gradient FuncGraph of `func_graph` and its args. 426 427 The returned grad_func_graph must be called with the returned 428 args + grad_func_graph.captures. 429 430 Args: 431 ys: A `Tensor` or list of tensors to be differentiated. 432 xs: A `Tensor` or list of tensors to be used for differentiation. 433 grads: The incoming grads for `ys`. 434 cond_graph: FuncGraph for the forward cond function. 435 body_graph: FuncGraph for the forward body function. 436 name: Name of the returned gradient function. 437 while_op: The forward While op. 438 maximum_iterations: Tensor. The maximum number of iterations. 439 440 Returns: 441 2-tuple of (grad_func_graph, args). 442 """ 443 assert len(ys) == len(grads) 444 445 total_iters = while_op.outputs[0] 446 counter = constant_op.constant( 447 0, dtype=total_iters.dtype, name="grad_counter") 448 449 args = [counter, maximum_iterations, total_iters] + list(grads) 450 # Note: The returned function does not have `args` in the list of 451 # `external_captures`. 452 grad_func_graph = func_graph_module.func_graph_from_py_func( 453 name, 454 lambda *args: _grad_fn(ys, xs, args, body_graph), 455 args, {}, 456 func_graph=_WhileBodyGradFuncGraph(name, cond_graph, body_graph, 457 maximum_iterations)) 458 459 # Add the popped accumulators to the list of outputs. 460 for internal_capture in grad_func_graph.internal_captures: 461 if internal_capture in grad_func_graph.popped_tensor_lists: 462 new_output = grad_func_graph.popped_tensor_lists[internal_capture] 463 elif internal_capture.dtype == dtypes.resource: 464 new_output = internal_capture 465 else: 466 raise ValueError("Tensor %s is in list of internal_captures but is" 467 " neither a resource nor is in popped_tensor_lists." % 468 str(internal_capture)) 469 grad_func_graph.outputs.append(new_output) 470 grad_func_graph.structured_outputs.append(new_output) 471 472 return grad_func_graph, args 473 474 475def _grad_fn(ys, xs, args, func_graph): 476 """Computes the gradient of `func_graph` in the current graph. 477 478 This function builds the gradient graph of the corresponding forward-pass 479 `func_graph` by differentiating `func_graph`'s outputs w.r.t. its inputs. 480 481 Args: 482 ys: A `Tensor` or list of tensors to be differentiated. 483 xs: A `Tensor` or list of tensors to be used for differentiation. 484 args: The input arguments. 485 args[0] - Loop counter 486 args[1] - Total number of iterations. 487 args[2] - maximum_iterations. 488 args[3:] - Incoming gradients for `ys`. 489 func_graph: function.FuncGraph. The corresponding forward-pass function. 490 491 Returns: 492 The output gradient Tensors. 493 """ 494 grad_ys = args[3:] 495 496 # Build the gradient graph. Note that this builds the gradient computation of 497 # func_graph in the current graph, which requires capturing tensors from 498 # func_graph. The captured func_graph tensors are resolved to external tensors 499 # after the forward While op has been rewritten in _resolve_grad_captures. 500 # TODO(srbs): Mark GradientsHelper as public? 501 grad_outs = gradients_util._GradientsHelper( 502 ys, xs, grad_ys=grad_ys, src_graph=func_graph, 503 unconnected_gradients="zero") 504 505 # TODO(b/118712257): Handle the case when grad_outs has None's e.g. when there 506 # is a tf.StopGradient in the loop body. 507 assert all(g is not None for g in grad_outs) 508 counter = args[0] 509 maximum_iterations = args[1] 510 total_iters = args[2] 511 return [counter + 1, maximum_iterations, total_iters] + grad_outs 512 513 514def _resolve_grad_captures(body_graph, body_grad_graph, while_op): 515 """Returns the tensors to pass as captured inputs to `body_grad_graph`. 516 517 `body_grad_graph` may have external references to: 518 1. Its outer graph containing the input gradients. These are left as-is. 519 2. Accumulators captured from the forward-pass graph. These should have been 520 added as `while_op` outputs after the gradient graph was built. We replace 521 these with the corresponding output of `while_op`, i.e. a tensor in 522 `body_graph.outer_graph`. In the case of nested control flow or functions, 523 the gradient logic handling `body_grad_graph.outer_graph` will make sure 524 the tensor from `body_graph.outer_graph` is also correctly captured. 525 526 Args: 527 body_graph: FuncGraph. The forward-pass body function. 528 body_grad_graph: FuncGraph. The body gradients function. 529 while_op: The forward-pass While Operation calling `body_graph`. 530 531 Returns: 532 A list of input tensors to be passed as the captured inputs to 533 `body_grad_graph`. 534 """ 535 new_capture_inputs = [] 536 for t in body_grad_graph.external_captures: 537 # All values captured by gradient computation should be from the forward 538 # graph or a captured resource variable (note that input gradients are 539 # regular non-captured inputs). 540 if t.graph == body_graph: 541 # Captured accumulator 542 t = while_op.outputs[t.graph.outputs.index(t)] 543 # Note: We rely on the capturing logic of the gradient While op graph to 544 # correctly capture the tensors in `body_graph.outer_graph`. Both cond_v2 545 # and while_v2 handle this while building their gradient functions. 546 assert t.graph == body_graph.outer_graph 547 else: 548 # Captured resource variable 549 assert t.dtype == dtypes.resource 550 551 new_capture_inputs.append(t) 552 return new_capture_inputs 553 554 555def _get_structured_grad_output(outputs, grads, body_grad_graph): 556 """Returns the values that should be returned from the while grad function. 557 558 Args: 559 outputs: the raw Tensor outputs of the grad While op. 560 grads: the input gradients to the gradient function. 561 body_grad_graph: _WhileBodyGradFuncGraph. 562 563 Returns: 564 A list of gradient values. May include Nones. 565 """ 566 result = [] 567 # outputs[0] is the loop counter. 568 # outputs[1] is maximum_iterations. 569 # outputs[2] is the total number of loop iterations. 570 outputs_idx = 3 571 structured_outputs_idx = 3 572 for g in grads: 573 # Set None as the output gradient for tensors with None input gradient. 574 if g is None: 575 result.append(None) 576 continue 577 output = body_grad_graph.structured_outputs[structured_outputs_idx] 578 structured_outputs_idx += 1 579 if isinstance(output, ops.IndexedSlices): 580 # TODO(skyewm): is there a more robust way to determine the order of 581 # flattened IndexedSlices components? 582 result.append(ops.IndexedSlices( 583 values=outputs[outputs_idx], 584 indices=outputs[outputs_idx + 1], 585 dense_shape=outputs[outputs_idx + 2])) 586 outputs_idx += 3 587 else: 588 assert isinstance(output, ops.Tensor) 589 result.append(outputs[outputs_idx]) 590 outputs_idx += 1 591 592 return result 593 594 595def _get_accumulator(tensor): 596 r"""Returns TensorList if any containing accumulated values of tensor. 597 598 We try to find a pattern of the form: 599 600 input_tl tensor 601 \ / 602 (TensorListPushBack) 603 | 604 output_tl 605 606 which satisfies the following conditions: 607 608 1. input_tl must be in tensor.graph.inputs. 609 2. output_tl or Identity(output_tl) must be in tensor.graph.outputs. 610 3. tensor.graph.input_index(input_tl) == tensor.graph.output_index(output_t). 611 612 output_tl or Identity(output_tl) (whichever is in tensor.graph.outputs) is 613 returned if such a pattern is found else None is returned. 614 615 Args: 616 tensor: The Tensor to be accumulated. 617 618 Returns: 619 A variant tensor in the same graph as `tensor` or None if no accumulator is 620 found. 621 """ 622 assert isinstance(tensor.graph, func_graph_module.FuncGraph) 623 624 def get_func_graph_output(t): 625 """Returns t or Identity(t) whichever exists in graph outputs else None.""" 626 if t in tensor.graph.outputs: 627 return t 628 # tf.defun adds an Identity for each output, check whether that is the case. 629 identity_op = t.consumers()[0] 630 if (identity_op.type == "Identity" and 631 identity_op.outputs[0] in tensor.graph.outputs): 632 return identity_op.outputs[0] 633 return None 634 635 for consumer in tensor.consumers(): 636 # Find the consumer that is a TensorListPushBack node whose TensorList input 637 # is in the list of function inputs. 638 if (consumer.type != "TensorListPushBack" or 639 consumer.inputs[0] not in tensor.graph.inputs): 640 continue 641 642 output = get_func_graph_output(consumer.outputs[0]) 643 if output is None: 644 # The TensorList output of `consumer` is not in the list of function 645 # outputs. 646 continue 647 648 accum_input_idx = tensor.graph.inputs.index(consumer.inputs[0]) 649 accum_output_idx = tensor.graph.outputs.index(output) 650 if accum_input_idx == accum_output_idx: 651 return output 652 return None 653 654 655class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): 656 """FuncGraph for the gradient function of the body of a While op. 657 658 Contains the logic for capturing the tensors from the body of the forward 659 While op which is as follows: 660 1. If the tensor is of resource type (these are not accumulated): 661 a. Ensure that the tensor is a loop invariant, i.e., it exists in both loop 662 inputs and outputs at the same index. 663 b. Lookup the corresponding resource tensor in the forward outer graph and 664 try to capture that. 665 2. If the tensor is not of resource type: 666 a. Create an accumulator for that tensor and output it from the forward 667 pass. Note this also requires adding it as an input to the forward pass. 668 b. Capture the accumulator from the forward pass in this FuncGraph. This 669 will later be resolved to the correct output of the forward While op. 670 c. Pop a value from the captured placeholder and use it as the captured 671 value for the forward pass tensor. 672 673 This only allows capturing tensors in the forward graph. A ValueError is 674 raised if an attempt is made to capture a tensor not in the forward graph. 675 To manually capture capture a tensor that is not in the forward graph, call 676 `capture` with `whitelisted=True`. 677 678 Note: The `captures` dict does not contain the forward tensor since it is not 679 directly captured. It contains the accumulator corresponding to this forward 680 tensor. 681 682 Attributes: 683 while_op_needs_rewrite: True if any non-resource intermediates were 684 captured, meaning the forward While op needs to be rewritten to output the 685 corresponding accumulators. 686 empty_tensor_lists: list of EmptyTensorList tensors to be used as initial 687 input to the new accumulators in the forward graph. 688 popped_tensor_lists: dict from the captured accumulator placeholder to the 689 TensorList obtained after popping the intermediate tensor from it. The 690 values of this dict need to be added to the list of outputs. 691 """ 692 693 def __init__(self, name, forward_cond_graph, forward_body_graph, 694 maximum_iterations): 695 super(_WhileBodyGradFuncGraph, self).__init__(name) 696 self.empty_tensor_lists = [] 697 self.popped_tensor_lists = {} 698 # FuncGraph for the body of the forward While op. 699 self._forward_graph = forward_body_graph 700 # FuncGraph for the cond of the forward While op. 701 self._forward_cond_graph = forward_cond_graph 702 self._maximum_iterations = maximum_iterations 703 # Dict from forward intermediate tensor to its indirectly captured tensor 704 # in this graph. Indirect capturing happens in two ways: 705 # 1. For non-resource tensors we capture their accumulators from the forward 706 # outer graph and pop values from that accumulator inside this graph 707 # using TensorListPopBack. 708 # 2. For resource tensors we directly capture their corresponding tensor 709 # in the forward outer graph. 710 self._indirect_captures = {} 711 712 @property 713 def while_op_needs_rewrite(self): 714 return self.empty_tensor_lists 715 716 def capture(self, tensor, name=None, whitelisted=False): 717 """Selectively captures external tensors. 718 719 If `whitelisted` is False only allows capturing tensors in the 720 `_forward_graph`. 721 722 Args: 723 tensor: Tensor. May be from this FuncGraph or a different graph. 724 name: Optional name if a placeholder is created. 725 whitelisted: If False (default), only allows capturing tensors from the 726 forward graph. 727 728 Returns: 729 The placeholder in this graph for the tensor. 730 731 Raises: 732 ValueError: If attempting to capture an external tensor not in the forward 733 graph with `whitelisted` set to False. 734 """ 735 if (not whitelisted and tensor.graph is not self and 736 tensor.graph != self._forward_graph): 737 raise ValueError("Attempting to capture tensor %s which is not in the " 738 "forward graph but in %s." % 739 (str(tensor), _graph_name(tensor.graph))) 740 return super(_WhileBodyGradFuncGraph, self).capture(tensor, name) 741 742 def _capture_helper(self, tensor, name): 743 if tensor.graph is not self._forward_graph: 744 return super(_WhileBodyGradFuncGraph, self)._capture_helper(tensor, name) 745 746 while tensor.op.type == "Identity": 747 # We do not accumulate the output of identity nodes so we try to capture 748 # the input of the Identity node instead. 749 tensor = tensor.op.inputs[0] 750 751 captured_tensor = self._indirect_captures.get(tensor) 752 if captured_tensor is not None: 753 return captured_tensor 754 755 # Resource tensors are not accumulated and handled specially. 756 if tensor.dtype == dtypes.resource: 757 return self._resource_capture_helper(tensor) 758 759 # Create or find an existing accumulator output for `tensor` in the forward 760 # graph, and fetch from this accumulator in the gradient graph to get the 761 # raw intermediate value. 762 accumulator = _get_accumulator(tensor) 763 if accumulator is None: 764 # Create the initial empty tensor list. 765 with self._forward_graph.outer_graph.as_default(): 766 tensor_list = list_ops.empty_tensor_list( 767 element_dtype=tensor.dtype, element_shape=tensor.shape, 768 max_num_elements=self._maximum_iterations) 769 self.empty_tensor_lists.append(tensor_list) 770 771 # Push the intermediate tensor to the tensor list. This captures 772 # `tensor_list`. 773 with self._forward_graph.as_default(): 774 accumulator = list_ops.tensor_list_push_back(tensor_list, tensor) 775 # Add the modified tensor list to the list of outputs. This output will be 776 # all the accumulated values. 777 self._forward_graph.outputs.append(accumulator) 778 779 # Capture in the cond graph as well so the forward cond and body inputs 780 # match. 781 with self._forward_cond_graph.as_default(): 782 self._forward_cond_graph.capture(tensor_list) 783 784 # Capture the accumulator tensor list in the gradient graph directly from 785 # the forward graph -- we'll later modify this to capture the final list 786 # output by the forward While op instead. 787 captured_accumulator = super(_WhileBodyGradFuncGraph, self)._capture_helper( 788 accumulator, name) 789 790 # Pop the intermediate value from the tensor list in the gradient graph. 791 new_tensor_list, captured_tensor = list_ops.tensor_list_pop_back( 792 captured_accumulator, element_dtype=tensor.dtype) 793 794 self._indirect_captures[tensor] = captured_tensor 795 self.popped_tensor_lists[captured_accumulator] = new_tensor_list 796 return captured_tensor 797 798 def _resource_capture_helper(self, tensor): 799 """Returns the captured resource tensor. 800 801 Resource-type tensors are not accumulated. If a resource tensor exists in 802 the loop body it must either be a loop input or an output of a nested While 803 op inside the loop body which had captured the external resource. 804 805 Args: 806 tensor: the external resource Tensor to be captured. 807 808 Returns: 809 Tensor in this graph. 810 """ 811 assert tensor.dtype == dtypes.resource 812 813 index = self._resource_input_index( 814 tensor.name, 815 [t.name for t in self._forward_graph.inputs], 816 {op.name: op.node_def for op in self._forward_graph.get_operations()}, 817 self._forward_graph._functions) 818 819 input_placeholder = self._forward_graph.inputs[index] 820 tensor_in_outer_graph = self._forward_graph._while.inputs[index] 821 822 assert input_placeholder.dtype == dtypes.resource 823 assert tensor_in_outer_graph.dtype == dtypes.resource 824 # This must be a loop invariant. 825 assert input_placeholder == self._forward_graph.outputs[index], ( 826 "Resource tensors must be loop invariants %s." % 827 tensor_in_outer_graph) 828 829 self._indirect_captures[tensor] = self.capture( 830 tensor_in_outer_graph, whitelisted=True) 831 return self._indirect_captures[tensor] 832 833 def _resource_input_index(self, tensor_name, input_names, node_defs, 834 functions): 835 """Returns the index of the input corresponding to `tensor_name`. 836 837 This method is used to find the corresponding index of an arbitrary resource 838 tensor in a function (the function could be a loop body). We assume that 839 resource handles are never created in functions, so that every resource 840 tensor can be traced back to a function input. 841 842 The awkward signature of this method is to make it work with both FuncGraphs 843 and FunctionDefs. This is so we can recurse on function call ops without 844 building the corresponding FuncGraph (note that even if a FuncGraph for a 845 FunctionDef already exists, the input/output/node names may have been 846 changed when the FuncGraph was serialized to the FunctionDef, which makes it 847 unusable with this algorithm). 848 849 Args: 850 tensor_name: the name of the resource tensor to be resolved to an input. 851 input_names: a list of the names of all inputs to the function. 852 node_defs: a dict mapping op name -> NodeDef for every op in the function. 853 functions: a dict mapping function name -> _EagerDefinedFunction. 854 855 Returns: 856 The index into input_names corresponding to `tensor_name`. 857 """ 858 while tensor_name not in input_names: 859 # FunctionDefs and graphs use different tensor naming conventions. 860 parts = tensor_name.split(":") 861 if len(parts) == 3: 862 op_name, _, output_idx = parts 863 elif len(parts) == 2: 864 op_name, output_idx = parts 865 else: 866 assert len(parts) == 1 867 op_name = parts[0] 868 output_idx = 0 869 output_idx = int(output_idx) 870 node_def = node_defs[op_name] 871 872 if node_def.op == "While": 873 # Captured resources occur at the same index in the lists of inputs and 874 # outputs of a while op. So we lookup the input of `tensor.op` at the 875 # same index as the index of `tensor` in the `tensor.op.outputs`. 876 tensor_name = node_def.input[output_idx] 877 elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"): 878 # Functions output any captured resource tensors used by their 879 # gradients. `tensor_name` is one of these outputs from a nested 880 # function call, so recursively find the corresponding input in the 881 # nested FunctionDef. 882 func_name = node_def.attr["f"].func.name 883 fdef = functions[func_name].definition 884 output_arg_name = fdef.signature.output_arg[output_idx].name 885 output_tensor_name = fdef.ret[output_arg_name] 886 input_index = self._resource_input_index( 887 output_tensor_name, 888 [arg.name for arg in fdef.signature.input_arg], 889 {ndef.name: ndef for ndef in fdef.node_def}, 890 functions) 891 tensor_name = node_def.input[input_index] 892 else: 893 # We assume there are no other ops types that will "forward" resource 894 # handles like this, so all other handles must have been created by the 895 # op. (Note that cond_v2 wraps resource handle outputs in optionals, 896 # which we'll end up accumulating). 897 raise ValueError( 898 "Taking gradient of a while loop which creates " 899 "a resource in its body is not supported: %s" % op_name) 900 901 return input_names.index(tensor_name) 902 903 904def _check_shapes_compat(output_tensors, shape_invariants, input_tensors): 905 for (t, shape, input_t) in zip(output_tensors, shape_invariants, 906 input_tensors): 907 if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape): 908 raise ValueError( 909 "Input tensor '%s' enters the loop with shape %s, but has " 910 "shape %s after one iteration. To allow the shape to vary across " 911 "iterations, use the `shape_invariants` argument of tf.while_loop to " 912 "specify a less-specific shape." % (input_t.name, shape, t.shape)) 913 914 915def _check_num_inputs_outputs(cond_graph, body_graph, num_flattened_loop_vars): 916 """Checks the number of inputs/outputs of `cond_graph` and `body_graph`.""" 917 assert len(cond_graph.inputs) == num_flattened_loop_vars, ( 918 "cond_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs), 919 num_flattened_loop_vars)) 920 assert len(cond_graph.outputs) == 1, ( 921 "cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs)) 922 assert len(body_graph.inputs) == num_flattened_loop_vars, ( 923 "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs), 924 num_flattened_loop_vars)) 925 assert len(body_graph.outputs) == num_flattened_loop_vars, ( 926 "body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs), 927 num_flattened_loop_vars)) 928 929 930def _copy_handle_data(src_tensors, tgt_tensors): 931 for src_t, tgt_t in zip(src_tensors, tgt_tensors): 932 custom_gradient.copy_handle_data(src_t, tgt_t) 933 934 935# TODO(srbs): This method should be in control_flow_util but that introduces 936# a circular dependency ops -> control_flow_util -> ops. 937def _is_in_xla_context(): 938 """Returns whether the current context is inside an XLA context.""" 939 outer_graph = ops.get_default_graph() 940 # The `_control_flow_context` is not copied when building a FuncGraph so 941 # we look it up from the base graph. 942 while isinstance(outer_graph, func_graph_module.FuncGraph): 943 outer_graph = outer_graph.outer_graph 944 cur_ctxt = outer_graph._get_control_flow_context() # pylint: disable=protected-access 945 return control_flow_util.GetContainingXLAContext(cur_ctxt) is not None 946 947 948def _graph_name(graph): 949 if isinstance(graph, func_graph_module.FuncGraph): 950 return graph.name 951 return "Base" 952 953 954def _pack_sequence_as(structure_with_tas, loop_vars): 955 """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays.""" 956 957 def flow_to_tensor_array(flow, ta): # pylint: disable=missing-docstring 958 return (tensor_array_ops.build_ta_with_new_flow(ta, flow) if isinstance( # pylint: disable=g-long-ternary 959 ta, tensor_array_ops.TensorArray) else flow) 960 961 flattened_loop_vars = [ 962 flow_to_tensor_array(*z) 963 for z in zip(nest.flatten(loop_vars), nest.flatten(structure_with_tas)) 964 ] 965 return nest.pack_sequence_as(structure_with_tas, flattened_loop_vars) 966 967 968def _tensor_array_to_flow(loop_vars): 969 970 def f(maybe_ta): 971 if isinstance(maybe_ta, tensor_array_ops.TensorArray): 972 return maybe_ta.flow 973 return maybe_ta 974 975 return nest.map_structure(f, loop_vars) 976 977 978def _build_signature(loop_vars, shape_invariants): 979 return nest.pack_sequence_as(loop_vars, [ 980 tensor_spec.TensorSpec(s, t.dtype, name=t.op.name) 981 for s, t in zip(nest.flatten(shape_invariants), nest.flatten(loop_vars)) 982 ]) 983 984 985def _build_maximum_iterations_loop_var(maximum_iterations): 986 if maximum_iterations is None: 987 # Default value for max_num_elements to EmptyTensorList meaning that the 988 # list size is unbounded. 989 maximum_iterations = -1 990 # EmptyTensorList expects `max_num_elements` to be of type int32. 991 return ops.convert_to_tensor( 992 maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") 993 994# pylint: enable=protected-access 995