1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""cond_v2 and gradient. 16 17This is a version of cond that emits a single If op, as well as the gradient 18function for If ops produced by cond_v2. This will eventually replace the 19current tf.cond implementation once it reaches feature and performance parity. 20""" 21 22import collections 23 24from tensorflow.core.framework import types_pb2 25from tensorflow.python.eager import backprop_util 26from tensorflow.python.framework import auto_control_deps 27from tensorflow.python.framework import auto_control_deps_utils as acd 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors_impl 31from tensorflow.python.framework import func_graph as func_graph_module 32from tensorflow.python.framework import indexed_slices 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.framework import type_spec 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_util 39from tensorflow.python.ops import control_flow_util_v2 as util 40from tensorflow.python.ops import default_gradient 41from tensorflow.python.ops import gen_dataset_ops 42from tensorflow.python.ops import gen_functional_ops 43from tensorflow.python.ops import gradients_util 44from tensorflow.python.ops import handle_data_util 45from tensorflow.python.ops import math_ops 46from tensorflow.python.util import nest 47 48 49# NOTE(skyewm): TensorFlow uses protected class methods and fields to signify 50# that they aren't part of the official public API. These protected members 51# often need to be used by implementation code however. Rather than litter the 52# code with pylint comments, we ignore protected access violations for 53# readability. 54# pylint: disable=protected-access 55 56_COND = 1 57_CASE = 2 58 59 60def cond_v2(pred, true_fn, false_fn, name="cond"): 61 """Like tf.cond, except emits a single If op.""" 62 if isinstance(pred, bool): 63 raise TypeError("pred must not be a Python bool", pred) 64 65 if not name: 66 name = "cond" 67 68 with ops.name_scope(name) as scope: 69 true_name = util.unique_fn_name(scope, "true") 70 false_name = util.unique_fn_name(scope, "false") 71 72 # Automatic control dependencies are added in defuns, but not in v1 73 # graphs. Propagate that behavior here. 74 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 75 pred = ops.convert_to_tensor(pred) 76 if (tensor_util.is_tf_type(pred) and 77 (pred.shape.dims is None or pred.shape.dims)): 78 pred = array_ops.squeeze_v2(pred) 79 80 true_graph = func_graph_module.func_graph_from_py_func( 81 true_name, 82 true_fn, [], {}, 83 func_graph=util.CondBranchFuncGraph( 84 true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 85 add_control_dependencies=add_control_dependencies, 86 op_return_value=pred) 87 false_graph = func_graph_module.func_graph_from_py_func( 88 false_name, 89 false_fn, [], {}, 90 func_graph=util.CondBranchFuncGraph( 91 false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 92 add_control_dependencies=add_control_dependencies, 93 op_return_value=pred) 94 95 verify_captures(_COND, [true_graph, false_graph]) 96 return _build_cond( 97 pred, 98 true_graph, 99 false_graph, 100 true_graph.external_captures, 101 false_graph.external_captures, 102 building_gradient=False, 103 name=scope) 104 105 106@ops.RegisterGradient("StatelessIf") 107@ops.RegisterGradient("If") 108def _IfGrad(op, *grads): # pylint: disable=invalid-name 109 """The gradient of an If op produced by cond_v2.""" 110 # Get the if operator (this logic handles the case where op is a MockOp) 111 if_op = op.outputs[0].op 112 true_graph, false_graph = get_func_graphs(if_op) 113 # Note: op.graph != ops.get_default_graph() when we are computing the gradient 114 # of a nested cond. 115 assert true_graph.outer_graph == if_op.graph 116 assert false_graph.outer_graph == if_op.graph 117 118 # Create grad functions that compute the gradient of the true/false forward 119 # graphs. These functions will capture tensors from the forward pass 120 # functions. 121 true_grad_graph = _create_grad_func( 122 true_graph, grads, util.unique_grad_fn_name(true_graph.name)) 123 false_grad_graph = _create_grad_func( 124 false_graph, grads, util.unique_grad_fn_name(false_graph.name)) 125 126 # Replaces output None grads with zeros if at least one branch has non-None 127 # grad at that index. 128 _create_zeros_for_none_grads([true_graph, false_graph], 129 [true_grad_graph, false_grad_graph]) 130 131 if (true_grad_graph.op_needs_rewrite or false_grad_graph.op_needs_rewrite): 132 # Modify 'op' to output the intermediates needed by the grad functions. Note 133 # that all needed intermediates are wrapped in optionals. Each optional 134 # intermediate output will have a value iff its corresponding branch is 135 # taken. 136 # NOTE(skyewm): if there are any active sessions, this modification to `op` 137 # may make them unrunnable! 138 139 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 140 # XLA does not yet support optionals, so output intermediates directly and 141 # make them match via FakeParams, which can be converted to zeros in XLA. 142 # TODO(skyewm,jpienaar): can XLA support optionals? 143 true_intermediates = true_grad_graph.xla_intermediates 144 false_intermediates = false_grad_graph.xla_intermediates 145 extra_true_outputs, extra_false_outputs = _make_intermediates_match_xla( 146 [true_graph, false_graph], [true_intermediates, false_intermediates]) 147 else: 148 true_intermediates = true_grad_graph.wrapped_intermediates 149 false_intermediates = false_grad_graph.wrapped_intermediates 150 # Make outputs match by adding none optionals. 151 extra_true_outputs, extra_false_outputs = _make_intermediates_match( 152 [true_graph, false_graph], [true_intermediates, false_intermediates]) 153 154 true_graph.outputs.extend(extra_true_outputs) 155 false_graph.outputs.extend(extra_false_outputs) 156 # TODO(skyewm): indicate it's an internal bug if this fails. 157 _check_same_outputs(_COND, [true_graph, false_graph]) 158 159 true_graph.name += "_rewritten" 160 false_graph.name += "_rewritten" 161 162 if_op._set_func_attr("then_branch", util.create_new_tf_function(true_graph)) 163 if_op._set_func_attr("else_branch", 164 util.create_new_tf_function(false_graph)) 165 if_op._set_type_list_attr("Tout", true_graph.output_types) 166 if_op._set_shape_list_attr("output_shapes", true_graph.output_shapes) 167 if_op._add_outputs( 168 [t.dtype for t in extra_true_outputs], 169 [t.shape for t in extra_true_outputs]) 170 171 # Resolve references to forward graph tensors in grad graphs and ensure 172 # they are in-scope, i.e., belong to one of outer graphs of the grad graph. 173 true_grad_inputs = _resolve_grad_inputs(true_graph, true_grad_graph) 174 false_grad_inputs = _resolve_grad_inputs(false_graph, false_grad_graph) 175 176 # This modifies true_grad_graph and false_grad_graph. 177 _make_output_composite_tensors_match(_COND, 178 [true_grad_graph, false_grad_graph]) 179 180 outputs = _build_cond( 181 if_op.inputs[0], 182 true_grad_graph, 183 false_grad_graph, 184 true_grad_inputs, 185 false_grad_inputs, 186 building_gradient=True, 187 ) 188 189 # The predicate has no gradient. 190 return [None] + outputs 191 192 193def _build_cond(pred, 194 true_graph, 195 false_graph, 196 true_inputs, 197 false_inputs, 198 building_gradient, 199 name=None): 200 """Creates an If op from the specified predicate, branch functions and inputs. 201 202 Note that this modifies true_graph and false_graph to make the inputs match, 203 and to output all intermediates values so they're available for the gradient 204 computation. 205 206 true_graph and false_graph need not have the same input types, but they must 207 have the same output types. 208 209 Args: 210 pred: boolean Tensor 211 true_graph: FuncGraph 212 false_graph: FuncGraph 213 true_inputs: a list of Tensors to be passed to true_graph as input. 214 false_inputs: a list of Tensors to be passed to false_graph as input. 215 building_gradient: Whether this is a gradient If op. 216 name: the name for the If op. 217 218 Returns: 219 A list of Tensors which are the outputs of the If op. Does not include added 220 intermediate outputs. 221 """ 222 _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph]) 223 _check_same_outputs(_COND, [true_graph, false_graph]) 224 225 # Add inputs to true_graph and false_graph to make them match. Note that 226 # this modifies true_graph and false_graph. 227 cond_inputs = _make_inputs_match([true_graph, false_graph], 228 [true_inputs, false_inputs]) 229 # We do not output intermediates of the gradient If op since this is just 230 # for backwards compatibility with existing code. 231 if not building_gradient and util.output_all_intermediates(): 232 # Add all intermediate tensors as function outputs so they're available for 233 # the gradient computation. Since the outputs of the two functions must 234 # match, we wrap all the intermediates in optionals. Each intermediate 235 # output will have a value iff its corresponding branch is taken. 236 237 true_intermediates = _get_intermediates(true_graph) 238 false_intermediates = _get_intermediates(false_graph) 239 240 # Wrap intermediates in optionals. 241 wrapped_true_intermediates = _wrap_intermediates(true_graph, 242 true_intermediates) 243 wrapped_false_intermediates = _wrap_intermediates(false_graph, 244 false_intermediates) 245 246 # Make outputs match by adding none optionals. 247 extra_true_outputs, extra_false_outputs = _make_intermediates_match( # pylint: disable=unbalanced-tuple-unpacking 248 [true_graph, false_graph], 249 [wrapped_true_intermediates, wrapped_false_intermediates]) 250 251 true_graph.outputs.extend(extra_true_outputs) 252 false_graph.outputs.extend(extra_false_outputs) 253 _check_same_outputs(_COND, [true_graph, false_graph]) 254 255 # Create the If op. 256 with ops.control_dependencies( 257 list(true_graph.control_captures) + list(false_graph.control_captures)): 258 true_stateful_ops = [ 259 op for op in true_graph.get_operations() if op._is_stateful 260 ] 261 false_stateful_ops = [ 262 op for op in false_graph.get_operations() if op._is_stateful 263 ] 264 if (true_stateful_ops or false_stateful_ops): 265 op_fn = gen_functional_ops._if 266 else: 267 op_fn = gen_functional_ops.stateless_if 268 269 def _make_op(inputs): 270 if_op, tensors = util.get_op_and_outputs(op_fn( 271 pred, 272 inputs, [t.dtype for t in true_graph.outputs], 273 util.create_new_tf_function(true_graph), 274 util.create_new_tf_function(false_graph), 275 output_shapes=_get_output_shapes(true_graph.outputs, 276 false_graph.outputs), 277 name=name)) 278 _copy_handle_data(tensors, true_graph.outputs, false_graph.outputs) 279 # `if_op` is None if this is a `StatelessIf` op with no outputs. 280 if if_op is not None: 281 # The true and false graphs have already been created, and we need that 282 # to happen before we know which tensors will be captured and so whether 283 # to wrap the cond in a tf.function. Post-hoc mutation of the branch 284 # `outer_graph` properties seems like the only option if we want to 285 # conditionally wrap in a function. 286 true_graph.outer_graph = ops.get_default_graph() 287 false_graph.outer_graph = ops.get_default_graph() 288 if_op._true_graph = true_graph 289 if_op._false_graph = false_graph 290 util.maybe_set_lowering_attr(if_op) 291 util.maybe_propagate_compile_time_consts_in_xla(if_op) 292 _set_read_only_resource_inputs_attr(if_op, [true_graph, false_graph]) 293 # Prevent fetching since the variant outputs can't be fetched directly. 294 if_op.graph.prevent_fetching(if_op) 295 return tensors 296 tensors = util.run_as_function_for_tape_gradients(_make_op, cond_inputs) 297 298 # Return identities for each output of the If op, rather than the output of 299 # the If op directly. This makes pruning work if the output of cond() is 300 # fetched: the lowering pass converts the If outputs into IdentityN outputs, 301 # which if fetched will cause all ops in the taken branch to be run (since 302 # it takes all merge ops as input). After lowering, each output identity op 303 # will end up with only the appropriate merge op as input. 304 # TODO(b/79984175): this doesn't have to be a tuple once we covert to the 305 # correct output structure 306 tensors = [array_ops.identity(t) for t in tensors] 307 308 structured_output_specs = _get_compatible_structured_output_specs(true_graph, 309 false_graph) 310 return _pack_sequence_as(structured_output_specs, tensors) 311 312 313def get_func_graphs(op): 314 """Returns `FuncGraph`s for the input op branches. 315 316 Args: 317 op: The If or Case Operation. 318 319 Returns: 320 A tuple of the `FuncGraph`s of the then_branch and else_branch (all branches 321 for Case). 322 """ 323 324 def _get_func_graph_for_branch(name_attr_list, cached_attr_name=None): 325 """Generates and returns a FuncGraph for the given branch.""" 326 func_graph = None 327 if cached_attr_name is not None: 328 func_graph = getattr(op, cached_attr_name, None) 329 inputs = op.inputs[1:] # First input is pred. 330 if func_graph is None: 331 input_shapes = [t.shape for t in inputs] 332 func_graph = util.get_func_graph(op, input_shapes, name_attr_list.name) 333 for external_t, internal_t in zip(inputs, func_graph.inputs): 334 handle_data_util.copy_handle_data(external_t, internal_t) 335 func_graph.reset_captures(zip(inputs, func_graph.inputs)) 336 # Link the op so that the gradient code can use it. 337 func_graph._forward_cond = op 338 return func_graph 339 340 if op.type in ["If", "StatelessIf"]: 341 return (_get_func_graph_for_branch( 342 op.get_attr("then_branch"), "_true_graph"), 343 _get_func_graph_for_branch( 344 op.get_attr("else_branch"), "_false_graph")) 345 elif op.type in ["Case", "StatelessCase"]: 346 return [_get_func_graph_for_branch(branch_fn, "_branch_graph_{}".format(i)) 347 for i, branch_fn in enumerate(op.get_attr("branches"))] 348 else: 349 raise ValueError("Unsupported op type: {}".format(op.type)) 350 351 352def _get_compatible_structured_output_specs(true_graph, false_graph): 353 """Returns the most specific compatible specs of graph structured outputs.""" 354 return nest.map_structure(_get_compatible_spec, 355 true_graph.structured_outputs, 356 false_graph.structured_outputs) 357 358 359def _get_compatible_spec(value_or_spec1, value_or_spec2): 360 """Returns the most specific compatible spec. 361 362 Args: 363 value_or_spec1: A TypeSpecs or a value that has a defined TypeSpec. 364 value_or_spec2: A TypeSpecs or a value that has a defined TypeSpec. 365 366 Returns: 367 The most specific compatible TypeSpecs of the input. 368 369 Raises: 370 ValueError: If value_or_spec1 is not compatible with value_or_spec2. 371 """ 372 spec1 = _get_spec_for(value_or_spec1) 373 spec2 = _get_spec_for(value_or_spec2) 374 375 # pylint: disable=protected-access 376 common = spec1._without_tensor_names().most_specific_common_supertype( 377 [spec2._without_tensor_names()]) 378 if common is None: 379 raise TypeError(f"No common supertype of {spec1} and {spec2}.") 380 return common 381 382 383def _get_spec_for(value_or_spec): 384 """Returns TypeSpec of a value or itself if it is a TypeSpec already.""" 385 if isinstance(value_or_spec, type_spec.TypeSpec): 386 return value_or_spec 387 return type_spec.type_spec_from_value(value_or_spec) 388 389 390def _grad_fn(func_graph, grads): 391 """The gradient function for each conditional branch. 392 393 This function builds the gradient graph of the corresponding forward-pass 394 conditional branch in `func_graph`. This is done by differentiating 395 func_graph's outputs w.r.t. its inputs. 396 397 Args: 398 func_graph: FuncGraph. The corresponding forward-pass function. 399 grads: The list of input gradient Tensors. 400 401 Returns: 402 The output gradient Tensors. 403 """ 404 # Filter out untrainable function outputs. 405 # NOTE(skyewm): If we don't do this, the untrainable tensors can sometimes 406 # cause _GradientsHelper to raise an exception (e.g. the implementation 407 # doesn't expect 'ys' to contain boolean tensors). 408 assert len(func_graph.outputs) == len(grads) 409 ys = [] 410 grad_ys = [] 411 for y, grad_y in zip(func_graph.outputs, grads): 412 if not backprop_util.IsTrainable(y): 413 continue 414 ys.append(y) 415 grad_ys.append(grad_y) 416 417 # Build the gradient graph. Note that this builds the gradient computation of 418 # func_graph in the current graph, which requires capturing tensors from 419 # func_graph. The captured func_graph tensors are resolved to external tensors 420 # in _resolve_grad_inputs. 421 result = gradients_util._GradientsHelper( 422 ys, func_graph.inputs, grad_ys=grad_ys, 423 src_graph=func_graph) 424 425 return result 426 427 428def _create_grad_func(func_graph, grads, name): 429 """Returns the FuncGraph representation of _grad_fn.""" 430 return func_graph_module.func_graph_from_py_func( 431 name, 432 lambda: _grad_fn(func_graph, grads), [], {}, 433 func_graph=_CondGradFuncGraph(name, func_graph)) 434 435 436def _resolve_grad_inputs(cond_graph, grad_graph): 437 """Returns the tensors to pass as inputs to `grad_graph`. 438 439 The `grad_graph` may have external references to 440 1. Its outer graph containing the input gradients. These references are kept 441 as is. 442 2. Tensors in the forward pass graph. These tensors may not be "live" 443 when the gradient is being computed. We replace such references by their 444 corresponding tensor in `cond_graph.outer_graph`. In the case of nested 445 control flow or functions, the gradient logic handling 446 `grad_graph.outer_graph` will make sure the tensor from 447 `cond_graph.outer_graph` is also correctly captured. 448 449 Args: 450 cond_graph: FuncGraph. The forward-pass function. 451 grad_graph: FuncGraph. The gradients function. 452 453 Returns: 454 A list of inputs tensors to be passed to grad_graph. 455 """ 456 new_inputs = [] 457 458 for t in grad_graph.external_captures: 459 # `t` must either be in `grad_graph.outer_graph` or in the forward 460 # `cond_graph`. 461 if t.graph != grad_graph.outer_graph: 462 assert t.graph == cond_graph 463 # `internal_captures` are not treated as intermediates and hence not added 464 # to If op outputs. So we get the outer tensor corresponding to those 465 # from the list of `external_captures`. 466 for i, output in enumerate(t.graph.outputs): 467 if output is t: 468 t = t.graph._forward_cond.outputs[i] 469 break 470 else: 471 for i, output in enumerate(t.graph.internal_captures): 472 if output is t: 473 t = t.graph.external_captures[i] 474 break 475 else: 476 raise ValueError("Could not find external tensor capture {tensor} in " 477 "captures or outputs".format(tensor=t)) 478 479 # Note: We rely on the capturing logic of the gradient If op graph to 480 # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2 481 # and while_v2 handle this while building their gradient functions. 482 assert t.graph == cond_graph.outer_graph 483 new_inputs.append(t) 484 485 return new_inputs 486 487 488def _get_intermediates(func_graph): 489 """Returns intermediate tensors of `func_graph` for gradient computation.""" 490 intermediates = [] 491 for op in func_graph.get_operations(): 492 for t in op.outputs: 493 if t in func_graph.inputs: continue 494 if t in func_graph.outputs: continue 495 if t.dtype is dtypes.resource: 496 continue 497 # Accumulating mutexes can cause deadlock. 498 if op.type == "MutexLock": 499 continue 500 intermediates.append(t) 501 return intermediates 502 503 504def _make_intermediates_match(branch_graphs, branch_optionals): 505 """Returns new optionals lists that have matching signatures. 506 507 This is done by mirroring each list in the other using none optionals. 508 There is no merging of like optionals. 509 510 Args: 511 branch_graphs: `list` of `FuncGraph`. 512 branch_optionals: `list` of `list`s of optional `Tensor`s from other 513 branch_graphs 514 515 Returns: 516 A `list` of `list`s of `Tensor`s for each branch_graph. Each list has the 517 same number of `Tensor`s, all of which will be optionals of the same 518 shape/type. 519 """ 520 new_branch_optionals = [] 521 # Since the intermediates are optionals with dtype variant, we only need 522 # enough room for the longest list of intermediates. 523 intermediates_size = max(len(o) for o in branch_optionals) 524 for i, branch_graph in enumerate(branch_graphs): 525 other_optionals = _create_none_optionals( 526 branch_graph, intermediates_size - len(branch_optionals[i])) 527 new_branch_optionals.append(branch_optionals[i] + other_optionals) 528 return new_branch_optionals 529 530 531def _make_intermediates_match_xla(branch_graphs, branch_intermediates): 532 """Like _make_intermediates_match but for the XLA case.""" 533 new_branch_intermediates = [] 534 for i, branch_graph in enumerate(branch_graphs): 535 other_fakeparams = _create_fakeparams( 536 branch_graph, 537 sum((bi for bi in branch_intermediates 538 if bi is not branch_intermediates[i]), [])) 539 num_preceding = sum(len(bi) for bi in branch_intermediates[:i]) 540 new_branch_intermediates.append(other_fakeparams[:num_preceding] + 541 branch_intermediates[i] + 542 other_fakeparams[num_preceding:]) 543 return new_branch_intermediates 544 545 546def _make_inputs_match(branch_graphs, branch_inputs): 547 """Modifies branch_graphs so they have the same input signature. 548 549 This method reorders and/or adds parameters to each graph in branch_graphs so 550 they have the same input signature, and updates the 'inputs' and 'captured' 551 fields of each graph accordingly. It uses the input tensors from the outer 552 graph to avoid duplicating shared arguments. 553 554 Args: 555 branch_graphs: a `list` of `FuncGraph` 556 branch_inputs: a `list` of `list`s of `Tensor`s in the outer graph. The 557 inputs for the corresponding graph in `branch_graphs`. 558 559 Returns: 560 A new list of Tensors from the outer graph that are the new inputs for each 561 branch_graph. This is a deduped version of `sum(branch_inputs)`. 562 """ 563 assert len(branch_graphs) == len(branch_inputs) 564 added_inputs = set() 565 new_inputs = [] 566 for branch_in in branch_inputs: 567 for tensor in branch_in: 568 tensor_id = ops.tensor_id(tensor) 569 if tensor_id not in added_inputs: 570 added_inputs.add(tensor_id) 571 new_inputs.append(tensor) 572 573 for branch_graph, branch_in in zip(branch_graphs, branch_inputs): 574 input_ids = [ops.tensor_id(t) for t in branch_in] 575 branch_input_to_param = dict(zip(input_ids, branch_graph.inputs)) 576 input_list = [] 577 for in_t in new_inputs: 578 param = branch_input_to_param.get(ops.tensor_id(in_t)) 579 if param is None: 580 param = _create_dummy_input(branch_graph, in_t) 581 input_list.append(param) 582 583 branch_graph.inputs = input_list 584 585 # Rewrite the FuncGraphs' state to reflect the new inputs. 586 branch_graph.reset_captures(zip(new_inputs, branch_graph.inputs)) 587 588 return new_inputs 589 590 591def _create_zeros_for_none_grads(forward_graphs, grad_graphs): 592 """Creates zeros for None out grads if at least one branch has non-None grad. 593 594 Args: 595 forward_graphs: List of forward FuncGraphs. 596 grad_graphs: List of grad FuncGraphs. 597 """ 598 assert len(forward_graphs) == len(grad_graphs) 599 branch_outputs = [g.structured_outputs for g in grad_graphs] 600 num_outputs_per_branch = [len(outs) for outs in branch_outputs] 601 assert len(set(num_outputs_per_branch)) == 1, num_outputs_per_branch 602 for output_idx, branch_outs in enumerate(zip(*branch_outputs)): 603 if (any(t is None for t in branch_outs) and 604 any(t is not None for t in branch_outs)): 605 for branch_index, t in enumerate(branch_outs): 606 if t is None: 607 with grad_graphs[branch_index].as_default(): 608 zeros = default_gradient.zeros_like( 609 forward_graphs[branch_index].inputs[output_idx]) 610 grad_graphs[branch_index].structured_outputs[output_idx] = zeros 611 612 for grad_graph in grad_graphs: 613 grad_graph.outputs = [ 614 t for t in func_graph_module.flatten(grad_graph.structured_outputs) 615 if t is not None 616 ] 617 618 619def _make_output_composite_tensors_match(op_type, branch_graphs): 620 """Modifies each branch_graph's outputs to have the same output signature. 621 622 Currently the only transformation implemented is turning a Tensor into an 623 equivalent IndexedSlices if the other branch returns an IndexedSlices. 624 Updates branch_graph.{outputs,structured_outputs} for each branch_graph in 625 branch_graphs. 626 627 Args: 628 op_type: _COND or _CASE 629 branch_graphs: `list` of `FuncGraph` 630 631 Raises: 632 TypeError: if a set of outputs cannot be rewritten. 633 """ 634 # Note: since this is only used for gradient graphs, we do not expect the 635 # outputs to be structured (e.g. nested lists), and thus do not need to use 636 # nest.flatten, etc. 637 assert branch_graphs 638 branch_outputs = [g.structured_outputs for g in branch_graphs] 639 outputs_per_branch = list(len(outs) for outs in branch_outputs) 640 assert len(set(outputs_per_branch)) == 1, outputs_per_branch 641 642 for output_idx, branch_outs in enumerate(zip(*branch_outputs)): 643 if len(set(type(out) for out in branch_outs)) == 1: 644 continue 645 if not any( 646 isinstance(out, indexed_slices.IndexedSlices) for out in branch_outs): 647 continue 648 for branch_idx, branch_out in enumerate(branch_outs): 649 if isinstance(branch_out, indexed_slices.IndexedSlices): 650 continue 651 elif isinstance(branch_out, ops.Tensor): 652 with branch_graphs[branch_idx].as_default(): 653 branch_outputs[branch_idx][output_idx] = math_ops._as_indexed_slices( 654 branch_out) 655 else: 656 raise TypeError( 657 "Cannot reconcile {op_name} {output_idx}-th outputs:\n" 658 " outputs from all branches: {outputs}".format( 659 op_name="tf.cond" if op_type == _COND else "tf.switch_case", 660 output_idx=output_idx, 661 outputs=branch_outs)) 662 663 for branch_graph, branch_outs in zip(branch_graphs, branch_outputs): 664 branch_graph.structured_outputs = branch_outs 665 branch_graph.outputs = [ 666 t for t in func_graph_module.flatten(branch_outs) if t is not None 667 ] 668 669 670def _make_indexed_slices_indices_types_match(op_type, branch_graphs): 671 """Match dtype of IndexedSlices.indices in outputs of branch_graphs.""" 672 assert branch_graphs 673 # Indices of `IndexedSlices.indices` tensors in `branch_graphs[i].outputs`. 674 indexed_slice_indices = [] 675 current_index = 0 676 # Note that this still contains Nones. We leave those in so that error 677 # messages contain the correct indices. We handle the Nones later when 678 # updating `current_index`. 679 branch_outputs_flat_with_composites = [ 680 nest.flatten(branch_graph.structured_outputs, expand_composites=False) 681 for branch_graph in branch_graphs 682 ] 683 outs_per_branch = [len(outs) for outs in branch_outputs_flat_with_composites] 684 assert len(set(outs_per_branch)) == 1, outs_per_branch 685 # Store indices of IndexedSlices.indices in `indexed_slice_indices`. 686 for output_idx, branch_outs in enumerate( 687 zip(*branch_outputs_flat_with_composites)): 688 if len( 689 set( 690 isinstance(out, indexed_slices.IndexedSlices) 691 for out in branch_outs)) != 1: 692 raise TypeError("Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n" 693 " branches returned: {outputs}".format( 694 op_name="cond" if op_type == _COND else "switch_case", 695 output_idx=output_idx, 696 outputs=branch_outs)) 697 if isinstance(branch_outs[0], indexed_slices.IndexedSlices): 698 # indices is the second component of the composite tensor. 699 indexed_slice_indices.append(current_index + 1) 700 if nest.is_nested_or_composite(branch_outs[0]): 701 current_index += len(nest.flatten(branch_outs[0], expand_composites=True)) 702 elif branch_outs[0] is not None: 703 # `FuncGraph.outputs` does not contain Nones so no need to update the 704 # counter in that case. 705 current_index += 1 706 707 if not indexed_slice_indices: 708 return 709 710 # `FuncGraph.outputs` is the flattened `FuncGraph.structured_outputs` minus 711 # the Nones. 712 if current_index != len(branch_graphs[0].outputs): 713 raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n" 714 "Expected: %i\n" 715 "Actual: %i" % 716 (current_index, len(branch_graphs[0].outputs))) 717 718 # Cast indices with mismatching types to int64. 719 for index in indexed_slice_indices: 720 if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64) 721 for bg in branch_graphs): 722 raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " 723 "Found: %s" % 724 str([bg.outputs[index].dtype for bg in branch_graphs])) 725 if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1: 726 for branch_graph in branch_graphs: 727 if branch_graph.outputs[index].dtype == dtypes.int32: 728 with branch_graph.as_default(): 729 branch_graph.outputs[index] = math_ops.cast( 730 branch_graph.outputs[index], dtypes.int64) 731 732 for branch_graph in branch_graphs: 733 branch_graph.structured_outputs = _pack_sequence_as( 734 branch_graph.structured_outputs, branch_graph.outputs) 735 736 737def _pack_sequence_as(structured_outputs, op_outputs): 738 """Packs the outputs of the gradient If/Case op. 739 740 The branch functions may contain None's in the list of `structured_outputs`. 741 `op_outputs` has those outputs missing. So we need to add those Nones to the 742 list of `op_outputs` and then pack it in the same structure as 743 `structured_outputs`. 744 745 Args: 746 structured_outputs: structured_outputs from one of the branch functions. 747 op_outputs: List of output tensors of the op. 748 749 Returns: 750 `op_outputs` packed like `structured_outputs`. 751 """ 752 outputs_with_nones = [] 753 counter = 0 754 for output in nest.flatten(structured_outputs, expand_composites=True): 755 if output is None: 756 outputs_with_nones.append(None) 757 else: 758 outputs_with_nones.append(op_outputs[counter]) 759 counter += 1 760 return func_graph_module.pack_sequence_as(structured_outputs, 761 outputs_with_nones) 762 763 764def _wrap_intermediates(func_graph, intermediates): 765 with func_graph.as_default(): 766 return [gen_dataset_ops.optional_from_value([t]) for t in intermediates] 767 768 769def _create_dummy_input(func_graph, template_tensor): 770 """Creates tensors in func_graph to represent template_tensors. 771 772 Args: 773 func_graph: FuncGraph. 774 template_tensor: a tensor in the outer graph. 775 776 Returns: 777 A tensor in func_graph. 778 """ 779 with func_graph.as_default(): 780 return array_ops.placeholder( 781 template_tensor.dtype, shape=template_tensor.shape) 782 783 784def _create_none_optionals(func_graph, n): 785 """Creates `n` `None` optionals in func_graph. 786 787 Args: 788 func_graph: FuncGraph. 789 n: `int` the number of `None` optionals to make. 790 791 Returns: 792 A list of tensors in func_graph. 793 """ 794 with func_graph.as_default(): 795 return [gen_dataset_ops.optional_none() for _ in range(n)] 796 797 798def _create_fakeparams(func_graph, template_tensors): 799 """Create FakeParams for the XLA case.""" 800 with func_graph.as_default(): 801 return [gen_functional_ops.fake_param(dtype=t.dtype, shape=t.shape) 802 for t in template_tensors] 803 804 805def _check_same_outputs(op_type, graphs): 806 """Raises an error if `graphs` have different outputs.""" 807 808 def error(branch_idx, error_detail): 809 raise TypeError( 810 "{b0_name} and {bn_name} arguments to {op_name} must have the same " 811 "number, type, and overall structure of return values.\n" 812 "\n" 813 "{b0_name} output: {b0_out}\n" 814 "{bn_name} output: {bn_out}\n" 815 "\n" 816 "Error details:\n" 817 "{detail}".format( 818 b0_name="true_fn" if op_type == _COND else "branches[0]", 819 bn_name=("false_fn" if op_type == _COND else 820 "branches[{}]".format(branch_idx)), 821 op_name="tf.cond" if op_type == _COND else "tf.switch_case", 822 b0_out=graphs[0].structured_outputs, 823 bn_out=graphs[branch_idx].structured_outputs, 824 detail=error_detail)) 825 826 for b in range(1, len(graphs)): 827 try: 828 nest.assert_same_structure( 829 graphs[0].structured_outputs, 830 graphs[b].structured_outputs, 831 expand_composites=True) 832 except (ValueError, TypeError) as e: 833 error(b, str(e)) 834 835 op_type_str = "cond" if op_type == _COND else "case" 836 if len(graphs[0].outputs) != len(graphs[b].outputs): 837 raise ValueError("Lengths of branch outputs of {op_type} must match.\n" 838 "len(graphs[0].outputs): {len_0}\n" 839 "len(graphs[{b}].outputs): {len_b}\n".format( 840 op_type=op_type_str, 841 len_0=len(graphs[0].outputs), 842 b=b, 843 len_b=len(graphs[b].outputs))) 844 for b0_out, bn_out in zip(graphs[0].outputs, graphs[b].outputs): 845 if b0_out.dtype != bn_out.dtype: 846 error(b, "%s and %s have different types" % (b0_out, bn_out)) 847 848 849def _get_output_shapes(*branch_graph_outputs): 850 output_shapes = [] 851 for out_by_branch in zip(*branch_graph_outputs): 852 shape = out_by_branch[0].shape 853 for other_out in out_by_branch[1:]: 854 shape = shape.most_specific_compatible_shape(other_out.shape) 855 output_shapes.append(shape) 856 return output_shapes 857 858 859def _copy_handle_data(external_tensors, *branch_graph_outputs): 860 """Combines shapes in handle data and sets metadata on `external_tensors`.""" 861 for tensors in zip(external_tensors, *branch_graph_outputs): 862 external = tensors[0] 863 internal = tensors[1:] 864 internal_handle_data = [] 865 for tensor in internal: 866 handle_data = handle_data_util.get_resource_handle_data(tensor) 867 # NOTE: Assumes handle data has only one ShapeAndType entry. It's 868 # unclear how to combine different lengths across branches. 869 if not handle_data.is_set or len(handle_data.shape_and_type) != 1: 870 break 871 internal_handle_data.append(handle_data) 872 else: # There is handle data, so we need to combine it. 873 combined_shape = tensor_shape.TensorShape(None) 874 combined_dtype = None 875 for handle_data in internal_handle_data: 876 handle_shape = tensor_shape.TensorShape( 877 handle_data.shape_and_type[0].shape) 878 combined_shape = combined_shape.most_specific_compatible_shape( 879 handle_shape) 880 if combined_dtype is None: 881 combined_dtype = handle_data.shape_and_type[0].dtype 882 elif handle_data.shape_and_type[0].dtype != combined_dtype: 883 # Variants from different branches have different dtypes. The 884 # combined variant has no static dtype. 885 combined_dtype = types_pb2.DT_INVALID 886 combined_handle_data = internal_handle_data[0] 887 combined_handle_data.shape_and_type[0].shape.CopyFrom( 888 combined_shape.as_proto()) 889 combined_handle_data.shape_and_type[0].dtype = combined_dtype 890 handle_data_util.set_handle_data(external, combined_handle_data) 891 892 893def verify_captures(op_type, branch_graphs): 894 """Verify that a branch's tensor is not accessed in another branch fn.""" 895 # Note: It is technically not possible for lower-branch_index branches to 896 # capture tensors from higher-branch_index branches, because of the order of 897 # branch graph construction, but we check all for completeness and to 898 # guard against potential future changes. 899 other_branch_graphs = {g: i for i, g in enumerate(branch_graphs)} 900 for i, branch_graph in enumerate(branch_graphs): 901 for t in branch_graph.external_captures: 902 if not isinstance(t, ops.EagerTensor) and t.graph in other_branch_graphs: 903 branch_names = ["true_fn", "false_fn"] if op_type == _COND else [ 904 "branch {}".format(bi) for bi in range(len(branch_graphs))] 905 raise ValueError( 906 "Tensor {tname} in {b0name} is accessed from {b1name}.".format( 907 tname=t.name, 908 b0name=branch_names[other_branch_graphs[t.graph]], 909 b1name=branch_names[i])) 910 911 912class _CondGradFuncGraph(util.CondBranchFuncGraph): 913 """FuncGraph for the gradient function of the branch of an If op. 914 915 Handles wrapping and unwrapping intermediate values that are captured by the 916 gradient computation in optionals. 917 918 Attributes: 919 op_needs_rewrite: True if any intermediates were captured, meaning the 920 forward If op needs to be written to output the wrapped intermediates. 921 """ 922 923 def __init__(self, name, forward_graph): 924 super(_CondGradFuncGraph, self).__init__( 925 name, collections=ops.get_default_graph()._collections) # pylint: disable=protected-access 926 self.op_needs_rewrite = False 927 self._forward_graph = forward_graph 928 # Maps from forward intermediate tensor -> the unwrapped captured 929 # intermediate. 930 self._indirect_captures = {} 931 # Maps unwrapped intermediate -> optional-wrapped intermediate in the 932 # forward graph. 933 self._wrapped_intermediates = collections.OrderedDict() 934 # Raw intermediates captured from the forward graph. Populated iff we're in 935 # an XLA context. 936 self._xla_intermediates = [] 937 # Maps forward intermediate constant valued tensor's id to the constant 938 # created in this graph for that tensor. 939 self._captured_constants = {} 940 941 @property 942 def wrapped_intermediates(self): 943 """The optional-wrapped intermediates captured from the forward graph.""" 944 return list(self._wrapped_intermediates.values()) 945 946 @property 947 def xla_intermediates(self): 948 """Raw intermediates captured from the forward graph if XLA is enabled.""" 949 return self._xla_intermediates 950 951 def _capture_helper(self, tensor, name): 952 if (tensor.graph is not self._forward_graph or 953 any(tensor is t for t in self._forward_graph.inputs) or 954 any(tensor is t for t in self._forward_graph.outputs)): 955 return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) 956 957 tensor_id = ops.tensor_id(tensor) 958 959 # If `tensor` is a graph-building time constant, we create a constant with 960 # the same value in the backward graph instead of capturing it. 961 if tensor_id in self._captured_constants: 962 return self._captured_constants[tensor_id] 963 elif constant_op.is_constant(tensor): 964 self._captured_constants[tensor_id] = constant_op.constant( 965 tensor_util.constant_value(tensor), dtype=tensor.dtype) 966 return self._captured_constants[tensor_id] 967 968 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 969 # XLA does not yet support optionals, so capture intermediates directly. 970 # TODO(skyewm,jpienaar): can XLA support optionals? 971 if all(tensor is not capture for capture in self.external_captures): 972 self.xla_intermediates.append(tensor) 973 self.op_needs_rewrite = True 974 return super(_CondGradFuncGraph, self)._capture_helper(tensor, name) 975 976 captured_tensor = self._indirect_captures.get(tensor_id) 977 if captured_tensor is not None: 978 return captured_tensor 979 980 # 'tensor' is an uncaptured intermediate in the forward graph. 981 # If it is not a resource, we wrap it in an optional in the forward graph 982 # and capture the optional normally. We then unwrap the captured optional 983 # value in the gradient graph to get the raw intermediate value. 984 # If it is a resource, we trace the resource up to the input in the forward 985 # graph and capture that. 986 987 if tensor.dtype == dtypes.resource: 988 # Index of the forward graph input corresponding to the resource tensor. 989 index = util.resource_input_index( 990 tensor.name, [t.name for t in self._forward_graph.inputs], 991 {op.name: op.node_def for op in self._forward_graph.get_operations()}, 992 self._forward_graph._functions) 993 # This gets mapped to the corresponding If op input in 994 # `_resolve_grad_inputs`. 995 captured_tensor = super(_CondGradFuncGraph, self)._capture_helper( 996 self._forward_graph.inputs[index], name) 997 else: 998 if tensor_id not in self._wrapped_intermediates: 999 # If the gradient has already been computed for this If op, 'tensor' may 1000 # already be wrapped. 1001 for consumer in tensor.consumers(): 1002 if (consumer.type == "OptionalFromValue" and 1003 any(consumer.outputs[0] is output 1004 for output in self._forward_graph.outputs)): 1005 optional = consumer.outputs[0] 1006 break 1007 else: 1008 # 'tensor' hasn't been wrapped, do it now. 1009 with self._forward_graph.as_default(): 1010 optional = gen_dataset_ops.optional_from_value([tensor]) 1011 self.op_needs_rewrite = True 1012 self._wrapped_intermediates[tensor_id] = optional 1013 1014 optional = self._wrapped_intermediates[tensor_id] 1015 captured_optional = super(_CondGradFuncGraph, 1016 self)._capture_helper(optional, name) 1017 captured_tensor = gen_dataset_ops.optional_get_value( 1018 captured_optional, [tensor.dtype], [tensor.shape])[0] 1019 1020 self._indirect_captures[tensor_id] = captured_tensor 1021 return captured_tensor 1022 1023 1024def indexed_case(branch_index, 1025 branch_fns, 1026 name="indexed_case", 1027 lower_using_switch_merge=None): 1028 """Like conv_v2, except emits a Case op instead of an If.""" 1029 if isinstance(branch_index, int): 1030 raise TypeError("branch_index must not be a Python int", branch_index) 1031 1032 with ops.name_scope(name) as scope: 1033 branch_names = [ 1034 util.unique_fn_name(scope, "branch{}".format(b)) 1035 for b in range(len(branch_fns)) 1036 ] 1037 1038 # Automatic control dependencies are added in defuns, but not in v1 1039 # graphs. Propagate that behavior here. 1040 add_control_dependencies = ops.get_default_graph()._add_control_dependencies 1041 branch_index = ops.convert_to_tensor(branch_index, name="branch_index") 1042 1043 branch_graphs = [] 1044 for branch_name, branch_fn in zip(branch_names, branch_fns): 1045 branch_graphs.append( 1046 func_graph_module.func_graph_from_py_func( 1047 branch_name, 1048 branch_fn, 1049 [], 1050 {}, 1051 func_graph=util.CondBranchFuncGraph( 1052 branch_name, 1053 collections=ops.get_default_graph()._collections), # pylint: disable=protected-access 1054 add_control_dependencies=add_control_dependencies, 1055 op_return_value=branch_index)) 1056 1057 verify_captures(_CASE, branch_graphs) 1058 return _build_case( 1059 branch_index, 1060 branch_graphs, [g.external_captures for g in branch_graphs], 1061 name=scope, 1062 lower_using_switch_merge=lower_using_switch_merge) 1063 1064 1065@ops.RegisterGradient("Case") 1066@ops.RegisterGradient("StatelessCase") 1067def _CaseGrad(op, *grads): # pylint: disable=invalid-name 1068 """The gradient of a Case op produced by tf.switch_case.""" 1069 # Get the Case operator (this logic handles the case where op is a MockOp) 1070 case_op = op.outputs[0].op 1071 branch_graphs = get_func_graphs(case_op) 1072 assert branch_graphs 1073 # Note: op.graph != ops.get_default_graph() when we are computing the gradient 1074 # of a nested cond. 1075 for branch_graph in branch_graphs: 1076 assert branch_graph.outer_graph == case_op.graph 1077 1078 # Create grad functions that compute the gradient of the branch forward 1079 # graphs. These functions will capture tensors from the forward pass 1080 # functions. 1081 branch_grad_graphs = [] 1082 for branch_graph in branch_graphs: 1083 branch_grad_graphs.append( 1084 _create_grad_func(branch_graph, grads, 1085 util.unique_grad_fn_name(branch_graph.name))) 1086 # Replaces output None grads with zeros if at least one branch has non-None 1087 # grad at that index. 1088 _create_zeros_for_none_grads(branch_graphs, branch_grad_graphs) 1089 1090 if any(g.op_needs_rewrite for g in branch_grad_graphs): 1091 # Modify 'op' to output the intermediates needed by the grad functions. Note 1092 # that all needed intermediates are wrapped in optionals. Each optional 1093 # intermediate output will have a value iff its corresponding branch is 1094 # taken. 1095 # NOTE(bjp): if there are any active sessions, this modification to `op` 1096 # may make them unrunnable! 1097 1098 if control_flow_util.GraphOrParentsInXlaContext(ops.get_default_graph()): 1099 # XLA does not yet support optionals, so output intermediates directly and 1100 # make them match via FakeParams, which can be converted to zeros in XLA. 1101 # TODO(bjp,jpienaar): can XLA support optionals? 1102 branches_intermediates = [ 1103 branch_grad_graph.xla_intermediates 1104 for branch_grad_graph in branch_grad_graphs 1105 ] 1106 extra_branch_outputs = _make_intermediates_match_xla( 1107 branch_graphs, branches_intermediates) 1108 else: 1109 branch_intermediates = [ 1110 g.wrapped_intermediates for g in branch_grad_graphs 1111 ] 1112 # Make outputs match by adding none optionals. 1113 extra_branch_outputs = _make_intermediates_match(branch_graphs, 1114 branch_intermediates) 1115 1116 for branch_graph, extra_outputs in zip(branch_graphs, extra_branch_outputs): 1117 branch_graph.outputs.extend(extra_outputs) 1118 # TODO(bjp): indicate it's an internal bug if this fails. 1119 _check_same_outputs(_CASE, branch_graphs) 1120 1121 for branch_graph in branch_graphs: 1122 branch_graph.name += "_rewritten" 1123 1124 case_op._set_func_list_attr("branches", [ 1125 util.create_new_tf_function(branch_graph) 1126 for branch_graph in branch_graphs 1127 ]) 1128 case_op._set_type_list_attr("Tout", branch_graphs[0].output_types) 1129 case_op._set_shape_list_attr("output_shapes", 1130 branch_graphs[0].output_shapes) 1131 case_op._add_outputs([t.dtype for t in extra_branch_outputs[0]], 1132 [t.shape for t in extra_branch_outputs[0]]) 1133 1134 # Resolve references to forward graph tensors in grad graphs and ensure 1135 # they are in-scope, i.e., belong to one of outer graphs of the grad graph. 1136 branches_grad_inputs = [ 1137 _resolve_grad_inputs(branch_graph, branch_grad_graph) for branch_graph, 1138 branch_grad_graph in zip(branch_graphs, branch_grad_graphs) 1139 ] 1140 1141 # This modifies the graphs in branch_grad_graphs. 1142 _make_output_composite_tensors_match(_CASE, branch_grad_graphs) 1143 1144 try: 1145 lowering = case_op._get_attr_bool("_lower_using_switch_merge") 1146 except errors_impl.NotFoundError: 1147 lowering = None 1148 1149 outputs = _build_case( 1150 case_op.inputs[0], 1151 branch_grad_graphs, 1152 branches_grad_inputs, 1153 name="gradient", 1154 lower_using_switch_merge=lowering) 1155 1156 # The predicate has no gradient. 1157 return [None] + outputs 1158 1159 1160def _build_case(branch_index, 1161 branch_graphs, 1162 branch_inputs, 1163 name=None, 1164 lower_using_switch_merge=None): 1165 """Creates an `Case` op from `branch_index`, branch graphs and inputs. 1166 1167 Note that this modifies `branch_graphs` to make the inputs match, and to 1168 output all intermediates values so they're available for the gradient 1169 computation. 1170 1171 `branch_graphs` need not have the same input types, but they must 1172 have the same output types. 1173 1174 Args: 1175 branch_index: integer Tensor 1176 branch_graphs: List of FuncGraph 1177 branch_inputs: List of lists of Tensors to be passed to corresponding 1178 branch_graph as input. 1179 name: the name for the Case op. 1180 lower_using_switch_merge: Lower this op using switch merge ops (optional). 1181 1182 Returns: 1183 A list of Tensors which are the outputs of the Case op. Does not include 1184 added intermediate outputs. 1185 """ 1186 _make_indexed_slices_indices_types_match(_CASE, branch_graphs) 1187 _check_same_outputs(_CASE, branch_graphs) 1188 1189 # Add inputs to branch_graphs to make them match. Note that this modifies the 1190 # graphs in `branch_graphs`. 1191 case_inputs = _make_inputs_match(branch_graphs, branch_inputs) 1192 1193 stateful_ops = [] 1194 for bg in branch_graphs: 1195 stateful_ops.extend([ 1196 op for op in bg.get_operations() if auto_control_deps.op_is_stateful(op) 1197 ]) 1198 1199 if stateful_ops: 1200 op_fn = gen_functional_ops.case 1201 else: 1202 op_fn = gen_functional_ops.stateless_case 1203 1204 # Create the Case op. 1205 with ops.control_dependencies( 1206 sum((list(bg.control_captures) for bg in branch_graphs), [])): 1207 1208 def _make_op(inputs): 1209 case_op, tensors = util.get_op_and_outputs(op_fn( 1210 branch_index, 1211 inputs, [t.dtype for t in branch_graphs[0].outputs], 1212 [util.create_new_tf_function(g) for g in branch_graphs], 1213 output_shapes=_get_output_shapes(*[g.outputs for g in branch_graphs]), 1214 name=name)) 1215 _copy_handle_data(tensors, *[g.outputs for g in branch_graphs]) 1216 if case_op is not None: 1217 util.maybe_set_lowering_attr(case_op, lower_using_switch_merge) 1218 util.maybe_propagate_compile_time_consts_in_xla(case_op) 1219 _set_read_only_resource_inputs_attr(case_op, branch_graphs) 1220 # Prevent fetching since the variant outputs can't be fetched directly. 1221 case_op.graph.prevent_fetching(case_op) 1222 1223 # Store the branch graphs so they can be reused during the gradient 1224 # pass. 1225 for i, bg in enumerate(branch_graphs): 1226 bg.outer_graph = ops.get_default_graph() 1227 setattr(case_op, "_branch_graph_{}".format(i), bg) 1228 1229 return tensors 1230 tensors = util.run_as_function_for_tape_gradients(_make_op, case_inputs) 1231 1232 # Return identities for each output of the Case op, rather than the output of 1233 # the Case op directly. This makes pruning work if the output of switch_case() 1234 # is fetched: the lowering pass converts the Case outputs into IdentityN 1235 # outputs, which if fetched will cause all ops in the taken branch to be run 1236 # (since it takes all merge ops as input). After lowering, each output 1237 # identity op will end up with only the appropriate merge op as input. 1238 # TODO(b/79984175): this doesn't have to be a tuple once we covert to the 1239 # correct output structure 1240 tensors = [array_ops.identity(t) for t in tensors] 1241 1242 return _pack_sequence_as(branch_graphs[0].structured_outputs, tensors) 1243 1244 1245def _set_read_only_resource_inputs_attr(op, branch_graphs): 1246 """Sets the list of resource inputs which are read-only. 1247 1248 This is used by AutomaticControlDependencies. 1249 1250 Args: 1251 op: If or Case Operation. 1252 branch_graphs: List of branch FuncGraphs. 1253 """ 1254 # The first entry in `op.inputs` is the predicate which is not passed to 1255 # branch graphs so len(branch_graph[i].inputs) == len(op.inputs) - 1. 1256 read_only_indices = set(range(len(op.inputs) - 1)) 1257 for branch_graph in branch_graphs: 1258 assert len(branch_graph.inputs) == len(op.inputs) - 1, "should never happen" 1259 if not read_only_indices: 1260 break 1261 branch_read_only_indices = acd.get_read_only_resource_input_indices_graph( 1262 branch_graph) 1263 read_only_indices = read_only_indices.intersection(branch_read_only_indices) 1264 # Convert indices in `branch_graphs[i].inputs` to `op.inputs`. 1265 read_only_indices = [i + 1 for i in read_only_indices] 1266 ops.set_int_list_attr(op, acd.READ_ONLY_RESOURCE_INPUTS_ATTR, 1267 sorted(read_only_indices)) 1268