1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""cond_v2 and gradient. 16 17This is a version of cond that emits a single If op, as well as the gradient 18function for If ops produced by cond_v2. This will eventually replace the 19current tf.cond implementation once it reaches feature and performance parity. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import collections 27 28from tensorflow.core.framework import types_pb2 29from tensorflow.python.eager import backprop_util 30from tensorflow.python.framework import auto_control_deps 31from tensorflow.python.framework import auto_control_deps_utils as acd 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors_impl 35from tensorflow.python.framework import func_graph as func_graph_module 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import tensor_shape 38from tensorflow.python.framework import tensor_util 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import control_flow_util 41from tensorflow.python.ops import control_flow_util_v2 as util 42from tensorflow.python.ops import default_gradient 43from tensorflow.python.ops import gen_dataset_ops 44from tensorflow.python.ops import gen_functional_ops 45from tensorflow.python.ops import gradients_util 46from tensorflow.python.ops import handle_data_util 47from tensorflow.python.ops import math_ops 48from tensorflow.python.util import nest 49 50 51# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify 52# that they aren't part of the official public API. These protected members 53# often need to be used by implementation code however. Rather than litter the 54# code with pylint comments, we ignore protected access violations for 55# readability. 56# pylint: disable=protected-access 57 58_COND = 1 59_CASE = 2 60 61 62def cond_v2(pred, true_fn, false_fn, name="cond"): 63 """Like tf.cond, except emits a single If op.""" 64 if isinstance(pred, bool): 65 raise TypeError("pred must not be a Python bool", pred) 66 67 if not name: 68 name = "cond" 69 70 with ops.name_scope(name) as scope: 71 true_name = util.unique_fn_name(scope, "true") 72 false_name = util.unique_fn_name(scope, "false") 73 74 # Automatic control dependencies are added in defuns, but not in v1 75 # graphs. Propagate that behavior here. 76 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 77 pred = ops.convert_to_tensor(pred) 78 if (tensor_util.is_tf_type(pred) and 79 (pred.shape.dims is None or pred.shape.dims)): 80 pred = array_ops.squeeze_v2(pred) 81 82 true_graph = func_graph_module.func_graph_from_py_func( 83 true_name, 84 true_fn, [], {}, 85 func_graph=util.CondBranchFuncGraph( 86 true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 87 add_control_dependencies=add_control_dependencies, 88 op_return_value=pred) 89 false_graph = func_graph_module.func_graph_from_py_func( 90 false_name, 91 false_fn, [], {}, 92 func_graph=util.CondBranchFuncGraph( 93 false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 94 add_control_dependencies=add_control_dependencies, 95 op_return_value=pred) 96 97 verify_captures(_COND, [true_graph, false_graph]) 98 return _build_cond( 99 pred, 100 true_graph, 101 false_graph, 102 true_graph.external_captures, 103 false_graph.external_captures, 104 building_gradient=False, 105 name=scope) 106 107 108@ops.RegisterGradient("StatelessIf") 109@ops.RegisterGradient("If") 110def _IfGrad(op, *grads): # pylint: disable=invalid-name 111 """The gradient of an If op produced by cond_v2.""" 112 # Get the if operator (this logic handles the case where op is a MockOp) 113 if_op = op.outputs[0].op 114 true_graph, false_graph = get_func_graphs(if_op) 115 # Note: op.graph != ops.get_default_graph() when we are computing the gradient 116 # of a nested cond. 117 assert true_graph.outer_graph == if_op.graph 118 assert false_graph.outer_graph == if_op.graph 119 120 # Create grad functions that compute the gradient of the true/false forward 121 # graphs. These functions will capture tensors from the forward pass 122 # functions. 123 true_grad_graph = _create_grad_func( 124 true_graph, grads, util.unique_grad_fn_name(true_graph.name)) 125 false_grad_graph = _create_grad_func( 126 false_graph, grads, util.unique_grad_fn_name(false_graph.name)) 127 128 # Replaces output None grads with zeros if at least one branch has non-None 129 # grad at that index. 130 _create_zeros_for_none_grads([true_graph, false_graph], 131 [true_grad_graph, false_grad_graph]) 132 133 if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite): 134 # Modify 'op' to output the intermediates needed by the grad functions. Note 135 # that all needed intermediates are wrapped in optionals. Each optional 136 # intermediate output will have a value iff its corresponding branch is 137 # taken. 138 # NOTE(skyewm): if there are any active sessions, this modification to `op` 139 # may make them unrunnable! 140 141 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 142 # XLA does not yet support optionals, so output intermediates directly and 143 # make them match via FakeParams, which can be converted to zeros in XLA. 144 # TODO(skyewm,jpienaar): can XLA support optionals? 145 true_intermediates = true_grad_graph.xla_intermediates 146 false_intermediates = false_grad_graph.xla_intermediates 147 extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( 148 [true_graph, false_graph], [true_intermediates, false_intermediates]) 149 else: 150 true_intermediates = true_grad_graph.wrapped_intermediates 151 false_intermediates = false_grad_graph.wrapped_intermediates 152 # Make outputs match by adding none optionals. 153 extra_true_outputs, extra_false_outputs = _make_intermediates_match( 154 [true_graph, false_graph], [true_intermediates, false_intermediates]) 155 156 true_graph.outputs.extend(extra_true_outputs) 157 false_graph.outputs.extend(extra_false_outputs) 158 # TODO(skyewm): indicate it's an internal bug if this fails. 159 _check_same_outputs(_COND, [true_graph, false_graph]) 160 161 true_graph.name += "_rewritten" 162 false_graph.name += "_rewritten" 163 164 if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph)) 165 if_op._set_func_attr("else_branch", 166 util.create_new_tf_function(false_graph)) 167 if_op._set_type_list_attr("Tout", true_graph.output_types) 168 if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes) 169 if_op._add_outputs( 170 [t.dtype for t in extra_true_outputs], 171 [t.shape for t in extra_true_outputs]) 172 173 # Resolve references to forward graph tensors in grad graphs and ensure 174 # they are in-scope, i.e., belong to one of outer graphs of the grad graph. 175 true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) 176 false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) 177 178 # This modifies true_grad_graph and false_grad_graph. 179 _make_output_composite_tensors_match(_COND, 180 [true_grad_graph, false_grad_graph]) 181 182 outputs = _build_cond( 183 if_op.inputs[0], 184 true_grad_graph, 185 false_grad_graph, 186 true_grad_inputs, 187 false_grad_inputs, 188 building_gradient=True, 189 ) 190 191 # The predicate has no gradient. 192 return [None] + outputs 193 194 195def _build_cond(pred, 196 true_graph, 197 false_graph, 198 true_inputs, 199 false_inputs, 200 building_gradient, 201 name=None): 202 """Creates an If op from the specified predicate, branch functions and inputs. 203 204 Note that this modifies true_graph and false_graph to make the inputs match, 205 and to output all intermediates values so they're available for the gradient 206 computation. 207 208 true_graph and false_graph need not have the same input types, but they must 209 have the same output types. 210 211 Args: 212 pred: boolean Tensor 213 true_graph: FuncGraph 214 false_graph: FuncGraph 215 true_inputs: a list of Tensors to be passed to true_graph as input. 216 false_inputs: a list of Tensors to be passed to false_graph as input. 217 building_gradient: Whether this is a gradient If op. 218 name: the name for the If op. 219 220 Returns: 221 A list of Tensors which are the outputs of the If op. Does not include added 222 intermediate outputs. 223 """ 224 _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph]) 225 _check_same_outputs(_COND, [true_graph, false_graph]) 226 227 # Add inputs to true_graph and false_graph to make them match. Note that 228 # this modifies true_graph and false_graph. 229 cond_inputs = _make_inputs_match([true_graph, false_graph], 230 [true_inputs, false_inputs]) 231 # We do not output intermediates of the gradient If op since this is just 232 # for backwards compatibility with existing code. 233 if not building_gradient and util.output_all_intermediates(): 234 # Add all intermediate tensors as function outputs so they're available for 235 # the gradient computation. Since the outputs of the two functions must 236 # match, we wrap all the intermediates in optionals. Each intermediate 237 # output will have a value iff its corresponding branch is taken. 238 239 true_intermediates = _get_intermediates(true_graph) 240 false_intermediates = _get_intermediates(false_graph) 241 242 # Wrap intermediates in optionals. 243 wrapped_true_intermediates = _wrap_intermediates(true_graph, 244 true_intermediates) 245 wrapped_false_intermediates = _wrap_intermediates(false_graph, 246 false_intermediates) 247 248 # Make outputs match by adding none optionals. 249 extra_true_outputs, extra_false_outputs = _make_intermediates_match( # pylint: disable=unbalanced-tuple-unpacking 250 [true_graph, false_graph], 251 [wrapped_true_intermediates, wrapped_false_intermediates]) 252 253 true_graph.outputs.extend(extra_true_outputs) 254 false_graph.outputs.extend(extra_false_outputs) 255 _check_same_outputs(_COND, [true_graph, false_graph]) 256 257 # Create the If op. 258 with ops.control_dependencies( 259 list(true_graph.control_captures) + list(false_graph.control_captures)): 260 true_stateful_ops = [ 261 op for op in true_graph.get_operations() if op._is_stateful 262 ] 263 false_stateful_ops = [ 264 op for op in false_graph.get_operations() if op._is_stateful 265 ] 266 if (true_stateful_ops or false_stateful_ops): 267 op_fn = gen_functional_ops._if 268 else: 269 op_fn = gen_functional_ops.stateless_if 270 271 def _make_op(inputs): 272 if_op, tensors = util.get_op_and_outputs(op_fn( 273 pred, 274 inputs, [t.dtype for t in true_graph.outputs], 275 util.create_new_tf_function(true_graph), 276 util.create_new_tf_function(false_graph), 277 output_shapes=_get_output_shapes(true_graph.outputs, 278 false_graph.outputs), 279 name=name)) 280 _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs) 281 # `if_op` is None if this is a `StatelessIf` op with no outputs. 282 if if_op is not None: 283 # The true and false graphs have already been created, and we need that 284 # to happen before we know which tensors will be captured and so whether 285 # to wrap the cond in a tf.function. Post-hoc mutation of the branch 286 # `outer_graph` properties seems like the only option if we want to 287 # conditionally wrap in a function. 288 true_graph.outer_graph = ops.get_default_graph() 289 false_graph.outer_graph = ops.get_default_graph() 290 if_op._true_graph = true_graph 291 if_op._false_graph = false_graph 292 util.maybe_set_lowering_attr(if_op) 293 util.maybe_propagate_compile_time_consts_in_xla(if_op) 294 _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph]) 295 # Prevent fetching since the variant outputs can't be fetched directly. 296 if_op.graph.prevent_fetching(if_op) 297 return tensors 298 tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs) 299 300 # Return identities for each output of the If op, rather than the output of 301 # the If op directly. This makes pruning work if the output of cond() is 302 # fetched: the lowering pass converts the If outputs into IdentityN outputs, 303 # which if fetched will cause all ops in the taken branch to be run (since 304 # it takes all merge ops as input). After lowering, each output identity op 305 # will end up with only the appropriate merge op as input. 306 # TODO(b/79984175): this doesn't have to be a tuple once we covert to the 307 # correct output structure 308 tensors = [array_ops.identity(t) for t in tensors] 309 310 return _pack_sequence_as(true_graph.structured_outputs, tensors) 311 312 313def get_func_graphs(op): 314 """Returns `FuncGraph`s for the input op branches. 315 316 Args: 317 op: The If or Case Operation. 318 319 Returns: 320 A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches 321 for Case). 322 """ 323 324 def _get_func_graph_for_branch(name_attr_list, cached_attr_name=None): 325 """Generates and returns a FuncGraph for the given branch.""" 326 func_graph = None 327 if cached_attr_name is not None: 328 func_graph = getattr(op, cached_attr_name, None) 329 inputs = op.inputs[1:] # First input is pred. 330 if func_graph is None: 331 input_shapes = [t.shape for t in inputs] 332 func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name) 333 for external_t, internal_t in zip(inputs, func_graph.inputs): 334 handle_data_util.copy_handle_data(external_t, internal_t) 335 func_graph.reset_captures(zip(inputs, func_graph.inputs)) 336 # Link the op so that the gradient code can use it. 337 func_graph._forward_cond = op 338 return func_graph 339 340 if op.type in ["If", "StatelessIf"]: 341 return (_get_func_graph_for_branch( 342 op.get_attr("then_branch"), "_true_graph"), 343 _get_func_graph_for_branch( 344 op.get_attr("else_branch"), "_false_graph")) 345 elif op.type in ["Case", "StatelessCase"]: 346 return [_get_func_graph_for_branch(branch_fn, "_branch_graph_{}".format(i)) 347 for i, branch_fn in enumerate(op.get_attr("branches"))] 348 else: 349 raise ValueError("Unsupported op type: {}".format(op.type)) 350 351 352def _grad_fn(func_graph, grads): 353 """The gradient function for each conditional branch. 354 355 This function builds the gradient graph of the corresponding forward-pass 356 conditional branch in `func_graph`. This is done by differentiating 357 func_graph's outputs w.r.t. its inputs. 358 359 Args: 360 func_graph: FuncGraph. The corresponding forward-pass function. 361 grads: The list of input gradient Tensors. 362 363 Returns: 364 The output gradient Tensors. 365 """ 366 # Filter out untrainable function outputs. 367 # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes 368 # cause _GradientsHelper to raise an exception (e.g. the implementation 369 # doesn't expect 'ys' to contain boolean tensors). 370 assert len(func_graph.outputs) == len(grads) 371 ys = [] 372 grad_ys = [] 373 for y, grad_y in zip(func_graph.outputs, grads): 374 if not backprop_util.IsTrainable(y): 375 continue 376 ys.append(y) 377 grad_ys.append(grad_y) 378 379 # Build the gradient graph. Note that this builds the gradient computation of 380 # func_graph in the current graph, which requires capturing tensors from 381 # func_graph. The captured func_graph tensors are resolved to external tensors 382 # in _resolve_grad_inputs. 383 result = gradients_util._GradientsHelper( 384 ys, func_graph.inputs, grad_ys=grad_ys, 385 src_graph=func_graph) 386 387 return result 388 389 390def _create_grad_func(func_graph, grads, name): 391 """Returns the FuncGraph representation of _grad_fn.""" 392 return func_graph_module.func_graph_from_py_func( 393 name, 394 lambda: _grad_fn(func_graph, grads), [], {}, 395 func_graph=_CondGradFuncGraph(name, func_graph)) 396 397 398def _resolve_grad_inputs(cond_graph, grad_graph): 399 """Returns the tensors to pass as inputs to `grad_graph`. 400 401 The `grad_graph` may have external references to 402 1. Its outer graph containing the input gradients. These references are kept 403 as is. 404 2. Tensors in the forward pass graph. These tensors may not be "live" 405 when the gradient is being computed. We replace such references by their 406 corresponding tensor in `cond_graph.outer_graph`. In the case of nested 407 control flow or functions, the gradient logic handling 408 `grad_graph.outer_graph` will make sure the tensor from 409 `cond_graph.outer_graph` is also correctly captured. 410 411 Args: 412 cond_graph: FuncGraph. The forward-pass function. 413 grad_graph: FuncGraph. The gradients function. 414 415 Returns: 416 A list of inputs tensors to be passed to grad_graph. 417 """ 418 new_inputs = [] 419 420 for t in grad_graph.external_captures: 421 # `t` must either be in `grad_graph.outer_graph` or in the forward 422 # `cond_graph`. 423 if t.graph != grad_graph.outer_graph: 424 assert t.graph == cond_graph 425 # `internal_captures` are not treated as intermediates and hence not added 426 # to If op outputs. So we get the outer tensor corresponding to those 427 # from the list of `external_captures`. 428 for i, output in enumerate(t.graph.outputs): 429 if output is t: 430 t = t.graph._forward_cond.outputs[i] 431 break 432 else: 433 for i, output in enumerate(t.graph.internal_captures): 434 if output is t: 435 t = t.graph.external_captures[i] 436 break 437 else: 438 raise ValueError("Could not find external tensor capture {tensor} in " 439 "captures or outputs".format(tensor=t)) 440 441 # Note: We rely on the capturing logic of the gradient If op graph to 442 # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2 443 # and while_v2 handle this while building their gradient functions. 444 assert t.graph == cond_graph.outer_graph 445 new_inputs.append(t) 446 447 return new_inputs 448 449 450def _get_intermediates(func_graph): 451 """Returns intermediate tensors of `func_graph` for gradient computation.""" 452 intermediates = [] 453 for op in func_graph.get_operations(): 454 for t in op.outputs: 455 if t in func_graph.inputs: continue 456 if t in func_graph.outputs: continue 457 if t.dtype is dtypes.resource: 458 continue 459 # Accumulating mutexes can cause deadlock. 460 if op.type == "MutexLock": 461 continue 462 intermediates.append(t) 463 return intermediates 464 465 466def _make_intermediates_match(branch_graphs, branch_optionals): 467 """Returns new optionals lists that have matching signatures. 468 469 This is done by mirroring each list in the other using none optionals. 470 There is no merging of like optionals. 471 472 Args: 473 branch_graphs: `list` of `FuncGraph`. 474 branch_optionals: `list` of `list`s of optional `Tensor`s from other 475 branch_graphs 476 477 Returns: 478 A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the 479 same number of `Tensor`s, all of which will be optionals of the same 480 shape/type. 481 """ 482 new_branch_optionals = [] 483 # Since the intermediates are optionals with dtype variant, we only need 484 # enough room for the longest list of intermediates. 485 intermediates_size = max(len(o) for o in branch_optionals) 486 for i, branch_graph in enumerate(branch_graphs): 487 other_optionals = _create_none_optionals( 488 branch_graph, intermediates_size - len(branch_optionals[i])) 489 new_branch_optionals.append(branch_optionals[i] + other_optionals) 490 return new_branch_optionals 491 492 493def _make_intermediates_match_xla(branch_graphs, branch_intermediates): 494 """Like _make_intermediates_match but for the XLA case.""" 495 new_branch_intermediates = [] 496 for i, branch_graph in enumerate(branch_graphs): 497 other_fakeparams = _create_fakeparams( 498 branch_graph, 499 sum((bi for bi in branch_intermediates 500 if bi is not branch_intermediates[i]), [])) 501 num_preceding = sum(len(bi) for bi in branch_intermediates[:i]) 502 new_branch_intermediates.append(other_fakeparams[:num_preceding] + 503 branch_intermediates[i] + 504 other_fakeparams[num_preceding:]) 505 return new_branch_intermediates 506 507 508def _make_inputs_match(branch_graphs, branch_inputs): 509 """Modifies branch_graphs so they have the same input signature. 510 511 This method reorders and/or adds parameters to each graph in branch_graphs so 512 they have the same input signature, and updates the 'inputs' and 'captured' 513 fields of each graph accordingly. It uses the input tensors from the outer 514 graph to avoid duplicating shared arguments. 515 516 Args: 517 branch_graphs: a `list` of `FuncGraph` 518 branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The 519 inputs for the corresponding graph in `branch_graphs`. 520 521 Returns: 522 A new list of Tensors from the outer graph that are the new inputs for each 523 branch_graph. This is a deduped version of `sum(branch_inputs)`. 524 """ 525 assert len(branch_graphs) == len(branch_inputs) 526 added_inputs = set() 527 new_inputs = [] 528 for branch_in in branch_inputs: 529 for tensor in branch_in: 530 tensor_id = ops.tensor_id(tensor) 531 if tensor_id not in added_inputs: 532 added_inputs.add(tensor_id) 533 new_inputs.append(tensor) 534 535 for branch_graph, branch_in in zip(branch_graphs, branch_inputs): 536 input_ids = [ops.tensor_id(t) for t in branch_in] 537 branch_input_to_param = dict(zip(input_ids, branch_graph.inputs)) 538 input_list = [] 539 for in_t in new_inputs: 540 param = branch_input_to_param.get(ops.tensor_id(in_t)) 541 if param is None: 542 param = _create_dummy_input(branch_graph, in_t) 543 input_list.append(param) 544 545 branch_graph.inputs = input_list 546 547 # Rewrite the FuncGraphs' state to reflect the new inputs. 548 branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs)) 549 550 return new_inputs 551 552 553def _create_zeros_for_none_grads(forward_graphs, grad_graphs): 554 """Creates zeros for None out grads if at least one branch has non-None grad. 555 556 Args: 557 forward_graphs: List of forward FuncGraphs. 558 grad_graphs: List of grad FuncGraphs. 559 """ 560 assert len(forward_graphs) == len(grad_graphs) 561 branch_outputs = [g.structured_outputs for g in grad_graphs] 562 num_outputs_per_branch = [len(outs) for outs in branch_outputs] 563 assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch 564 for output_idx, branch_outs in enumerate(zip(*branch_outputs)): 565 if (any(t is None for t in branch_outs) and 566 any(t is not None for t in branch_outs)): 567 for branch_index, t in enumerate(branch_outs): 568 if t is None: 569 with grad_graphs[branch_index].as_default(): 570 zeros = default_gradient.zeros_like( 571 forward_graphs[branch_index].inputs[output_idx]) 572 grad_graphs[branch_index].structured_outputs[output_idx] = zeros 573 574 for grad_graph in grad_graphs: 575 grad_graph.outputs = [ 576 t for t in func_graph_module.flatten(grad_graph.structured_outputs) 577 if t is not None 578 ] 579 580 581def _make_output_composite_tensors_match(op_type, branch_graphs): 582 """Modifies each branch_graph's outputs to have the same output signature. 583 584 Currently the only transformation implemented is turning a Tensor into an 585 equivalent IndexedSlices if the other branch returns an IndexedSlices. 586 Updates branch_graph.{outputs,structured_outputs} for each branch_graph in 587 branch_graphs. 588 589 Args: 590 op_type: _COND or _CASE 591 branch_graphs: `list` of `FuncGraph` 592 593 Raises: 594 TypeError: if a set of outputs cannot be rewritten. 595 """ 596 # Note: since this is only used for gradient graphs, we do not expect the 597 # outputs to be structured (e.g. nested lists), and thus do not need to use 598 # nest.flatten, etc. 599 assert branch_graphs 600 branch_outputs = [g.structured_outputs for g in branch_graphs] 601 outputs_per_branch = list(len(outs) for outs in branch_outputs) 602 assert len(set(outputs_per_branch)) == 1, outputs_per_branch 603 604 for output_idx, branch_outs in enumerate(zip(*branch_outputs)): 605 if len(set(type(out) for out in branch_outs)) == 1: 606 continue 607 if not any(isinstance(out, ops.IndexedSlices) for out in branch_outs): 608 continue 609 for branch_idx, branch_out in enumerate(branch_outs): 610 if isinstance(branch_out, ops.IndexedSlices): 611 continue 612 elif isinstance(branch_out, ops.Tensor): 613 with branch_graphs[branch_idx].as_default(): 614 branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices( 615 branch_out) 616 else: 617 raise TypeError( 618 "Cannot reconcile {op_name} {output_idx}-th outputs:\n" 619 " outputs from all branches: {outputs}".format( 620 op_name="tf.cond" if op_type == _COND else "tf.switch_case", 621 output_idx=output_idx, 622 outputs=branch_outs)) 623 624 for branch_graph, branch_outs in zip(branch_graphs, branch_outputs): 625 branch_graph.structured_outputs = branch_outs 626 branch_graph.outputs = [ 627 t for t in func_graph_module.flatten(branch_outs) if t is not None 628 ] 629 630 631def _make_indexed_slices_indices_types_match(op_type, branch_graphs): 632 """Match dtype of IndexedSlices.indices in outputs of branch_graphs.""" 633 assert branch_graphs 634 # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`. 635 indexed_slice_indices = [] 636 current_index = 0 637 # Note that this still contains Nones. We leave those in so that error 638 # messages contain the correct indices. We handle the Nones later when 639 # updating `current_index`. 640 branch_outputs_flat_with_composites = [ 641 nest.flatten(branch_graph.structured_outputs, expand_composites=False) 642 for branch_graph in branch_graphs 643 ] 644 outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites] 645 assert len(set(outs_per_branch)) == 1, outs_per_branch 646 # Store indices of IndexedSlices.indices in `indexed_slice_indices`. 647 for output_idx, branch_outs in enumerate( 648 zip(*branch_outputs_flat_with_composites)): 649 if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1: 650 raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n" 651 " branches returned: {outputs}".format( 652 op_name="cond" if op_type == _COND else "switch_case", 653 output_idx=output_idx, 654 outputs=branch_outs)) 655 if isinstance(branch_outs[0], ops.IndexedSlices): 656 # indices is the second component of the composite tensor. 657 indexed_slice_indices.append(current_index + 1) 658 if nest.is_sequence_or_composite(branch_outs[0]): 659 current_index += len(nest.flatten(branch_outs[0], expand_composites=True)) 660 elif branch_outs[0] is not None: 661 # `FuncGraph.outputs` does not contain Nones so no need to update the 662 # counter in that case. 663 current_index += 1 664 665 if not indexed_slice_indices: 666 return 667 668 # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus 669 # the Nones. 670 if current_index != len(branch_graphs[0].outputs): 671 raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n" 672 "Expected: %i\n" 673 "Actual: %i" % 674 (current_index, len(branch_graphs[0].outputs))) 675 676 # Cast indices with mismatching types to int64. 677 for index in indexed_slice_indices: 678 if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64) 679 for bg in branch_graphs): 680 raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " 681 "Found: %s" % 682 str([bg.outputs[index].dtype for bg in branch_graphs])) 683 if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1: 684 for branch_graph in branch_graphs: 685 if branch_graph.outputs[index].dtype == dtypes.int32: 686 with branch_graph.as_default(): 687 branch_graph.outputs[index] = math_ops.cast( 688 branch_graph.outputs[index], dtypes.int64) 689 690 for branch_graph in branch_graphs: 691 branch_graph.structured_outputs = _pack_sequence_as( 692 branch_graph.structured_outputs, branch_graph.outputs) 693 694 695def _pack_sequence_as(structured_outputs, op_outputs): 696 """Packs the outputs of the gradient If/Case op. 697 698 The branch functions may contain None's in the list of `structured_outputs`. 699 `op_outputs` has those outputs missing. So we need to add those Nones to the 700 list of `op_outputs` and then pack it in the same structure as 701 `structured_outputs`. 702 703 Args: 704 structured_outputs: structured_outputs from one of the branch functions. 705 op_outputs: List of output tensors of the op. 706 707 Returns: 708 `op_outputs` packed like `structured_outputs`. 709 """ 710 outputs_with_nones = [] 711 counter = 0 712 for output in nest.flatten(structured_outputs, expand_composites=True): 713 if output is None: 714 outputs_with_nones.append(None) 715 else: 716 outputs_with_nones.append(op_outputs[counter]) 717 counter += 1 718 return func_graph_module.pack_sequence_as(structured_outputs, 719 outputs_with_nones) 720 721 722def _wrap_intermediates(func_graph, intermediates): 723 with func_graph.as_default(): 724 return [gen_dataset_ops.optional_from_value([t]) for t in intermediates] 725 726 727def _create_dummy_input(func_graph, template_tensor): 728 """Creates tensors in func_graph to represent template_tensors. 729 730 Args: 731 func_graph: FuncGraph. 732 template_tensor: a tensor in the outer graph. 733 734 Returns: 735 A tensor in func_graph. 736 """ 737 with func_graph.as_default(): 738 return array_ops.placeholder( 739 template_tensor.dtype, shape=template_tensor.shape) 740 741 742def _create_none_optionals(func_graph, n): 743 """Creates `n` `None` optionals in func_graph. 744 745 Args: 746 func_graph: FuncGraph. 747 n: `int` the number of `None` optionals to make. 748 749 Returns: 750 A list of tensors in func_graph. 751 """ 752 with func_graph.as_default(): 753 return [gen_dataset_ops.optional_none() for _ in range(n)] 754 755 756def _create_fakeparams(func_graph, template_tensors): 757 """Create FakeParams for the XLA case.""" 758 with func_graph.as_default(): 759 return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape) 760 for t in template_tensors] 761 762 763def _check_same_outputs(op_type, graphs): 764 """Raises an error if `graphs` have different outputs.""" 765 766 def error(branch_idx, error_detail): 767 raise TypeError( 768 "{b0_name} and {bn_name} arguments to {op_name} must have the same " 769 "number, type, and overall structure of return values.\n" 770 "\n" 771 "{b0_name} output: {b0_out}\n" 772 "{bn_name} output: {bn_out}\n" 773 "\n" 774 "Error details:\n" 775 "{detail}".format( 776 b0_name="true_fn" if op_type == _COND else "branches[0]", 777 bn_name=("false_fn" if op_type == _COND else 778 "branches[{}]".format(branch_idx)), 779 op_name="tf.cond" if op_type == _COND else "tf.switch_case", 780 b0_out=graphs[0].structured_outputs, 781 bn_out=graphs[branch_idx].structured_outputs, 782 detail=error_detail)) 783 784 for b in range(1, len(graphs)): 785 try: 786 nest.assert_same_structure( 787 graphs[0].structured_outputs, 788 graphs[b].structured_outputs, 789 expand_composites=True) 790 except (ValueError, TypeError) as e: 791 error(b, str(e)) 792 793 op_type_str = "cond" if op_type == _COND else "case" 794 if len(graphs[0].outputs) != len(graphs[b].outputs): 795 raise ValueError("Lengths of branch outputs of {op_type} must match.\n" 796 "len(graphs[0].outputs): {len_0}\n" 797 "len(graphs[{b}].outputs): {len_b}\n".format( 798 op_type=op_type_str, 799 len_0=len(graphs[0].outputs), 800 b=b, 801 len_b=len(graphs[b].outputs))) 802 for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs): 803 if b0_out.dtype != bn_out.dtype: 804 error(b, "%s and %s have different types" % (b0_out, bn_out)) 805 806 807def _get_output_shapes(*branch_graph_outputs): 808 output_shapes = [] 809 for out_by_branch in zip(*branch_graph_outputs): 810 shape = out_by_branch[0].shape 811 for other_out in out_by_branch[1:]: 812 shape = shape.most_specific_compatible_shape(other_out.shape) 813 output_shapes.append(shape) 814 return output_shapes 815 816 817def _copy_handle_data(external_tensors, *branch_graph_outputs): 818 """Combines shapes in handle data and sets metadata on `external_tensors`.""" 819 for tensors in zip(external_tensors, *branch_graph_outputs): 820 external = tensors[0] 821 internal = tensors[1:] 822 internal_handle_data = [] 823 for tensor in internal: 824 handle_data = handle_data_util.get_resource_handle_data(tensor) 825 # NOTE: Assumes handle data has only one ShapeAndType entry. It's 826 # unclear how to combine different lengths across branches. 827 if not handle_data.is_set or len(handle_data.shape_and_type) != 1: 828 break 829 internal_handle_data.append(handle_data) 830 else: # There is handle data, so we need to combine it. 831 combined_shape = tensor_shape.TensorShape(None) 832 combined_dtype = None 833 for handle_data in internal_handle_data: 834 handle_shape = tensor_shape.TensorShape( 835 handle_data.shape_and_type[0].shape) 836 combined_shape = combined_shape.most_specific_compatible_shape( 837 handle_shape) 838 if combined_dtype is None: 839 combined_dtype = handle_data.shape_and_type[0].dtype 840 elif handle_data.shape_and_type[0].dtype != combined_dtype: 841 # Variants from different branches have different dtypes. The 842 # combined variant has no static dtype. 843 combined_dtype = types_pb2.DT_INVALID 844 combined_handle_data = internal_handle_data[0] 845 combined_handle_data.shape_and_type[0].shape.CopyFrom( 846 combined_shape.as_proto()) 847 combined_handle_data.shape_and_type[0].dtype = combined_dtype 848 handle_data_util.set_handle_data(external, combined_handle_data) 849 850 851def verify_captures(op_type, branch_graphs): 852 """Verify that a branch's tensor is not accessed in another branch fn.""" 853 # Note: It is technically not possible for lower-branch_index branches to 854 # capture tensors from higher-branch_index branches, because of the order of 855 # branch graph construction, but we check all for completeness and to 856 # guard against potential future changes. 857 other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)} 858 for i, branch_graph in enumerate(branch_graphs): 859 for t in branch_graph.external_captures: 860 if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs: 861 branch_names = ["true_fn", "false_fn"] if op_type == _COND else [ 862 "branch {}".format(bi) for bi in range(len(branch_graphs))] 863 raise ValueError( 864 "Tensor {tname} in {b0name} is accessed from {b1name}.".format( 865 tname=t.name, 866 b0name=branch_names[other_branch_graphs[t.graph]], 867 b1name=branch_names[i])) 868 869 870class _CondGradFuncGraph(util.CondBranchFuncGraph): 871 """FuncGraph for the gradient function of the branch of an If op. 872 873 Handles wrapping and unwrapping intermediate values that are captured by the 874 gradient computation in optionals. 875 876 Attributes: 877 op_needs_rewrite: True if any intermediates were captured, meaning the 878 forward If op needs to be written to output the wrapped intermediates. 879 """ 880 881 def __init__(self, name, forward_graph): 882 super(_CondGradFuncGraph, self).__init__( 883 name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access 884 self.op_needs_rewrite = False 885 self._forward_graph = forward_graph 886 # Maps from forward intermediate tensor -> the unwrapped captured 887 # intermediate. 888 self._indirect_captures = {} 889 # Maps unwrapped intermediate -> optional-wrapped intermediate in the 890 # forward graph. 891 self._wrapped_intermediates = collections.OrderedDict() 892 # Raw intermediates captured from the forward graph. Populated iff we're in 893 # an XLA context. 894 self._xla_intermediates = [] 895 # Maps forward intermediate constant valued tensor's id to the constant 896 # created in this graph for that tensor. 897 self._captured_constants = {} 898 899 @property 900 def wrapped_intermediates(self): 901 """The optional-wrapped intermediates captured from the forward graph.""" 902 return list(self._wrapped_intermediates.values()) 903 904 @property 905 def xla_intermediates(self): 906 """Raw intermediates captured from the forward graph if XLA is enabled.""" 907 return self._xla_intermediates 908 909 def _capture_helper(self, tensor, name): 910 if (tensor.graph is not self._forward_graph or 911 any(tensor is t for t in self._forward_graph.inputs) or 912 any(tensor is t for t in self._forward_graph.outputs)): 913 return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) 914 915 tensor_id = ops.tensor_id(tensor) 916 917 # If `tensor` is a graph-building time constant, we create a constant with 918 # the same value in the backward graph instead of capturing it. 919 if tensor_id in self._captured_constants: 920 return self._captured_constants[tensor_id] 921 elif constant_op.is_constant(tensor): 922 self._captured_constants[tensor_id] = constant_op.constant( 923 tensor_util.constant_value(tensor), dtype=tensor.dtype) 924 return self._captured_constants[tensor_id] 925 926 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 927 # XLA does not yet support optionals, so capture intermediates directly. 928 # TODO(skyewm,jpienaar): can XLA support optionals? 929 if all(tensor is not capture for capture in self.external_captures): 930 self.xla_intermediates.append(tensor) 931 self.op_needs_rewrite = True 932 return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) 933 934 captured_tensor = self._indirect_captures.get(tensor_id) 935 if captured_tensor is not None: 936 return captured_tensor 937 938 # 'tensor' is an uncaptured intermediate in the forward graph. 939 # If it is not a resource, we wrap it in an optional in the forward graph 940 # and capture the optional normally. We then unwrap the captured optional 941 # value in the gradient graph to get the raw intermediate value. 942 # If it is a resource, we trace the resource up to the input in the forward 943 # graph and capture that. 944 945 if tensor.dtype == dtypes.resource: 946 # Index of the forward graph input corresponding to the resource tensor. 947 index = util.resource_input_index( 948 tensor.name, [t.name for t in self._forward_graph.inputs], 949 {op.name: op.node_def for op in self._forward_graph.get_operations()}, 950 self._forward_graph._functions) 951 # This gets mapped to the corresponding If op input in 952 # `_resolve_grad_inputs`. 953 captured_tensor = super(_CondGradFuncGraph, self)._capture_helper( 954 self._forward_graph.inputs[index], name) 955 else: 956 if tensor_id not in self._wrapped_intermediates: 957 # If the gradient has already been computed for this If op, 'tensor' may 958 # already be wrapped. 959 for consumer in tensor.consumers(): 960 if (consumer.type == "OptionalFromValue" and 961 any(consumer.outputs[0] is output 962 for output in self._forward_graph.outputs)): 963 optional = consumer.outputs[0] 964 break 965 else: 966 # 'tensor' hasn't been wrapped, do it now. 967 with self._forward_graph.as_default(): 968 optional = gen_dataset_ops.optional_from_value([tensor]) 969 self.op_needs_rewrite = True 970 self._wrapped_intermediates[tensor_id] = optional 971 972 optional = self._wrapped_intermediates[tensor_id] 973 captured_optional = super(_CondGradFuncGraph, 974 self)._capture_helper(optional, name) 975 captured_tensor = gen_dataset_ops.optional_get_value( 976 captured_optional, [tensor.dtype], [tensor.shape])[0] 977 978 self._indirect_captures[tensor_id] = captured_tensor 979 return captured_tensor 980 981 982def indexed_case(branch_index, 983 branch_fns, 984 name="indexed_case", 985 lower_using_switch_merge=None): 986 """Like conv_v2, except emits a Case op instead of an If.""" 987 if isinstance(branch_index, int): 988 raise TypeError("branch_index must not be a Python int", branch_index) 989 990 with ops.name_scope(name) as scope: 991 branch_names = [ 992 util.unique_fn_name(scope, "branch{}".format(b)) 993 for b in range(len(branch_fns)) 994 ] 995 996 # Automatic control dependencies are added in defuns, but not in v1 997 # graphs. Propagate that behavior here. 998 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 999 branch_index = ops.convert_to_tensor(branch_index, name="branch_index") 1000 1001 branch_graphs = [] 1002 for branch_name, branch_fn in zip(branch_names, branch_fns): 1003 branch_graphs.append( 1004 func_graph_module.func_graph_from_py_func( 1005 branch_name, 1006 branch_fn, 1007 [], 1008 {}, 1009 func_graph=util.CondBranchFuncGraph( 1010 branch_name, 1011 collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 1012 add_control_dependencies=add_control_dependencies, 1013 op_return_value=branch_index)) 1014 1015 verify_captures(_CASE, branch_graphs) 1016 return _build_case( 1017 branch_index, 1018 branch_graphs, [g.external_captures for g in branch_graphs], 1019 name=scope, 1020 lower_using_switch_merge=lower_using_switch_merge) 1021 1022 1023@ops.RegisterGradient("Case") 1024@ops.RegisterGradient("StatelessCase") 1025def _CaseGrad(op, *grads): # pylint: disable=invalid-name 1026 """The gradient of a Case op produced by tf.switch_case.""" 1027 # Get the Case operator (this logic handles the case where op is a MockOp) 1028 case_op = op.outputs[0].op 1029 branch_graphs = get_func_graphs(case_op) 1030 assert branch_graphs 1031 # Note: op.graph != ops.get_default_graph() when we are computing the gradient 1032 # of a nested cond. 1033 for branch_graph in branch_graphs: 1034 assert branch_graph.outer_graph == case_op.graph 1035 1036 # Create grad functions that compute the gradient of the branch forward 1037 # graphs. These functions will capture tensors from the forward pass 1038 # functions. 1039 branch_grad_graphs = [] 1040 for branch_graph in branch_graphs: 1041 branch_grad_graphs.append( 1042 _create_grad_func(branch_graph, grads, 1043 util.unique_grad_fn_name(branch_graph.name))) 1044 # Replaces output None grads with zeros if at least one branch has non-None 1045 # grad at that index. 1046 _create_zeros_for_none_grads(branch_graphs, branch_grad_graphs) 1047 1048 if any(g.op_needs_rewrite for g in branch_grad_graphs): 1049 # Modify 'op' to output the intermediates needed by the grad functions. Note 1050 # that all needed intermediates are wrapped in optionals. Each optional 1051 # intermediate output will have a value iff its corresponding branch is 1052 # taken. 1053 # NOTE(bjp): if there are any active sessions, this modification to `op` 1054 # may make them unrunnable! 1055 1056 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 1057 # XLA does not yet support optionals, so output intermediates directly and 1058 # make them match via FakeParams, which can be converted to zeros in XLA. 1059 # TODO(bjp,jpienaar): can XLA support optionals? 1060 branches_intermediates = [ 1061 branch_grad_graph.xla_intermediates 1062 for branch_grad_graph in branch_grad_graphs 1063 ] 1064 extra_branch_outputs = _make_intermediates_match_xla( 1065 branch_graphs, branches_intermediates) 1066 else: 1067 branch_intermediates = [ 1068 g.wrapped_intermediates for g in branch_grad_graphs 1069 ] 1070 # Make outputs match by adding none optionals. 1071 extra_branch_outputs = _make_intermediates_match(branch_graphs, 1072 branch_intermediates) 1073 1074 for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs): 1075 branch_graph.outputs.extend(extra_outputs) 1076 # TODO(bjp): indicate it's an internal bug if this fails. 1077 _check_same_outputs(_CASE, branch_graphs) 1078 1079 for branch_graph in branch_graphs: 1080 branch_graph.name += "_rewritten" 1081 1082 case_op._set_func_list_attr("branches", [ 1083 util.create_new_tf_function(branch_graph) 1084 for branch_graph in branch_graphs 1085 ]) 1086 case_op._set_type_list_attr("Tout", branch_graphs[0].output_types) 1087 case_op._set_shape_list_attr("output_shapes", 1088 branch_graphs[0].output_shapes) 1089 case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]], 1090 [t.shape for t in extra_branch_outputs[0]]) 1091 1092 # Resolve references to forward graph tensors in grad graphs and ensure 1093 # they are in-scope, i.e., belong to one of outer graphs of the grad graph. 1094 branches_grad_inputs = [ 1095 _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph, 1096 branch_grad_graph in zip(branch_graphs, branch_grad_graphs) 1097 ] 1098 1099 # This modifies the graphs in branch_grad_graphs. 1100 _make_output_composite_tensors_match(_CASE, branch_grad_graphs) 1101 1102 try: 1103 lowering = case_op._get_attr_bool("_lower_using_switch_merge") 1104 except errors_impl.NotFoundError: 1105 lowering = None 1106 1107 outputs = _build_case( 1108 case_op.inputs[0], 1109 branch_grad_graphs, 1110 branches_grad_inputs, 1111 name="gradient", 1112 lower_using_switch_merge=lowering) 1113 1114 # The predicate has no gradient. 1115 return [None] + outputs 1116 1117 1118def _build_case(branch_index, 1119 branch_graphs, 1120 branch_inputs, 1121 name=None, 1122 lower_using_switch_merge=None): 1123 """Creates an `Case` op from `branch_index`, branch graphs and inputs. 1124 1125 Note that this modifies `branch_graphs` to make the inputs match, and to 1126 output all intermediates values so they're available for the gradient 1127 computation. 1128 1129 `branch_graphs` need not have the same input types, but they must 1130 have the same output types. 1131 1132 Args: 1133 branch_index: integer Tensor 1134 branch_graphs: List of FuncGraph 1135 branch_inputs: List of lists of Tensors to be passed to corresponding 1136 branch_graph as input. 1137 name: the name for the Case op. 1138 lower_using_switch_merge: Lower this op using switch merge ops (optional). 1139 1140 Returns: 1141 A list of Tensors which are the outputs of the Case op. Does not include 1142 added intermediate outputs. 1143 """ 1144 _make_indexed_slices_indices_types_match(_CASE, branch_graphs) 1145 _check_same_outputs(_CASE, branch_graphs) 1146 1147 # Add inputs to branch_graphs to make them match. Note that this modifies the 1148 # graphs in `branch_graphs`. 1149 case_inputs = _make_inputs_match(branch_graphs, branch_inputs) 1150 1151 stateful_ops = [] 1152 for bg in branch_graphs: 1153 stateful_ops.extend([ 1154 op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op) 1155 ]) 1156 1157 if stateful_ops: 1158 op_fn = gen_functional_ops.case 1159 else: 1160 op_fn = gen_functional_ops.stateless_case 1161 1162 # Create the Case op. 1163 with ops.control_dependencies( 1164 sum((list(bg.control_captures) for bg in branch_graphs), [])): 1165 1166 def _make_op(inputs): 1167 case_op, tensors = util.get_op_and_outputs(op_fn( 1168 branch_index, 1169 inputs, [t.dtype for t in branch_graphs[0].outputs], 1170 [util.create_new_tf_function(g) for g in branch_graphs], 1171 output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), 1172 name=name)) 1173 _copy_handle_data(tensors, *[g.outputs for g in branch_graphs]) 1174 if case_op is not None: 1175 util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) 1176 util.maybe_propagate_compile_time_consts_in_xla(case_op) 1177 _set_read_only_resource_inputs_attr(case_op, branch_graphs) 1178 # Prevent fetching since the variant outputs can't be fetched directly. 1179 case_op.graph.prevent_fetching(case_op) 1180 1181 # Store the branch graphs so they can be reused during the gradient 1182 # pass. 1183 for i, bg in enumerate(branch_graphs): 1184 bg.outer_graph = ops.get_default_graph() 1185 setattr(case_op, "_branch_graph_{}".format(i), bg) 1186 1187 return tensors 1188 tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs) 1189 1190 # Return identities for each output of the Case op, rather than the output of 1191 # the Case op directly. This makes pruning work if the output of switch_case() 1192 # is fetched: the lowering pass converts the Case outputs into IdentityN 1193 # outputs, which if fetched will cause all ops in the taken branch to be run 1194 # (since it takes all merge ops as input). After lowering, each output 1195 # identity op will end up with only the appropriate merge op as input. 1196 # TODO(b/79984175): this doesn't have to be a tuple once we covert to the 1197 # correct output structure 1198 tensors = [array_ops.identity(t) for t in tensors] 1199 1200 return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors) 1201 1202 1203def _set_read_only_resource_inputs_attr(op, branch_graphs): 1204 """Sets the list of resource inputs which are read-only. 1205 1206 This is used by AutomaticControlDependencies. 1207 1208 Args: 1209 op: If or Case Operation. 1210 branch_graphs: List of branch FuncGraphs. 1211 """ 1212 # The first entry in `op.inputs` is the predicate which is not passed to 1213 # branch graphs so len(branch_graph[i].inputs) == len(op.inputs) - 1. 1214 read_only_indices = set(range(len(op.inputs) - 1)) 1215 for branch_graph in branch_graphs: 1216 assert len(branch_graph.inputs) == len(op.inputs) - 1, "should never happen" 1217 if not read_only_indices: 1218 break 1219 branch_read_only_indices = acd.get_read_only_resource_input_indices_graph( 1220 branch_graph) 1221 read_only_indices = read_only_indices.intersection(branch_read_only_indices) 1222 # Convert indices in `branch_graphs[i].inputs` to `op.inputs`. 1223 read_only_indices = [i + 1 for i in read_only_indices] 1224 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 1225 sorted(read_only_indices)) 1226