1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Control Flow Operations. 16 17See the [autograph](https://www.tensorflow.org/guide/autograph) guide. 18""" 19# pylint: disable=g-bad-name 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import abc 25import collections 26import functools 27 28import six 29 30from tensorflow.core.framework import attr_value_pb2 31from tensorflow.core.protobuf import control_flow_pb2 32from tensorflow.python.eager import context 33from tensorflow.python.framework import composite_tensor 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.framework import tensor_spec 40from tensorflow.python.framework import tensor_util 41from tensorflow.python.framework import type_spec 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_util as util 44from tensorflow.python.ops import gen_array_ops 45from tensorflow.python.ops import gen_control_flow_ops 46from tensorflow.python.ops import gen_functional_ops 47from tensorflow.python.ops import gen_logging_ops 48from tensorflow.python.ops import gen_math_ops 49from tensorflow.python.ops import math_ops 50from tensorflow.python.ops import tensor_array_ops 51# go/tf-wildcard-import 52# pylint: disable=wildcard-import,undefined-variable 53from tensorflow.python.ops.gen_control_flow_ops import * 54# pylint: enable=wildcard-import 55from tensorflow.python.platform import tf_logging as logging 56from tensorflow.python.util import compat 57from tensorflow.python.util import deprecation 58from tensorflow.python.util import dispatch 59from tensorflow.python.util import nest 60from tensorflow.python.util import tf_should_use 61from tensorflow.python.util.lazy_loader import LazyLoader 62from tensorflow.python.util.tf_export import tf_export 63 64# This is to avoid a circular dependency: 65# cond_v2 -> gradients_util -> control_flow_ops 66cond_v2 = LazyLoader("cond_v2", globals(), 67 "tensorflow.python.ops.cond_v2") 68 69# This is to avoid circular dependencies: 70# while_v2 -> control_flow_ops 71# while_v2 -> gradients_util -> control_flow_ops 72while_v2 = LazyLoader("while_v2", globals(), 73 "tensorflow.python.ops.while_v2") 74 75# def_function also uses cond 76def_function = LazyLoader( 77 "def_function", globals(), 78 "tensorflow.python.eager.def_function") 79 80 81# We override the 'tuple' for a control flow op, so we keep python's 82# existing 'tuple' for later use in this module. 83_basetuple = tuple 84 85 86def _summarize_eager(tensor, summarize=None): 87 """Returns a summarized string representation of eager `tensor`. 88 89 Args: 90 tensor: EagerTensor to summarize 91 summarize: Include these many first elements of `array` 92 """ 93 # Emulate the behavior of Tensor::SummarizeValue() 94 if summarize is None: 95 summarize = 3 96 elif summarize < 0: 97 summarize = array_ops.size(tensor) 98 99 # reshape((-1,)) is the fastest way to get a flat array view 100 if tensor._rank(): # pylint: disable=protected-access 101 flat = tensor.numpy().reshape((-1,)) 102 lst = [str(x) for x in flat[:summarize]] 103 if len(lst) < flat.size: 104 lst.append("...") 105 else: 106 # tensor.numpy() returns a scalar for zero dimensional arrays 107 if gen_math_ops.not_equal(summarize, 0): 108 lst = [str(tensor.numpy())] 109 else: 110 lst = [] 111 112 return ", ".join(lst) 113 114 115# pylint: disable=protected-access 116 117 118# Assert and Print are special symbols in python, so we must 119# use an upper-case version of them. 120@tf_export("debugging.Assert", "Assert") 121@dispatch.add_dispatch_support 122@tf_should_use.should_use_result 123def Assert(condition, data, summarize=None, name=None): 124 """Asserts that the given condition is true. 125 126 If `condition` evaluates to false, print the list of tensors in `data`. 127 `summarize` determines how many entries of the tensors to print. 128 129 Args: 130 condition: The condition to evaluate. 131 data: The tensors to print out when condition is false. 132 summarize: Print this many entries of each tensor. 133 name: A name for this operation (optional). 134 135 Returns: 136 assert_op: An `Operation` that, when executed, raises a 137 `tf.errors.InvalidArgumentError` if `condition` is not true. 138 @compatibility(eager) 139 returns None 140 @end_compatibility 141 142 Raises: 143 @compatibility(TF1) 144 When in TF V1 mode (that is, outside `tf.function`) Assert needs a control 145 dependency on the output to ensure the assertion executes: 146 147 ```python 148 # Ensure maximum element of x is smaller or equal to 1 149 assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x]) 150 with tf.control_dependencies([assert_op]): 151 ... code using x ... 152 ``` 153 154 @end_compatibility 155 """ 156 if context.executing_eagerly(): 157 if not condition: 158 xs = ops.convert_n_to_tensor(data) 159 data_str = [_summarize_eager(x, summarize) for x in xs] 160 raise errors.InvalidArgumentError( 161 node_def=None, 162 op=None, 163 message="Expected '%s' to be true. Summarized data: %s" % 164 (condition, "\n".join(data_str))) 165 return 166 167 with ops.name_scope(name, "Assert", [condition, data]) as name: 168 xs = ops.convert_n_to_tensor(data) 169 if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs): 170 # As a simple heuristic, we assume that string and int32 are 171 # on host to avoid the need to use cond. If it is not case, 172 # we will pay the price copying the tensor to host memory. 173 return gen_logging_ops._assert(condition, data, summarize, name="Assert") 174 else: 175 condition = ops.convert_to_tensor(condition, name="Condition") 176 177 def true_assert(): 178 return gen_logging_ops._assert( 179 condition, data, summarize, name="Assert") 180 181 guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard") 182 if context.executing_eagerly(): 183 return 184 return guarded_assert.op 185 186 187def _Identity(tensor, name=None): 188 """Return a tensor with the same shape and contents as the input tensor. 189 190 Args: 191 tensor: A Tensor. 192 name: A name for this operation (optional). 193 194 Returns: 195 A Tensor with the same type and value as the input Tensor. 196 """ 197 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 198 if isinstance(tensor, ops.Tensor): 199 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 200 return gen_array_ops.ref_identity(tensor, name=name) 201 else: 202 return array_ops.identity(tensor, name=name) 203 elif isinstance(tensor, composite_tensor.CompositeTensor): 204 return nest.map_structure(_Identity, tensor, expand_composites=True) 205 else: 206 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 207 f"Received: {type(tensor)}.") 208 209 210def _NextIteration(tensor, name=None): 211 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 212 if isinstance(tensor, ops.Tensor): 213 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 214 return ref_next_iteration(tensor, name=name) 215 else: 216 return next_iteration(tensor, name=name) 217 elif isinstance(tensor, composite_tensor.CompositeTensor): 218 return nest.map_structure(_NextIteration, tensor, expand_composites=True) 219 else: 220 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 221 f"Received: {type(tensor)}.") 222 223 224def _Enter(tensor, 225 frame_name, 226 is_constant=False, 227 parallel_iterations=10, 228 use_ref=True, 229 use_input_shape=True, 230 name=None): 231 """Creates or finds a child frame, and makes `tensor` available to it. 232 233 The unique `frame_name` is used by the `Executor` to identify frames. If 234 `is_constant` is true, `tensor` is a constant in the child frame; otherwise 235 it may be changed in the child frame. At most `parallel_iterations` 236 iterations are run in parallel in the child frame. 237 238 Args: 239 tensor: The tensor to be made available to the child frame. 240 frame_name: The name of the child frame. 241 is_constant: If true, the output is constant within the child frame. 242 parallel_iterations: The number of iterations allowed to run in parallel. 243 use_ref: If true, use ref_enter if tensor is of ref type. 244 use_input_shape: If true, set the result's shape based on tensor's shape. 245 name: A name for this operation (optional). 246 247 Returns: 248 The same tensor as `tensor`. 249 """ 250 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 251 if isinstance(tensor, ops.Tensor): 252 if tensor.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access 253 result = gen_control_flow_ops.ref_enter( 254 tensor, frame_name, is_constant, parallel_iterations, name=name) 255 else: 256 result = gen_control_flow_ops.enter( 257 tensor, frame_name, is_constant, parallel_iterations, name=name) 258 if use_input_shape: 259 result.set_shape(tensor.get_shape()) 260 return result 261 elif isinstance(tensor, composite_tensor.CompositeTensor): 262 263 def enter_component(t): 264 return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref, 265 use_input_shape) 266 267 return nest.map_structure(enter_component, tensor, expand_composites=True) 268 else: 269 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 270 f"Received: {type(tensor)}.") 271 272 273def exit(tensor, name=None): # pylint: disable=redefined-builtin 274 """Exits the current frame to its parent frame. 275 276 Exit makes its input `tensor` available to the parent frame. 277 278 Args: 279 tensor: The tensor to be made available to the parent frame. 280 name: A name for this operation (optional). 281 282 Returns: 283 The same tensor as `tensor`. 284 """ 285 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 286 if isinstance(tensor, ops.Tensor): 287 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 288 return gen_control_flow_ops.ref_exit(tensor, name) 289 else: 290 return gen_control_flow_ops._exit(tensor, name) 291 elif isinstance(tensor, composite_tensor.CompositeTensor): 292 return nest.map_structure(exit, tensor, expand_composites=True) 293 else: 294 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 295 f"Received: {type(tensor)}.") 296 297 298def switch(data, pred, dtype=None, name=None): 299 """Forwards `data` to an output determined by `pred`. 300 301 If `pred` is false, the `data` input is forwarded to the first output. 302 Otherwise, the data goes to the second output. 303 304 This op handles `Tensor`s and `IndexedSlices`. 305 306 Args: 307 data: The tensor to be forwarded to the appropriate output. 308 pred: A scalar that specifies which output port will receive data. 309 dtype: Optional element type for the returned tensor. If missing, the type 310 is inferred from the type of `value`. 311 name: A name for this operation (optional). 312 313 Returns: 314 `(output_false, output_true)`: If `pred` is true, data will be forwarded 315 to `output_true`, otherwise it goes to `output_false`. 316 """ 317 with ops.name_scope(name, "Switch", [data, pred]) as name: 318 data = ops.internal_convert_to_tensor_or_composite( 319 data, dtype=dtype, name="data", as_ref=True) 320 pred = ops.convert_to_tensor(pred, name="pred") 321 if isinstance(data, ops.Tensor): 322 return gen_control_flow_ops.switch(data, pred, name=name) 323 else: 324 if not isinstance(data, composite_tensor.CompositeTensor): 325 raise TypeError( 326 "'data' must be a Tensor or CompositeTensor. " 327 f"Received: {type(data)}.") 328 tensors = nest.flatten(data, expand_composites=True) 329 mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors] 330 mapped_f, mapped_t = zip(*mapped) 331 return (nest.pack_sequence_as(data, mapped_f, expand_composites=True), 332 nest.pack_sequence_as(data, mapped_t, expand_composites=True)) 333 334 335def _SwitchRefOrTensor(data, pred, name="Switch"): 336 """Forwards `data` to an output determined by `pred`. 337 338 If `pred` is false, the `data` input is forwarded to the first output. 339 Otherwise, the data goes to the second output. 340 341 This op handles `Tensor`s and `IndexedSlices`. 342 343 Args: 344 data: The tensor to be forwarded to the appropriate output. 345 pred: A scalar that specifies which output port will receive data. 346 name: A name for this operation (optional). 347 348 Returns: 349 `(output_false, output_true)`: If `pred` is true, data will be forwarded to 350 `output_true`, otherwise it goes to `output_false`. 351 352 Raises: 353 TypeError: if data is not a Tensor or IndexedSlices 354 """ 355 data = ops.convert_to_tensor_or_composite(data, name="data") 356 # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below 357 # addresses the following scenario. 358 # 359 # Assume you execute Optimizer.apply_gradients() in a branch of a cond(). 360 # 361 # 1. The update op is created inside a `with ops.colocate(var):` block 362 # 363 # 2. Some tensor `data` is captured and a switch is created in a 364 # `with ops.colocate_with(data):` block. 365 # 366 # with ops.colocate_with(var): 367 # with ops.colocate_with(data): 368 # op = ... 369 # 370 # var and data may be pinned to different devices, so we want to ops 371 # created within ops.colocate_with(data) to ignore the existing stack. 372 with ops.colocate_with(data, ignore_existing=True): 373 if isinstance(data, ops.Tensor): 374 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 375 return ref_switch(data, pred, name=name) 376 return switch(data, pred, name=name) 377 378 379def merge(inputs, name=None): 380 """Returns the value of an available element of `inputs`. 381 382 This op tests each of the tensors in `inputs` in turn to determine if any of 383 them is available. If it finds an available tensor, it returns it and its 384 index in `inputs`. 385 386 It is an error if more than one tensor in `inputs` is available. If no tensor 387 in `inputs` is available, the returned tensor and index are not set. 388 389 This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of 390 `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices 391 before merging. 392 393 Args: 394 inputs: The input tensors, at most one of which is available. 395 name: A name for this operation (optional). 396 397 Returns: 398 A tuple containing the chosen input tensor and its index in `inputs`. 399 400 Raises: 401 ValueError: If any of the inputs is None, or inputs are IndexedSlices and 402 some but not all have a dense_shape property. 403 """ 404 if any(inp is None for inp in inputs): 405 raise ValueError("At least one of the merge inputs is None: %s" % inputs) 406 with ops.name_scope(name, "Merge", inputs) as name: 407 inputs = [ 408 ops.internal_convert_to_tensor_or_composite(inp, as_ref=True) 409 for inp in inputs 410 ] 411 if all(isinstance(v, ops.Tensor) for v in inputs): 412 if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access 413 return gen_control_flow_ops.ref_merge(inputs, name) 414 else: 415 return gen_control_flow_ops.merge(inputs, name) 416 else: 417 # If there is a mix of tensors and indexed slices, then convert the 418 # tensors to indexed slices. 419 if all(isinstance(v, (ops.IndexedSlices, ops.Tensor)) for v in inputs): 420 inputs = math_ops._as_indexed_slices_list(inputs, optimize=False) 421 422 for v in inputs: 423 if not isinstance(v, composite_tensor.CompositeTensor): 424 raise TypeError("Type %s not supported" % type(v)) 425 426 for v in inputs[1:]: 427 nest.assert_same_structure(inputs[0], v, expand_composites=True) 428 429 flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs] 430 merged_results = [ 431 gen_control_flow_ops.merge(component) 432 for component in zip(*flat_inputs) 433 ] 434 flat_merged = [tensor for (tensor, _) in merged_results] 435 chosen_index = merged_results[0][1] 436 merged_inputs = nest.pack_sequence_as( 437 inputs[0], flat_merged, expand_composites=True) 438 return (merged_inputs, chosen_index) 439 440 441# pylint: enable=protected-access 442 443 444def _convert_tensorarray_to_flow(tensor_or_tensor_array): 445 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray): 446 return tensor_or_tensor_array.flow 447 else: 448 return tensor_or_tensor_array 449 450 451def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows): 452 if len(tensors_or_tensorarrays) != len(tensors_or_flows): 453 raise ValueError( 454 "Lengths of original Tensor list and new list do not match: %d vs. %d" % 455 (len(tensors_or_tensorarrays), len(tensors_or_flows))) 456 return [ 457 tensor_array_ops.build_ta_with_new_flow(ta, t_or_flow) if isinstance( 458 ta, tensor_array_ops.TensorArray) else t_or_flow 459 for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows) 460 ] 461 462 463def _ShapeLessThanOrEqual(shape1, shape2): 464 if shape2.dims is None: 465 return True 466 if shape1.ndims != shape2.ndims: 467 return False 468 for dim1, dim2 in zip(shape1.dims, shape2.dims): 469 if dim2.value is not None and dim1.value != dim2.value: 470 return False 471 return True 472 473 474def _get_shape_invariant(var, shape=None): 475 """Returns shape invariant(s) for the given variable. 476 477 Args: 478 var: The tensor whose shape is described. 479 shape: The shape invariant for the tensor. If not specified, then a default 480 shape invariant for `var` is returned. 481 482 Returns: 483 `TensorShape` or `list` of `TensorShape`: The shape invariant for `var` (if 484 it is a `Tensor`), or the shape invariants for the components that comprise 485 `var` (if it is a `CompositeTensor`). 486 """ 487 if isinstance(var, composite_tensor.CompositeTensor): 488 # Get a TypeSpec for `var`. 489 if shape is None: 490 spec = var._type_spec # pylint: disable=protected-access 491 else: 492 spec = _shape_invariant_to_type_spec(var, shape) 493 494 tensor_specs = nest.flatten(spec, expand_composites=True) 495 return [tspec.shape for tspec in tensor_specs] 496 497 elif shape is None: 498 return var.shape 499 elif isinstance(shape, tensor_spec.TensorSpec): 500 if var.dtype != shape.dtype: 501 raise TypeError("TensorSpec %r is not compatible with %r" % (shape, var)) 502 return shape.shape 503 elif isinstance(shape, type_spec.TypeSpec): 504 raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var)) 505 else: 506 return shape 507 508 509def _shape_invariant_to_type_spec(var, shape): 510 """Converts a shape invariant to a TypeSpec. 511 512 Args: 513 var: The tensor whose shape is described by the shape invariant. 514 shape: A `TypeSpec` or `TensorShape`. If `shape` is already a `TypeSpec`, 515 then it is simply returned as-is. 516 517 Returns: 518 A `TypeSpec` for `var`, consistent with the given shape. 519 """ 520 if shape is None: 521 return type_spec.type_spec_from_value(var) 522 elif isinstance(shape, type_spec.TypeSpec): 523 if not shape.is_compatible_with(var): 524 raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var)) 525 return shape 526 elif not isinstance(shape, tensor_shape.TensorShape): 527 raise TypeError( 528 "'shape' must be one of TypeSpec, TensorShape or None. " 529 f"Received: {type(shape)}") 530 531 if isinstance(var, ops.Tensor): 532 return tensor_spec.TensorSpec(shape, var.dtype) 533 534 elif isinstance(var, composite_tensor.CompositeTensor): 535 try: 536 return var._shape_invariant_to_type_spec(shape) # pylint: disable=protected-access 537 except NotImplementedError: 538 raise TypeError( 539 "To describe or constrain a %s, use a %s instead of a TensorShape." % 540 (type(var).__name__, type(var._type_spec).__name__)) # pylint: disable=protected-access 541 542 else: 543 raise TypeError("Expected var to be a Tensor or CompositeTensor, got %s" 544 % var) 545 546 547def _SetShapeInvariants(input_vars, enter_vars, shapes): 548 """Set the shapes of the tensors in `enter_vars` to `shapes`. 549 550 Args: 551 input_vars: A list of tensors that are inputs to `enter_vars`. 552 enter_vars: A list of tensors whose shapes will be set. 553 shapes: A (possibly nested) list of shapes. 554 555 Raises: 556 ValueError: If any tensor in `enter_vars` has a less specific shape 557 than its corresponding shape in `shapes`. 558 """ 559 if shapes is None: 560 return 561 flat_shapes = nest.flatten(shapes) 562 if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes): 563 raise ValueError("'shapes' must be a (possibly nested) list of " 564 "TensorShapes.") 565 # Check that the shapes of the inputs are less than the shape invariants, 566 # and set the shapes of `enter_vars` to the shape invariants. 567 for inp, var, shape in zip(input_vars, enter_vars, flat_shapes): 568 if isinstance(var, ops.Tensor): 569 if not _ShapeLessThanOrEqual(inp.get_shape(), shape): 570 raise ValueError( 571 "The shape invariant specified for %s is not compatible with " 572 "the initial shape of the loop variable. It enters the loop " 573 "with shape %s, but the specified shape invariant is %s." % 574 (inp.name, inp.get_shape(), shape)) 575 var.set_shape(shape) 576 else: 577 raise TypeError("'enter_vars' must be a list of Tensors." 578 f"Received: {type(var)}.") 579 580 581def _EnforceShapeInvariant(merge_var, next_var): 582 """Check if the shapes of the loops variables are invariants. 583 584 Args: 585 merge_var: The tensor representing the initial values of the loop 586 variables. 587 next_var: The tensor representing the values of the loop variables 588 after one loop iteration. 589 590 Raises: 591 ValueError: If any tensor in `merge_var` has a more specific shape than 592 its corresponding tensor in `next_var`. 593 """ 594 if isinstance(merge_var, ops.Tensor): 595 m_shape = merge_var.get_shape() 596 n_shape = next_var.get_shape() 597 if not _ShapeLessThanOrEqual(n_shape, m_shape): 598 enter = merge_var.op.inputs[0].op 599 assert util.IsLoopEnter(enter) 600 input_t = enter.inputs[0] 601 raise ValueError( 602 "Input tensor '%s' enters the loop with shape %s, but has shape %s " 603 "after one iteration. To allow the shape to vary across iterations, " 604 "use the `shape_invariants` argument of tf.while_loop to specify a " 605 "less-specific shape." % (input_t.name, input_t.shape, n_shape)) 606 else: 607 raise TypeError("'merge_var' must be a Tensor. " 608 f"Received: {type(merge_var)}.") 609 610 611def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): 612 """Add NextIteration and back edge from v to m.""" 613 if isinstance(m, ops.Tensor): 614 v = ops.convert_to_tensor(v) 615 v = _NextIteration(v) 616 if enforce_shape_invariant: 617 # Make sure the shapes of loop outputs are correct. We do this before 618 # calling _update_input, which will raise a less-helpful error message if 619 # the types don't match. 620 # TODO(skyewm): call this for other cases below (needs testing) 621 _EnforceShapeInvariant(m, v) 622 m.op._update_input(1, v) # pylint: disable=protected-access 623 elif isinstance(m, composite_tensor.CompositeTensor): 624 # pylint: disable=protected-access 625 def update_component(m_component, v_component): 626 m_component.op._update_input(1, v_component) 627 628 if isinstance(m, ops.IndexedSlices): 629 v = math_ops._as_indexed_slices(v, optimize=False) 630 # pylint: enable=protected-access 631 v = _NextIteration(v) 632 return nest.map_structure(update_component, m, v, expand_composites=True) 633 else: 634 raise TypeError("'m' must be a Tensor or CompositeTensor. " 635 f"Received: {type(m)}.") 636 return v 637 638 639@six.add_metaclass(abc.ABCMeta) 640class ControlFlowContext(object): 641 """The base class for control flow context. 642 643 The usage pattern is a sequence of (Enter, Exit) followed by a final 644 ExitResult. 645 646 We maintain the following state for control flow contexts during graph 647 construction: 648 1. graph has _control_flow_context: the current context used to 649 construct new nodes. Changed by ctxt.Enter() and ctxt.Exit() 650 2. op has _control_flow_context: the context to which the op belongs. 651 Set at the time the op is created. Immutable. 652 3. A ControlFlowContext has _outer_context: the context in which this 653 context is created. Set at the time a context is created. Immutable. 654 4. A ControlFlowContext has _context_stack. 655 Pushed and popped by ctxt.Enter() and ctxt.Exit() 656 """ 657 658 def __init__(self, values_def=None, import_scope=None): 659 self._nested_contexts = [] 660 self._outer_context = ops.get_default_graph()._get_control_flow_context() 661 if self._outer_context: 662 self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access 663 self._context_stack = [] 664 if values_def: 665 self._init_values_from_proto(values_def, import_scope=import_scope) 666 else: 667 # The names of tensors that have been already seen in this context. 668 self._values = set() 669 # The keys are the names of tensors referenced by but external to this 670 # context. Each value is the Tensor that should be used by this context to 671 # access the key value (e.g. a switch output guarding a cond input value). 672 self._external_values = {} 673 674 def _init_values_from_proto(self, values_def, import_scope=None): 675 """Initializes values and external_values from `ValuesDef` protocol buffer. 676 677 Args: 678 values_def: `ValuesDef` protocol buffer. 679 import_scope: Optional `string`. Name scope to add. 680 """ 681 assert isinstance(values_def, control_flow_pb2.ValuesDef) 682 self._values = set( 683 ops.prepend_name_scope(value, import_scope) 684 for value in values_def.values) 685 g = ops.get_default_graph() 686 self._external_values = {} 687 for k, v in values_def.external_values.items(): 688 k = ops.prepend_name_scope(k, import_scope) 689 self._external_values[k] = g.as_graph_element( 690 ops.prepend_name_scope(v, import_scope)) 691 op_names = set([ 692 op.split(":")[0] 693 for op in self._values - set(self._external_values.keys()) 694 ]) 695 for op in op_names: 696 # pylint: disable=protected-access 697 g.as_graph_element(op)._set_control_flow_context(self) 698 # pylint: enable=protected-access 699 700 @property 701 def name(self): 702 return self._name 703 704 @property 705 def outer_context(self): 706 """Return the context containing this context.""" 707 return self._outer_context 708 709 @property 710 def grad_state(self): 711 raise NotImplementedError("Abstract method") 712 713 @property 714 def back_prop(self): 715 raise NotImplementedError("Abstract method") 716 717 @abc.abstractmethod 718 def to_control_flow_context_def(self, context_def, export_scope=None): 719 """Serializes this into `context_def`. 720 721 Args: 722 context_def: a `ControlFlowContextDef` protocol buffer. 723 export_scope: Optional `string`. Name scope to remove. 724 """ 725 raise NotImplementedError("Abstract method") 726 727 def _to_values_def(self, export_scope=None): 728 """Converts the values to a `ValuesDef` protocol buffer. 729 730 Args: 731 export_scope: Optional `string`. Name scope to remove. 732 733 Returns: 734 A `ValuesDef` protocol buffer. 735 """ 736 values_def = control_flow_pb2.ValuesDef() 737 values_def.values.extend( 738 [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)]) 739 for k, v in self._external_values.items(): 740 k = ops.strip_name_scope(k, export_scope) 741 values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope) 742 return values_def 743 744 def AddName(self, name): 745 self._values.add(name) 746 747 # pylint: disable=protected-access 748 def Enter(self): 749 """Enter this control flow context.""" 750 graph = ops.get_default_graph() 751 self._context_stack.append(graph._get_control_flow_context()) 752 graph._set_control_flow_context(self) 753 754 def Exit(self): 755 """Exit this control flow context.""" 756 graph = ops.get_default_graph() 757 last_context = self._context_stack.pop() 758 graph._set_control_flow_context(last_context) 759 760 def EnterGradientColocation(self, op, gradient_uid): 761 """Start building a gradient colocated with an op.""" 762 if self._outer_context: 763 self._outer_context.EnterGradientColocation(op, gradient_uid) 764 765 def ExitGradientColocation(self, op, gradient_uid): 766 """Start building a gradient colocated with an op.""" 767 if self._outer_context: 768 self._outer_context.ExitGradientColocation(op, gradient_uid) 769 770 def ExitResult(self, result): 771 """Make a list of tensors available in the outer context.""" 772 if self._outer_context: 773 def fn(x): 774 self._outer_context.AddName(x.name) 775 return x 776 nest.map_structure(fn, result, expand_composites=True) 777 778 def GetWhileContext(self): 779 """Return the while context containing this context.""" 780 if self._outer_context: 781 return self._outer_context.GetWhileContext() 782 return None 783 784 def _RemoveExternalControlEdges(self, op): 785 """Remove any external control dependency on this op.""" 786 while_ctxt = self.GetWhileContext() 787 # A control input of `op` is internal if it is in the same while 788 # loop context as the enclosing while loop context of self. 789 if while_ctxt is None: 790 internal_control_inputs = op.control_inputs 791 else: 792 internal_control_inputs = [] 793 for x in op.control_inputs: 794 ctxt = util.GetOutputContext(x) 795 if ctxt is not None and ctxt.GetWhileContext() == while_ctxt: 796 internal_control_inputs.append(x) 797 external_control_inputs = [] 798 if len(internal_control_inputs) != len(op.control_inputs): 799 external_control_inputs = list( 800 set(op.control_inputs) - set(internal_control_inputs)) 801 op._remove_all_control_inputs() 802 op._add_control_inputs(internal_control_inputs) 803 return internal_control_inputs, external_control_inputs 804 805 # pylint: enable=protected-access 806 807 def AddInnerOp(self, op): 808 """Notifies a scope about an operator added to an inner scope.""" 809 if self._outer_context: 810 self._outer_context.AddInnerOp(op) 811 812 def GetControlPivot(self): 813 """Returns the pivot node for this context, or None.""" 814 return None 815 816 def IsWhileContext(self): 817 return False 818 819 def IsCondContext(self): 820 return False 821 822 def IsXLAContext(self): 823 return False 824 825 def __str__(self): 826 return self.name 827 828 829class CondContext(ControlFlowContext): 830 """The context for the conditional construct.""" 831 832 def __init__(self, 833 pred=None, 834 pivot=None, 835 branch=None, 836 name="cond_text", 837 context_def=None, 838 import_scope=None): 839 """Creates a `CondContext`. 840 841 Args: 842 pred: The `boolean` tensor for the conditional predicate. 843 pivot: The predicate tensor in this branch. 844 branch: 0 or 1 representing this branch. 845 name: Name of the `CondContext` python object. 846 context_def: Optional `ContextDef` protocol buffer to initialize the 847 `CondContext` object from. 848 import_scope: Optional `string`. Name scope to add. Only used when 849 initialing from protocol buffer. 850 """ 851 self._name = ops.get_default_graph().unique_name(name) 852 853 if context_def: 854 self._init_from_proto(context_def, import_scope=import_scope) 855 else: 856 # Initializes the default fields. 857 ControlFlowContext.__init__(self) 858 self._pred = pred # The boolean tensor for the cond predicate 859 self._pivot = pivot # The predicate tensor in this branch 860 self._branch = branch # 0 or 1 representing this branch 861 862 # Values considered to have been already seen in this context. pred is not 863 # included in this context. 864 self._values.add(pred.name) 865 self._external_values[pred.name] = pred 866 self._values.add(pivot.name) 867 pivot.op._set_control_flow_context(self) # pylint: disable=protected-access 868 869 def _init_from_proto(self, context_def, import_scope=None): 870 """Creates a new `CondContext` from protocol buffer. 871 872 Args: 873 context_def: `CondContextDef` protocol buffer. 874 import_scope: Optional `string`. Name scope to add. 875 """ 876 assert isinstance(context_def, control_flow_pb2.CondContextDef) 877 # Create from context_def. 878 g = ops.get_default_graph() 879 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 880 self._pred = g.as_graph_element( 881 ops.prepend_name_scope(context_def.pred_name, import_scope)) 882 self._pivot = g.as_graph_element( 883 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 884 self._branch = context_def.branch 885 super(CondContext, self).__init__( 886 values_def=context_def.values_def, import_scope=import_scope) 887 888 @property 889 def pred(self): 890 return self._pred 891 892 @property 893 def pivot(self): 894 return self._pivot 895 896 @property 897 def branch(self): 898 return self._branch 899 900 @property 901 def grad_state(self): 902 if self.GetWhileContext(): 903 return self.GetWhileContext().grad_state 904 return None 905 906 @property 907 def back_prop(self): 908 if self.GetWhileContext(): 909 self.GetWhileContext().back_prop 910 return False 911 912 def GetControlPivot(self): 913 return self._pivot 914 915 def to_proto(self, export_scope=None): 916 """Converts a `CondContext` to a `CondContextDef` protocol buffer. 917 918 Args: 919 export_scope: Optional `string`. Name scope to remove. 920 921 Returns: 922 A `CondContextDef` protocol buffer. 923 """ 924 if (export_scope is None or self.name.startswith(export_scope)): 925 context_def = control_flow_pb2.CondContextDef() 926 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 927 context_def.pred_name = ops.strip_name_scope(self._pred.name, 928 export_scope) 929 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 930 export_scope) 931 context_def.branch = self._branch 932 context_def.values_def.MergeFrom( 933 super(CondContext, self)._to_values_def(export_scope)) 934 for nested in self._nested_contexts: 935 nested_def = context_def.nested_contexts.add() 936 nested.to_control_flow_context_def(nested_def) 937 938 return context_def 939 else: 940 return None 941 942 @staticmethod 943 def from_proto(context_def, import_scope=None): 944 """Returns a `CondContext` object created from `context_def`.""" 945 ret = CondContext(context_def=context_def, import_scope=import_scope) 946 947 ret.Enter() 948 for nested_def in context_def.nested_contexts: 949 from_control_flow_context_def(nested_def, import_scope=import_scope) 950 ret.Exit() 951 return ret 952 953 def to_control_flow_context_def(self, context_def, export_scope=None): 954 context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 955 956 def AddValue(self, val): 957 """Add `val` to the current context and its outer context recursively.""" 958 if val.name in self._values: 959 # Use the real value if it comes from outer context. This is needed in 960 # particular for nested conds. 961 result = self._external_values.get(val.name) 962 result = val if result is None else result 963 else: 964 result = val 965 self._values.add(val.name) 966 if self._outer_context: 967 result = self._outer_context.AddValue(val) 968 self._values.add(result.name) 969 self._external_values[result.name] = result 970 with ops.control_dependencies(None): 971 result = _SwitchRefOrTensor(result, self._pred)[self._branch] 972 if self._outer_context: 973 self._outer_context.AddInnerOp(result.op) 974 975 result.op.graph.prevent_fetching(result.op) 976 # pylint: disable=protected-access 977 result.op._set_control_flow_context(self) 978 # pylint: enable=protected-access 979 980 # Mark Switch output as seen by this context and any outer contexts, 981 # just like what we do for normal op outputs in _AddOpInternal() below. 982 ctxt = self 983 while ctxt is not None: 984 # pylint: disable=protected-access 985 ctxt._values.add(result.name) 986 ctxt = ctxt._outer_context 987 # pylint: enable=protected-access 988 989 self._external_values[val.name] = result 990 return result 991 992 def AddOp(self, op): 993 self._AddOpInternal(op) 994 995 def _AddOpInternal(self, op): 996 """Add `op` to the current context.""" 997 if not op.inputs: 998 # If we're in a while loop, remove any control inputs from outside the 999 # loop. 1000 self._RemoveExternalControlEdges(op) 1001 1002 if not any( 1003 util.OpInContext(input_op, self) for input_op in op.control_inputs): 1004 # pylint: disable=protected-access 1005 op._add_control_input(self._pivot.op) 1006 # pylint: enable=protected-access 1007 else: 1008 # Make each input to 'op' available in this CondContext. If an input is 1009 # already part of this context there's nothing to do, but if it's 1010 # external, AddValue() will handle adding the appropriate Switch node and 1011 # other bookkeeping. 1012 for index in range(len(op.inputs)): 1013 x = op.inputs[index] 1014 if op.type == "Merge" and x.op.type == "NextIteration": 1015 # Edge case: if we're importing a while loop inside this CondContext, 1016 # AddValue() will not correctly handle the NextIteration inputs to 1017 # Merge node. The problem is that the NextIteration should also be 1018 # part of this context, but if we're importing it won't have been 1019 # processed and added to the context yet, so AddValue() will try to 1020 # add a Switch which results in an invalid graph. Instead, we use the 1021 # NextIteration input as-is here, and it will eventually be added to 1022 # the context via AddOp(). 1023 real_x = x 1024 else: 1025 real_x = self.AddValue(x) 1026 if real_x != x: 1027 # pylint: disable=protected-access 1028 op._update_input(index, real_x) 1029 # pylint: enable=protected-access 1030 # Remove any external control dependency on this op. 1031 self._RemoveExternalControlEdges(op) 1032 # pylint: disable=protected-access 1033 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 1034 op._add_control_input(self._pivot.op) 1035 # pylint: enable=protected-access 1036 1037 # Mark op's outputs as seen by this context and any outer contexts. 1038 output_names = [x.name for x in op.outputs] 1039 ctxt = self 1040 while ctxt is not None: 1041 # pylint: disable=protected-access 1042 ctxt._values.update(output_names) 1043 ctxt = ctxt._outer_context 1044 # pylint: enable=protected-access 1045 1046 if self._outer_context or not util.IsLoopExit(op): 1047 op.graph.prevent_fetching(op) 1048 1049 if self._outer_context: 1050 self._outer_context.AddInnerOp(op) 1051 1052 def _ProcessOutputTensor(self, val): 1053 """Process an output tensor of a conditional branch.""" 1054 real_val = val 1055 if val.name not in self._values: 1056 # Handle the special case of lambda: x 1057 self._values.add(val.name) 1058 if self._outer_context: 1059 real_val = self._outer_context.AddValue(val) 1060 self._values.add(real_val.name) 1061 self._external_values[real_val.name] = real_val 1062 real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch] 1063 self._external_values[val.name] = real_val 1064 else: 1065 external_val = self._external_values.get(val.name) 1066 if external_val is not None: 1067 real_val = external_val 1068 return real_val 1069 1070 def _BuildCondTensor(self, v): 1071 if isinstance(v, ops.Operation): 1072 # Use pivot as the proxy for this op. 1073 return with_dependencies([v], self._pivot) 1074 else: 1075 v = nest.map_structure( 1076 _convert_tensorarray_to_flow, v, expand_composites=True) 1077 return self._ProcessOutputTensor(ops.convert_to_tensor(v)) 1078 1079 def BuildCondBranch(self, fn): 1080 """Add the subgraph defined by fn() to the graph.""" 1081 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1082 original_result = fn() 1083 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1084 if len(post_summaries) > len(pre_summaries): 1085 new_summaries = post_summaries[len(pre_summaries):] 1086 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1087 summary_ref[:] = pre_summaries 1088 with ops.control_dependencies(new_summaries): 1089 if original_result is None: 1090 return no_op(), None 1091 elif not isinstance(original_result, ops.Operation): 1092 original_result = nest.map_structure( 1093 array_ops.identity, original_result, expand_composites=True) 1094 if original_result is None: 1095 return None, None 1096 1097 result = nest.map_structure( 1098 self._BuildCondTensor, original_result, expand_composites=True) 1099 if not isinstance(result, (list, _basetuple)): 1100 result = [result] 1101 return original_result, result 1102 1103 def IsCondContext(self): 1104 return True 1105 1106 1107def _UnpackIfSingleton(res): 1108 if isinstance(res, (list, _basetuple)) and len(res) == 1: 1109 return res[0] 1110 else: 1111 return res 1112 1113 1114def _eager_cond_implementation(pred, true_fn, false_fn, strict, name): 1115 """Special cases for `cond` when executing eagerly.""" 1116 pred = ops.convert_to_tensor(pred) 1117 pred_constant_value = tensor_util.constant_value(pred) 1118 if pred_constant_value is None: 1119 # Eager tensors from a parallel device may not have a constant 1120 # value. Running the cond op itself would work, but we don't have logic to 1121 # build cond ops without wrapping in a function first. 1122 if (not isinstance(true_fn, def_function.Function) 1123 or not isinstance(false_fn, def_function.Function)): 1124 raise TypeError("When running tf.cond on a parallel device, 'true_fn' " 1125 "and 'false_fn' must be decorated with `tf.function`.") 1126 @def_function.function 1127 def _parallel_device_cond_wrapper(): 1128 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 1129 functions_run_eagerly = def_function.functions_run_eagerly() 1130 if functions_run_eagerly: 1131 # We need to use tf.function to deal with variable creation inside the 1132 # cond, and skipping it because of run_functions_eagerly would just 1133 # crash immediately. 1134 logging.warning( 1135 "It looks like tf.function behavior was disabled, perhaps using " 1136 "tf.config.run_functions_eagerly. Parallelized tf.cond requires " 1137 "tf.function to work. This primitive will override the disable.") 1138 def_function.run_functions_eagerly(False) 1139 try: 1140 return _parallel_device_cond_wrapper() 1141 finally: 1142 if functions_run_eagerly is not None: 1143 def_function.run_functions_eagerly(functions_run_eagerly) 1144 else: 1145 # For conditions which are eager tensors with a constant value (most of 1146 # them), we only call the relevant branch function and execute it eagerly. 1147 with ops.name_scope(name, "cond", [pred]): 1148 if pred_constant_value: 1149 result = true_fn() 1150 else: 1151 result = false_fn() 1152 if not strict: 1153 result = _UnpackIfSingleton(result) 1154 return result 1155 1156 1157# pylint: disable=redefined-outer-name 1158# pylint: disable=g-doc-args 1159@tf_export(v1=["cond"]) 1160@dispatch.add_dispatch_support 1161@deprecation.deprecated_args( 1162 None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", 1163 "fn1", "fn2") 1164def cond(pred, 1165 true_fn=None, 1166 false_fn=None, 1167 strict=False, 1168 name=None, 1169 fn1=None, 1170 fn2=None): 1171 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 1172 1173 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 1174 `false_fn` must have the same non-zero number and type of outputs. 1175 1176 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 1177 `false_fn` will be executed regardless of which branch is selected at runtime. 1178 1179 Although this behavior is consistent with the dataflow model of TensorFlow, 1180 it has frequently surprised users who expected a lazier semantics. 1181 Consider the following simple program: 1182 1183 ```python 1184 z = tf.multiply(a, b) 1185 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 1186 ``` 1187 1188 If `x < y`, the `tf.add` operation will be executed and `tf.square` 1189 operation will not be executed. Since `z` is needed for at least one 1190 branch of the `cond`, the `tf.multiply` operation is always executed, 1191 unconditionally. 1192 1193 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 1194 call to `cond`, and not at all during `Session.run()`). `cond` 1195 stitches together the graph fragments created during the `true_fn` and 1196 `false_fn` calls with some additional graph nodes to ensure that the right 1197 branch gets executed depending on the value of `pred`. 1198 1199 `tf.cond` supports nested structures as implemented in 1200 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 1201 same (possibly nested) value structure of lists, tuples, and/or named tuples. 1202 Singleton lists and tuples form the only exceptions to this: when returned by 1203 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 1204 This behavior is disabled by passing `strict=True`. 1205 1206 Args: 1207 pred: A scalar determining whether to return the result of `true_fn` or 1208 `false_fn`. 1209 true_fn: The callable to be performed if pred is true. 1210 false_fn: The callable to be performed if pred is false. 1211 strict: A boolean that enables/disables 'strict' mode; see above. 1212 name: Optional name prefix for the returned tensors. 1213 1214 Returns: 1215 Tensors returned by the call to either `true_fn` or `false_fn`. If the 1216 callables return a singleton list, the element is extracted from the list. 1217 1218 Raises: 1219 TypeError: if `true_fn` or `false_fn` is not callable. 1220 ValueError: if `true_fn` and `false_fn` do not return the same number of 1221 tensors, or return tensors of different types. 1222 1223 Example: 1224 1225 ```python 1226 x = tf.constant(2) 1227 y = tf.constant(5) 1228 def f1(): return tf.multiply(x, 17) 1229 def f2(): return tf.add(y, 23) 1230 r = tf.cond(tf.less(x, y), f1, f2) 1231 # r is set to f1(). 1232 # Operations in f2 (e.g., tf.add) are not executed. 1233 ``` 1234 1235 """ 1236 # We needed to make true_fn/false_fn keyword arguments for 1237 # backwards-compatibility. This check exists so that we can convert back to 1238 # having them be positional arguments. 1239 # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after 1240 # `fn1` and `fn2` are deleted. 1241 if fn1 is not None: 1242 if true_fn is not None: 1243 raise TypeError( 1244 "cond(): 'true_fn' and 'fn1' may not be set simultaneously.") 1245 true_fn = fn1 1246 elif true_fn is None: 1247 raise TypeError("cond(): 'true_fn' argument required") 1248 if fn2 is not None: 1249 if false_fn is not None: 1250 raise TypeError( 1251 "cond(): 'false_fn' and 'fn2' may not be set simultaneously.") 1252 false_fn = fn2 1253 elif false_fn is None: 1254 raise TypeError("cond(): 'false_fn' argument required") 1255 1256 if not callable(true_fn): 1257 raise TypeError("'true_fn' must be callable.") 1258 if not callable(false_fn): 1259 raise TypeError("'false_fn' must be callable.") 1260 1261 if context.executing_eagerly(): 1262 return _eager_cond_implementation(pred, true_fn, false_fn, strict, name) 1263 1264 # Always enable control flow v2 if building a function, regardless of toggle. 1265 if util.EnableControlFlowV2(ops.get_default_graph()): 1266 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 1267 1268 with ops.name_scope(name, "cond", [pred]): 1269 # Add the Switch to the graph. 1270 if isinstance(pred, bool): 1271 raise TypeError("'pred' must not be a Python bool.") 1272 p_2, p_1 = switch(pred, pred) 1273 pivot_1 = array_ops.identity(p_1, name="switch_t") 1274 pivot_2 = array_ops.identity(p_2, name="switch_f") 1275 pred = array_ops.identity(pred, name="pred_id") 1276 # Disable the fetching of tensors that are only on one branch of cond. 1277 for tensor in [p_1, p_2, pivot_1, pivot_2, pred]: 1278 tensor.op.graph.prevent_fetching(tensor.op) 1279 1280 # Build the graph for the true branch in a new context. 1281 context_t = CondContext(pred, pivot_1, branch=1) 1282 try: 1283 context_t.Enter() 1284 orig_res_t, res_t = context_t.BuildCondBranch(true_fn) 1285 if orig_res_t is None: 1286 raise ValueError("'true_fn' must have a return value.") 1287 context_t.ExitResult(res_t) 1288 finally: 1289 context_t.Exit() 1290 1291 # Build the graph for the false branch in a new context. 1292 context_f = CondContext(pred, pivot_2, branch=0) 1293 try: 1294 context_f.Enter() 1295 orig_res_f, res_f = context_f.BuildCondBranch(false_fn) 1296 if orig_res_f is None: 1297 raise ValueError("'false_fn' must have a return value.") 1298 context_f.ExitResult(res_f) 1299 finally: 1300 context_f.Exit() 1301 1302 if not strict: 1303 orig_res_t = _UnpackIfSingleton(orig_res_t) 1304 orig_res_f = _UnpackIfSingleton(orig_res_f) 1305 1306 # Check that the return values of the two branches have the same structure. 1307 try: 1308 nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True) 1309 except (TypeError, ValueError): 1310 nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f) 1311 nest.map_structure(_cast_indexed_slice_indices, res_t, res_f) 1312 try: 1313 nest.assert_same_structure(orig_res_t, orig_res_f, 1314 expand_composites=True) 1315 except TypeError as e: 1316 raise TypeError( 1317 f"Incompatible return types of 'true_fn' and 'false_fn': {e}") 1318 except ValueError as e: 1319 raise ValueError( 1320 f"Incompatible return values of 'true_fn' and 'false_fn': {e}") 1321 1322 # Add the final merge to the graph. 1323 if not res_t: 1324 raise ValueError( 1325 "'true_fn' and 'false_fn' must return at least one result.") 1326 1327 res_t_flat = nest.flatten(res_t, expand_composites=True) 1328 res_f_flat = nest.flatten(res_f, expand_composites=True) 1329 1330 for (x, y) in zip(res_t_flat, res_f_flat): 1331 assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) 1332 if x.dtype.base_dtype != y.dtype.base_dtype: 1333 raise ValueError( 1334 "Outputs of 'true_fn' and 'false_fn' must have the same type(s). " 1335 f"Received {x.dtype.name} from 'true_fn' " 1336 f"and {y.dtype.name} from 'false_fn'.") 1337 1338 merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] 1339 merges = _convert_flows_to_tensorarrays( 1340 nest.flatten(orig_res_t, expand_composites=True), merges) 1341 1342 # Only add non-nested conds to the collection. Any nested control flow will 1343 # be encapsulated in the root context. 1344 assert context_t.outer_context == context_f.outer_context 1345 if context_t.outer_context is None: 1346 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) 1347 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) 1348 1349 merges = nest.pack_sequence_as( 1350 structure=orig_res_t, flat_sequence=merges, expand_composites=True) 1351 1352 # Singleton lists and tuples are automatically unpacked if strict == False. 1353 if not strict: 1354 merges = _UnpackIfSingleton(merges) 1355 return merges 1356 1357 1358def _cast_indexed_slice_indices(a, b): 1359 """Cast IndexedSlice.indices from int32 to int64 where necessary. 1360 1361 If `a` and `b` are both IndexedSlices, and their indices have different 1362 dtypes, then cast both their dtypes to `int64` (modifies `a` and `b` 1363 in-place). Otherwise, does nothing. 1364 1365 Args: 1366 a: A value, which may be an IndexedSlices. 1367 b: A value, which may be an IndexedSlices. 1368 """ 1369 if (isinstance(a, ops.IndexedSlices) and isinstance(b, ops.IndexedSlices) 1370 and a.indices.dtype != b.indices.dtype): 1371 # pylint: disable=protected-access 1372 a._indices = math_ops.cast(a.indices, dtypes.int64) 1373 b._indices = math_ops.cast(b.indices, dtypes.int64) 1374 1375 1376# pylint: enable=g-doc-args 1377# pylint: enable=redefined-outer-name 1378 1379 1380@tf_export("cond", v1=[]) 1381@dispatch.add_dispatch_support 1382def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None): 1383 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 1384 1385 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 1386 `false_fn` must have the same non-zero number and type of outputs. 1387 1388 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 1389 `false_fn` will be executed regardless of which branch is selected at runtime. 1390 1391 Although this behavior is consistent with the dataflow model of TensorFlow, 1392 it has frequently surprised users who expected a lazier semantics. 1393 Consider the following simple program: 1394 1395 ```python 1396 z = tf.multiply(a, b) 1397 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 1398 ``` 1399 1400 If `x < y`, the `tf.add` operation will be executed and `tf.square` 1401 operation will not be executed. Since `z` is needed for at least one 1402 branch of the `cond`, the `tf.multiply` operation is always executed, 1403 unconditionally. 1404 1405 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 1406 call to `cond`, and not at all during `Session.run()`). `cond` 1407 stitches together the graph fragments created during the `true_fn` and 1408 `false_fn` calls with some additional graph nodes to ensure that the right 1409 branch gets executed depending on the value of `pred`. 1410 1411 `tf.cond` supports nested structures as implemented in 1412 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 1413 same (possibly nested) value structure of lists, tuples, and/or named tuples. 1414 Singleton lists and tuples form the only exceptions to this: when returned by 1415 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 1416 1417 Note: It is illegal to "directly" use tensors created inside a cond branch 1418 outside it, e.g. by storing a reference to a branch tensor in the python 1419 state. If you need to use a tensor created in a branch function you should 1420 return it as an output of the branch function and use the output from 1421 `tf.cond` instead. 1422 1423 Args: 1424 pred: A scalar determining whether to return the result of `true_fn` or 1425 `false_fn`. 1426 true_fn: The callable to be performed if pred is true. 1427 false_fn: The callable to be performed if pred is false. 1428 name: Optional name prefix for the returned tensors. 1429 1430 Returns: 1431 Tensors returned by the call to either `true_fn` or `false_fn`. If the 1432 callables return a singleton list, the element is extracted from the list. 1433 1434 Raises: 1435 TypeError: if `true_fn` or `false_fn` is not callable. 1436 ValueError: if `true_fn` and `false_fn` do not return the same number of 1437 tensors, or return tensors of different types. 1438 1439 Example: 1440 1441 ```python 1442 x = tf.constant(2) 1443 y = tf.constant(5) 1444 def f1(): return tf.multiply(x, 17) 1445 def f2(): return tf.add(y, 23) 1446 r = tf.cond(tf.less(x, y), f1, f2) 1447 # r is set to f1(). 1448 # Operations in f2 (e.g., tf.add) are not executed. 1449 ``` 1450 1451 """ 1452 return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name) 1453 1454 1455def _resource_safe_shape(t): 1456 """Returns the shape of t or the variable it points to.""" 1457 if t.dtype == dtypes.resource: 1458 while t.op.inputs: 1459 t = t.op.inputs[0] 1460 return tensor_shape.TensorShape(t.op.get_attr("shape")) 1461 return array_ops.shape_internal(t, optimize=False) 1462 1463 1464# TODO(yuanbyu): Consider having a unified notion of context for 1465# not only conditionals and loops but also control dependency and 1466# subgraphs. 1467class WhileContext(ControlFlowContext): 1468 """The context for the loop construct.""" 1469 1470 def __init__(self, 1471 maximum_iterations=None, 1472 parallel_iterations=10, 1473 back_prop=True, 1474 swap_memory=False, 1475 name="while_context", 1476 grad_state=None, 1477 context_def=None, 1478 import_scope=None): 1479 """"Creates a `WhileContext`. 1480 1481 Args: 1482 maximum_iterations: Optional upper bound on number of loop iterations. 1483 parallel_iterations: The number of iterations allowed to run in parallel. 1484 back_prop: Whether backprop is enabled for this while loop. 1485 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 1486 name: Optional name prefix for the returned tensors. 1487 grad_state: The gradient loop state. 1488 context_def: Optional `WhileContextDef` protocol buffer to initialize the 1489 `Whilecontext` python object from. 1490 import_scope: Optional `string`. Name scope to add. Only used when 1491 initialing from protocol buffer. 1492 """ 1493 if context_def: 1494 self._init_from_proto(context_def, import_scope=import_scope) 1495 else: 1496 ControlFlowContext.__init__(self) 1497 self._init_from_args(maximum_iterations, parallel_iterations, back_prop, 1498 swap_memory, name) 1499 # The gradient loop state. 1500 self._grad_state = grad_state 1501 1502 def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop, 1503 swap_memory, name): 1504 """Creates a new `WhileContext` from arguments. 1505 1506 Args: 1507 maximum_iterations: Optional upper bound on number of loop iterations. 1508 parallel_iterations: The number of iterations allowed to run in parallel. 1509 back_prop: Whether backprop is enabled for this while loop. 1510 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 1511 name: Optional name prefix for the returned tensors. 1512 1513 Raises: 1514 ValueError: If `parallel_iterations` has invalid value. 1515 """ 1516 if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0): 1517 raise ValueError("'parallel_iterations' must be a positive integer: " 1518 "%s" % parallel_iterations) 1519 self._name = ops.get_default_graph().unique_name(name) 1520 self._maximum_iterations = maximum_iterations 1521 self._parallel_iterations = parallel_iterations 1522 self._back_prop = back_prop 1523 self._swap_memory = swap_memory 1524 # We use this node to control constants created by the pred lambda. 1525 self._pivot_for_pred = None 1526 # We use this node to control constants created by the body lambda. 1527 self._pivot_for_body = None 1528 # The boolean tensor for loop termination condition. Used in code 1529 # generation for gradient computation 1530 self._pivot = None 1531 # The list of exit tensors for loop variables. 1532 self._loop_exits = [] 1533 # The list of enter tensors for loop variables. 1534 self._loop_enters = [] 1535 self._graph = ops.get_default_graph() 1536 1537 def _init_from_proto(self, context_def, import_scope=None): 1538 """Creates a new `WhileContext` from protocol buffer. 1539 1540 Args: 1541 context_def: `WhileContextDef` protocol buffer. 1542 import_scope: Optional `string`. Name scope to add. 1543 """ 1544 assert isinstance(context_def, control_flow_pb2.WhileContextDef) 1545 # Create from context_def. 1546 g = ops.get_default_graph() 1547 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 1548 if context_def.maximum_iterations_name: 1549 self._maximum_iterations = g.as_graph_element( 1550 ops.prepend_name_scope(context_def.maximum_iterations_name, 1551 import_scope)) 1552 else: 1553 self._maximum_iterations = None 1554 self._parallel_iterations = context_def.parallel_iterations 1555 self._back_prop = context_def.back_prop 1556 self._swap_memory = context_def.swap_memory 1557 self._pivot_for_pred = g.as_graph_element( 1558 ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope)) 1559 # We use this node to control constants created by the body lambda. 1560 self._pivot_for_body = g.as_graph_element( 1561 ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope)) 1562 # The boolean tensor for loop termination condition. Used in code 1563 # generation for gradient computation. 1564 self._pivot = g.as_graph_element( 1565 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 1566 # The list of exit tensors for loop variables. 1567 self._loop_exits = [ 1568 g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) 1569 for exit_name in context_def.loop_exit_names 1570 ] 1571 # The list of enter tensors for loop variables. 1572 self._loop_enters = [ 1573 g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) 1574 for enter_name in context_def.loop_enter_names 1575 ] 1576 super(WhileContext, self).__init__( 1577 values_def=context_def.values_def, import_scope=import_scope) 1578 1579 # import_scope causes self.name to be different from the original serialized 1580 # context's name. Rewrite "frame_name" attrs with the new name. 1581 if import_scope: 1582 for tensor_name in self._values: 1583 op = g.as_graph_element(tensor_name).op 1584 if util.IsLoopEnter(op): 1585 # pylint: disable=protected-access 1586 op._set_attr("frame_name", 1587 attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) 1588 # pylint: enable=protected-access 1589 self._graph = ops.get_default_graph() 1590 1591 @property 1592 def maximum_iterations(self): 1593 """The maximum number of iterations that will be executed.""" 1594 return self._maximum_iterations 1595 1596 @property 1597 def parallel_iterations(self): 1598 """The number of iterations allowed to run in parallel.""" 1599 return self._parallel_iterations 1600 1601 @property 1602 def back_prop(self): 1603 """True iff backprop is enabled for this while loop.""" 1604 return self._back_prop 1605 1606 @property 1607 def swap_memory(self): 1608 """True iff GPU-CPU memory swap is enabled for this while loop.""" 1609 return self._swap_memory 1610 1611 @property 1612 def pivot(self): 1613 """The boolean tensor representing the loop termination condition.""" 1614 return self._pivot 1615 1616 @property 1617 def loop_enters(self): 1618 """The list of enter tensors for loop variables.""" 1619 return self._loop_enters 1620 1621 @property 1622 def loop_exits(self): 1623 """The list of exit tensors for loop variables.""" 1624 return self._loop_exits 1625 1626 @property 1627 def grad_state(self): 1628 """The gradient loop state.""" 1629 return self._grad_state 1630 1631 def to_proto(self, export_scope=None): 1632 """Converts a `WhileContext` to a `WhileContextDef` protocol buffer. 1633 1634 Args: 1635 export_scope: Optional `string`. Name scope to remove. 1636 1637 Returns: 1638 A `WhileContextDef` protocol buffer. 1639 """ 1640 if (export_scope is None or self.name.startswith(export_scope)): 1641 context_def = control_flow_pb2.WhileContextDef() 1642 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 1643 context_def.parallel_iterations = self._parallel_iterations 1644 if self._maximum_iterations is not None: 1645 context_def.maximum_iterations_name = ops.strip_name_scope( 1646 self._maximum_iterations.name, export_scope) 1647 context_def.back_prop = self._back_prop 1648 context_def.swap_memory = self._swap_memory 1649 context_def.pivot_for_pred_name = ops.strip_name_scope( 1650 self._pivot_for_pred.name, export_scope) 1651 context_def.pivot_for_body_name = ops.strip_name_scope( 1652 self._pivot_for_body.name, export_scope) 1653 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 1654 export_scope) 1655 context_def.loop_exit_names.extend([ 1656 ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits 1657 ]) 1658 context_def.loop_enter_names.extend([ 1659 ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters 1660 ]) 1661 context_def.values_def.MergeFrom( 1662 super(WhileContext, self)._to_values_def(export_scope=export_scope)) 1663 for nested in self._nested_contexts: 1664 nested_def = context_def.nested_contexts.add() 1665 nested.to_control_flow_context_def(nested_def) 1666 1667 return context_def 1668 else: 1669 return None 1670 1671 def to_control_flow_context_def(self, context_def, export_scope=None): 1672 context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 1673 1674 @staticmethod 1675 def from_proto(context_def, import_scope=None): 1676 """Returns a `WhileContext` object created from `context_def`. 1677 1678 Args: 1679 context_def: A `WhileContextDef` protocol buffer. 1680 import_scope: Optional `string`. Name scope to add. 1681 1682 Returns: 1683 A `WhileContext` Python object. 1684 """ 1685 ret = WhileContext(context_def=context_def, import_scope=import_scope) 1686 ret.Enter() 1687 for nested_def in context_def.nested_contexts: 1688 from_control_flow_context_def(nested_def, import_scope=import_scope) 1689 ret.Exit() 1690 return ret 1691 1692 def GetWhileContext(self): 1693 return self 1694 1695 def GetControlPivot(self): 1696 if self._pivot_for_body is not None: 1697 return self._pivot_for_body 1698 return self._pivot_for_pred 1699 1700 def AddValue(self, val): 1701 """Add `val` to the current context and its outer context recursively.""" 1702 result = val 1703 new_value = val.name not in self._values 1704 # Don't treat ops in this context as new values. Usually all known values 1705 # are in self._values, except when we're importing a while loop inside this 1706 # WhileContext. Since there's a cycle in this case, `val` may be part of the 1707 # imported while loop but not yet processed by this context and added to 1708 # self._values in _AddOpInternal. We only want to process external input 1709 # tensors to the while loop here. 1710 new_value &= val.op._control_flow_context is not self # pylint: disable=protected-access 1711 if new_value: 1712 self._values.add(val.name) 1713 1714 # If we are in a grad context and val is from its forward context, 1715 # use GetRealValue(), which adds the logic to save the history of 1716 # val in forward. 1717 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 1718 if grad_ctxt: 1719 grad_ctxt = grad_ctxt.GetWhileContext() 1720 if grad_ctxt.grad_state: 1721 forward_ctxt = util.GetWhileContext(val.op) 1722 if util.IsLoopExit(val.op): 1723 forward_ctxt = forward_ctxt.outer_context 1724 if forward_ctxt: 1725 forward_ctxt = forward_ctxt.GetWhileContext() 1726 if forward_ctxt == grad_ctxt.grad_state.forward_context: 1727 real_val = grad_ctxt.grad_state.GetRealValue(val) 1728 self._external_values[val.name] = real_val 1729 return real_val 1730 1731 if self._outer_context is not None: 1732 result = self._outer_context.AddValue(val) 1733 # Create an Enter to make `result` known to this loop context. 1734 with ops.control_dependencies(None): 1735 enter = _Enter( 1736 result, 1737 self._name, 1738 is_constant=True, 1739 parallel_iterations=self._parallel_iterations) 1740 enter.graph.prevent_feeding(enter) 1741 if self._outer_context: 1742 self._outer_context.AddInnerOp(enter.op) 1743 # Fix the control inputs and control flow context of these enter ops. 1744 self._FixControlInputsAndContext([enter]) 1745 1746 # Add `enter` in this context. 1747 self._values.add(enter.name) 1748 self._external_values[val.name] = enter 1749 result = enter 1750 else: 1751 actual_val = self._external_values.get(val.name) 1752 if actual_val is not None: 1753 result = actual_val 1754 return result 1755 1756 def AddOp(self, op): 1757 """Add `op` to the current context.""" 1758 # For a reduction op, if op is in a grad context and its input is from 1759 # its forward context, moving op to the forward context means we would 1760 # store the tensor after the reduction as opposed to the tensor before 1761 # reduction, and therefore could significantly reduce memory consumption. 1762 # For now, we do this only for a few ops. 1763 # 1764 # If in XLA context, do not move constant ops to forward pass as pushing to 1765 # and popping from a stack removes the constant property of an op and breaks 1766 # XLA compilation, which requires certain inputs to be constant for certain 1767 # ops. 1768 if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}: 1769 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 1770 if grad_ctxt: 1771 grad_ctxt = grad_ctxt.GetWhileContext() 1772 if grad_ctxt.grad_state: 1773 op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op) 1774 if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context: 1775 op_input_ctxt = op.inputs[0].op._get_control_flow_context() 1776 op._set_control_flow_context(op_input_ctxt) 1777 op_input_ctxt._AddOpInternal(op) 1778 return 1779 self._AddOpInternal(op) 1780 1781 def _AddOpInternal(self, op): 1782 """Add `op` to the current context. 1783 1784 We move any external control dependencies of the op to the loop pivot, to 1785 ensure they get executed. 1786 """ 1787 # This is needed to prevent frame mismatch errors where there are Const 1788 # nodes inside tf.function in v1 while_loop and inlining is turned on. 1789 if op.type in ["PartitionedCall", "StatefulPartitionedCall"]: 1790 op._add_control_input(self.GetControlPivot().op) # pylint: disable=protected-access 1791 if not op.inputs: 1792 # Remove any external control dependency on this op 1793 control_inputs, external_inputs = self._RemoveExternalControlEdges(op) 1794 # Add a control edge from the control pivot to this op. 1795 if not control_inputs: 1796 # pylint: disable=protected-access 1797 op._add_control_input(self.GetControlPivot().op) 1798 # pylint: enable=protected-access 1799 for x in op.outputs: 1800 self._values.add(x.name) 1801 else: 1802 for index in range(len(op.inputs)): 1803 x = op.inputs[index] 1804 real_x = self.AddValue(x) 1805 if real_x != x: 1806 op._update_input(index, real_x) # pylint: disable=protected-access 1807 # Remove any external control dependency on this op. 1808 _, external_inputs = self._RemoveExternalControlEdges(op) 1809 # Add a control dependency to prevent loop invariants from 1810 # enabling ops that should not be executed. 1811 self._MaybeAddControlDependency(op) 1812 for x in op.outputs: 1813 self._values.add(x.name) 1814 if external_inputs: 1815 # Use an identity to pull control inputs as data inputs. Note that we 1816 # ignore ops which don't have outputs. TODO(apassos): fix that 1817 with ops.control_dependencies(None): 1818 self.Enter() 1819 external_inputs = [ 1820 array_ops.identity(x.outputs[0]).op 1821 for x in external_inputs 1822 if x.outputs 1823 ] 1824 self.Exit() 1825 op._add_control_inputs(external_inputs) # pylint: disable=protected-access 1826 if self._outer_context or not util.IsLoopExit(op): 1827 op.graph.prevent_fetching(op) 1828 for x in op.outputs: 1829 op.graph.prevent_feeding(x) 1830 1831 if self._outer_context: 1832 self._outer_context.AddInnerOp(op) 1833 1834 def _MaybeAddControlDependency(self, op): 1835 """Add a control input to the op if it only depends on loop invariants.""" 1836 1837 def _IsOpFree(op): 1838 """Determines if `op` needs a control dependency.""" 1839 if op.control_inputs: 1840 return False 1841 # pylint: disable=protected-access 1842 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 1843 return True 1844 # pylint: enable=protected-access 1845 for x in op.inputs: 1846 if not util.IsLoopConstantEnter(x.op): 1847 return False 1848 return True 1849 1850 if _IsOpFree(op): 1851 # pylint: disable=protected-access 1852 op._add_control_input(self.GetControlPivot().op) 1853 # pylint: enable=protected-access 1854 1855 def AddForwardLoopCounter(self, outer_grad_state): 1856 """Adds a loop that counts the number of iterations. 1857 1858 This is added to the forward loop at the time when we start to 1859 create the loop for backprop gradient computation. Called in 1860 the outer context of this forward context. 1861 1862 The pseudocode is: 1863 `n = 0; while (_pivot) { n++; }` 1864 1865 Note that a control dependency is added to `n` to ensure the correct 1866 execution order of stack push ops. 1867 1868 Args: 1869 outer_grad_state: The outer grad state. None if not nested. 1870 1871 Returns: 1872 The number of iterations taken by the forward loop and the loop index. 1873 """ 1874 n = constant_op.constant(0, name="f_count") 1875 if outer_grad_state is not None: 1876 # Force the stack pushes of i-th execution of an inner loop to be ordered 1877 # before the pushes of (i+1)-th execution of the same inner loop. 1878 outer_add_op = outer_grad_state.forward_index.op.inputs[0].op 1879 n.op._add_control_input(outer_add_op) # pylint: disable=protected-access 1880 1881 self.Enter() 1882 self.AddName(n.name) 1883 enter_n = _Enter( 1884 n, 1885 self._name, 1886 is_constant=False, 1887 parallel_iterations=self._parallel_iterations, 1888 name="f_count") 1889 self.loop_enters.append(enter_n) 1890 1891 merge_n = merge([enter_n, enter_n])[0] 1892 switch_n = switch(merge_n, self._pivot) 1893 1894 index = math_ops.add(switch_n[1], 1) 1895 next_n = _NextIteration(index) 1896 merge_n.op._update_input(1, next_n) 1897 1898 total_iterations = exit(switch_n[0], name="f_count") 1899 self.loop_exits.append(total_iterations) 1900 self.ExitResult([total_iterations]) 1901 self.Exit() 1902 return total_iterations, next_n 1903 1904 def AddBackpropLoopCounter(self, count, outer_grad_state): 1905 """Add the backprop loop that controls the iterations. 1906 1907 This is added to the backprop loop. It is used to control the loop 1908 termination of the backprop loop. Called in the outer context of 1909 this grad context. 1910 1911 The pseudocode is: 1912 `n = count; while (n >= 1) { n--; }` 1913 1914 Note that a control dependency is added to `final_zero` to ensure the 1915 correct execution order of stack pop ops. 1916 1917 Args: 1918 count: The number of iterations for backprop. 1919 outer_grad_state: The outer grad state. None if not nested. 1920 1921 Returns: 1922 The loop index. 1923 """ 1924 in_separate_functions = count.graph is not ops.get_default_graph() 1925 if in_separate_functions: 1926 # Brings the count into this graph 1927 count = array_ops.identity(count) 1928 else: 1929 # TODO(apassos) XLA expects this constant to be created outside the loop, 1930 # so doing that for now. 1931 one = constant_op.constant(1, name="b_count") 1932 1933 self.Enter() 1934 self.AddName(count.name) 1935 enter_count = _Enter( 1936 count, 1937 self._name, 1938 is_constant=False, 1939 parallel_iterations=self._parallel_iterations, 1940 name="b_count") 1941 self.loop_enters.append(enter_count) 1942 1943 merge_count = merge([enter_count, enter_count])[0] 1944 self._pivot_for_pred = merge_count 1945 1946 if in_separate_functions: 1947 one = constant_op.constant(1, name="b_count") 1948 pred = math_ops.greater_equal(merge_count, one) 1949 self._pivot = loop_cond(pred, name="b_count") 1950 switch_count = switch(merge_count, self._pivot) 1951 1952 index = math_ops.subtract(switch_count[1], one) 1953 self._pivot_for_body = index 1954 next_count = _NextIteration(index) 1955 merge_count.op._update_input(1, next_count) 1956 1957 final_zero = exit(switch_count[0], name="b_count") 1958 self.loop_exits.append(final_zero) 1959 if outer_grad_state is not None: 1960 # Force the stack pops of i-th execution of an inner loop to be ordered 1961 # before the pops of (i+1)-th execution of the same inner loop. 1962 # pylint: disable=protected-access 1963 outer_grad_state.grad_sync._add_control_input(final_zero.op) 1964 # pylint: enable=protected-access 1965 1966 self.ExitResult([final_zero]) 1967 self.Exit() 1968 return next_count 1969 1970 def AddBackpropAccumulator(self, op, grad): 1971 """Add an accumulation loop for every loop invariant. 1972 1973 This is added to the backprop loop. It is used to accumulate partial 1974 gradients within each loop iteration. Called when in the gradient while 1975 context. 1976 1977 The pseudocode is: 1978 ``` 1979 acc = 0.0; 1980 while (_pivot) { 1981 acc += grad; 1982 } 1983 ``` 1984 1985 Args: 1986 op: The Enter op for a loop invariant. 1987 grad: The partial gradient of an iteration for a loop invariant. 1988 1989 Returns: 1990 The gradient for a loop invariant. 1991 """ 1992 self.Exit() 1993 # Create a zeros tensor with the right shape for acc. If we don't 1994 # know the full shape statically, we will have to get the shape 1995 # dynamically from the forward inference. Getting the shape right 1996 # for the zeros is only needed for the base case when the loop exits 1997 # without running any iterations. 1998 shape = grad.get_shape() 1999 if shape.is_fully_defined(): 2000 if self.outer_context: 2001 self.outer_context.Enter() 2002 acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc") 2003 if self.outer_context: 2004 self.outer_context.Exit() 2005 else: 2006 value = op.inputs[0] 2007 if (isinstance(self.outer_context, WhileContext) and 2008 self.outer_context.grad_state is not None): 2009 # We are in a nested while loop. 2010 forward_ctxt = self.grad_state.forward_context 2011 forward_ctxt.outer_context.Enter() 2012 zeros_shape = array_ops.shape_internal(value, optimize=False) 2013 forward_ctxt.outer_context.Exit() 2014 outer_grad_state = self.grad_state.outer_grad_state 2015 history_zeros_shape = outer_grad_state.AddForwardAccumulator( 2016 zeros_shape) 2017 self.outer_context.Enter() 2018 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 2019 history_zeros_shape, zeros_shape) 2020 acc = array_ops.zeros(real_shape, grad.dtype) 2021 self.outer_context.Exit() 2022 else: 2023 if self.outer_context: 2024 self.outer_context.Enter() 2025 zeros_shape = array_ops.shape_internal(value, optimize=False) 2026 acc = array_ops.zeros(zeros_shape, grad.dtype) 2027 if self.outer_context: 2028 self.outer_context.Exit() 2029 2030 self.Enter() 2031 self.AddName(acc.name) 2032 enter_acc = _Enter( 2033 acc, 2034 self._name, 2035 is_constant=False, 2036 parallel_iterations=self._parallel_iterations, 2037 name="b_acc") 2038 self.loop_enters.append(enter_acc) 2039 2040 merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0] 2041 switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot) 2042 2043 add_acc = math_ops.add(switch_acc_true, grad) 2044 next_acc = _NextIteration(add_acc) 2045 merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access 2046 2047 result_acc = exit(switch_acc_false, name="b_acc") 2048 self.loop_exits.append(result_acc) 2049 self.ExitResult([result_acc]) 2050 return result_acc 2051 2052 def AddBackpropIndexedSlicesAccumulator(self, op, grad): 2053 """This is used for accumulating gradients that are IndexedSlices. 2054 2055 This is essentially the equivalent of AddBackpropAccumulator but optimized 2056 for things like updating embeddings from within a while loop. 2057 2058 Args: 2059 op: The Enter op for a loop invariant. 2060 grad: The partial gradients represented as an IndexedSlices. 2061 2062 Returns: 2063 The accumulated IndexedSlices gradient of the loop invariant. 2064 """ 2065 values = grad.values 2066 indices = grad.indices 2067 dense_shape = grad.dense_shape 2068 2069 self.Exit() 2070 if self.outer_context: 2071 self.outer_context.Enter() 2072 if values.get_shape().is_fully_defined(): 2073 values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] + 2074 values.get_shape().dims[1:]) 2075 if self.outer_context: 2076 self.outer_context.Enter() 2077 values_acc = constant_op.constant( 2078 0, values.dtype, shape=values_shape, name="b_acc") 2079 if self.outer_context: 2080 self.outer_context.Exit() 2081 else: 2082 values_shape = _resource_safe_shape(op.inputs[0])[1:] 2083 values_shape = array_ops.concat([[1], values_shape], 0) 2084 values_acc = array_ops.zeros(values_shape, dtype=values.dtype) 2085 indices_acc = constant_op.constant([0], indices.dtype) 2086 shape_acc = None 2087 if dense_shape is not None: 2088 if dense_shape.get_shape().is_fully_defined(): 2089 if self.outer_context: 2090 self.outer_context.Enter() 2091 shape_acc = constant_op.constant( 2092 0, dense_shape.dtype, shape=dense_shape.get_shape()) 2093 if self.outer_context: 2094 self.outer_context.Exit() 2095 else: 2096 shape_acc = array_ops.zeros_like( 2097 array_ops.shape_internal( 2098 op.inputs[0], optimize=False, out_type=dense_shape.dtype), 2099 optimize=False) 2100 2101 if self.outer_context: 2102 self.outer_context.Exit() 2103 2104 self.Enter() 2105 self.AddName(values_acc.name) 2106 self.AddName(indices_acc.name) 2107 init_acc = [indices_acc, values_acc] 2108 if shape_acc is not None: 2109 self.AddName(shape_acc.name) 2110 init_acc.append(shape_acc) 2111 2112 # Set use_input_shape=False since the accumulator tensors will grow in 2113 # size. If use_input_shape=True, the _update_input call below will result in 2114 # incompatible shapes. 2115 enter_acc = [ 2116 _Enter( 2117 x, 2118 self._name, 2119 is_constant=False, 2120 parallel_iterations=self._parallel_iterations, 2121 use_input_shape=False, 2122 name="b_acc") for x in init_acc 2123 ] 2124 # Manually set appropriate partial shapes. 2125 enter_acc[0].set_shape([None]) 2126 if values_acc.shape.dims is not None: 2127 enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:]) 2128 self.loop_enters.extend(enter_acc) 2129 2130 merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc] 2131 switch_acc = [switch(x, self._pivot) for x in merge_acc] 2132 2133 # The actual accumulation. 2134 acc_indexed_slices = [ 2135 array_ops.concat([xa[1], xv], 0) 2136 for xa, xv in zip(switch_acc[:2], [indices, values]) 2137 ] 2138 if shape_acc is not None: 2139 # For the shape we just keep the maximum 2140 acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1])) 2141 2142 next_acc = [_NextIteration(x) for x in acc_indexed_slices] 2143 for xm, xn in zip(merge_acc, next_acc): 2144 xm.op._update_input(1, xn) # pylint: disable=protected-access 2145 2146 exit_acc = [exit(x[0], name="b_acc") for x in switch_acc] 2147 self.loop_exits.extend(exit_acc) 2148 2149 self.ExitResult(exit_acc) 2150 return ops.IndexedSlices( 2151 indices=exit_acc[0], 2152 values=exit_acc[1], 2153 dense_shape=exit_acc[2] if shape_acc is not None else None) 2154 2155 def _InitializeValues(self, values): 2156 """Makes the values known to this context.""" 2157 self._values = set() 2158 for x in values: 2159 if isinstance(x, ops.Tensor): 2160 self._values.add(x.name) 2161 else: 2162 raise TypeError("'values' must be a list of Tensors. " 2163 f"Received: {type(x)}.") 2164 2165 def _BuildLoop(self, pred, body, original_loop_vars, loop_vars, 2166 shape_invariants): 2167 """Core: Add the loop termination condition and body to the graph.""" 2168 flat_loop_vars = nest.flatten(original_loop_vars, expand_composites=True) 2169 2170 # Let the context know the loop variables so the loop variables 2171 # would be added in the outer contexts properly. 2172 self._InitializeValues(loop_vars) 2173 real_vars = loop_vars 2174 if self._outer_context: 2175 real_vars = [self._outer_context.AddValue(x) for x in loop_vars] 2176 with ops.control_dependencies(None): 2177 enter_vars = [ 2178 _Enter( 2179 x, 2180 self._name, 2181 is_constant=False, 2182 parallel_iterations=self._parallel_iterations, 2183 use_input_shape=(shape_invariants is None)) for x in real_vars 2184 ] 2185 for x in enter_vars: 2186 x.graph.prevent_feeding(x) 2187 if self._outer_context: 2188 self._outer_context.AddInnerOp(x.op) 2189 2190 # Finds the closest enclosing non-None control pivot. 2191 outer_context = self._outer_context 2192 control_pivot = None 2193 while outer_context is not None and control_pivot is None: 2194 control_pivot = outer_context.GetControlPivot() 2195 # pylint: disable=protected-access 2196 outer_context = outer_context._outer_context 2197 # pylint: enable=protected-access 2198 2199 if control_pivot is not None: 2200 for var in enter_vars: 2201 if util.IsLoopConstantEnter(var.op.inputs[0].op): 2202 # pylint: disable=protected-access 2203 var.op._add_control_input(control_pivot.op) 2204 # pylint: enable=protected-access 2205 _SetShapeInvariants(real_vars, enter_vars, shape_invariants) 2206 2207 # Fix the control inputs and control flow context of these enter ops. 2208 self._FixControlInputsAndContext(enter_vars) 2209 self._InitializeValues(enter_vars) 2210 self._loop_enters = enter_vars 2211 2212 merge_vars = [merge([x, x])[0] for x in enter_vars] 2213 self._pivot_for_pred = merge_vars[0] 2214 2215 # Build the graph for pred. 2216 merge_vars_with_tensor_arrays = ( 2217 _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars)) 2218 packed_vars = nest.pack_sequence_as( 2219 structure=original_loop_vars, 2220 flat_sequence=merge_vars_with_tensor_arrays, 2221 expand_composites=True) 2222 c = ops.convert_to_tensor(pred(*packed_vars)) 2223 self._pivot = loop_cond(c, name="LoopCond") 2224 switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars] 2225 2226 # Build the graph for body. 2227 vars_for_body = [_Identity(x[1]) for x in switch_vars] 2228 self._pivot_for_body = vars_for_body[0] 2229 # Convert TensorArray flow variables inside the context back into 2230 # their associated TensorArrays for calling the body. 2231 vars_for_body_with_tensor_arrays = ( 2232 _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body)) 2233 packed_vars_for_body = nest.pack_sequence_as( 2234 structure=original_loop_vars, 2235 flat_sequence=vars_for_body_with_tensor_arrays, 2236 expand_composites=True) 2237 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2238 body_result = body(*packed_vars_for_body) 2239 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2240 if not nest.is_sequence_or_composite(body_result): 2241 body_result = [body_result] 2242 if len(post_summaries) > len(pre_summaries): 2243 new_summaries = post_summaries[len(pre_summaries):] 2244 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2245 summary_ref[:] = pre_summaries 2246 with ops.control_dependencies(new_summaries): 2247 2248 def map_fn(x): 2249 # TODO(apassos) figure out how to trigger with tensor arrays as well 2250 if isinstance(x, tensor_array_ops.TensorArray): 2251 return x 2252 return array_ops.identity(x) 2253 2254 body_result = nest.map_structure( 2255 map_fn, body_result, expand_composites=True) 2256 2257 # Compare the structure types of input and output of body. 2258 # For backwards compatibility, the first layer is forced to a list 2259 # during this comparison, because inputs are typically lists and 2260 # outputs of the body are typically tuples. 2261 nest.assert_same_structure( 2262 list(packed_vars_for_body), list(body_result), expand_composites=True) 2263 2264 # Store body_result to keep track of TensorArrays returned by body 2265 original_body_result = body_result 2266 # Convert TensorArrays returned by body into their flow variables 2267 result = nest.map_structure( 2268 _convert_tensorarray_to_flow, 2269 nest.flatten(body_result, expand_composites=True), 2270 expand_composites=True) 2271 result = ops.convert_n_to_tensor_or_composite(result) 2272 2273 # Add NextIteration and the back edges to complete the loop. 2274 if len(merge_vars) != len(result): 2275 raise ValueError("Number of inputs and outputs of 'body' must match " 2276 f"'loop_vars'. Got {len(merge_vars)} for the number of " 2277 f"inputs/outputs, and {len(result)} for 'loop_vars'.") 2278 next_vars = [] 2279 for m, v in zip(merge_vars, result): 2280 next_vars.append(_AddNextAndBackEdge(m, v)) 2281 2282 # Add the exit ops. 2283 exit_vars = [exit(x[0]) for x in switch_vars] 2284 self._loop_exits = exit_vars 2285 2286 # Exit the loop. 2287 self.ExitResult(exit_vars) 2288 2289 return original_body_result, exit_vars 2290 2291 def BuildLoop(self, pred, body, loop_vars, shape_invariants, 2292 return_same_structure): 2293 """Add the loop termination condition and body to the graph.""" 2294 2295 # Keep original_loop_vars to identify which are TensorArrays 2296 original_loop_vars = loop_vars 2297 # Convert TensorArrays to their flow variables 2298 loop_vars = nest.map_structure( 2299 _convert_tensorarray_to_flow, 2300 nest.flatten(loop_vars, expand_composites=False), 2301 expand_composites=True) 2302 loop_vars = ops.convert_n_to_tensor_or_composite(loop_vars) 2303 if shape_invariants is None: 2304 shape_invariants = nest.map_structure( 2305 _get_shape_invariant, loop_vars, expand_composites=False) 2306 loop_vars = nest.flatten(loop_vars, expand_composites=True) 2307 try: 2308 self.Enter() 2309 # _BuildLoop calls _update_input in several places. _mutation_lock() 2310 # ensures a Session.run call cannot occur between creating and mutating 2311 # new ops. 2312 with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access 2313 original_body_result, exit_vars = self._BuildLoop( 2314 pred, body, original_loop_vars, loop_vars, shape_invariants) 2315 finally: 2316 self.Exit() 2317 2318 flat_result = nest.flatten(original_body_result, expand_composites=True) 2319 # Convert TensorArray flow variables outside the context back into 2320 # their associated TensorArrays for returning to caller. 2321 exit_vars_with_tensor_arrays = ( 2322 _convert_flows_to_tensorarrays(flat_result, exit_vars)) 2323 packed_exit_vars = nest.pack_sequence_as( 2324 structure=original_body_result, 2325 flat_sequence=exit_vars_with_tensor_arrays, 2326 expand_composites=True) 2327 2328 if return_same_structure: 2329 return packed_exit_vars 2330 else: 2331 return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars 2332 2333 def _FixControlInputsAndContext(self, enters): 2334 graph = ops.get_default_graph() 2335 # pylint: disable=protected-access 2336 for e in enters: 2337 if isinstance(e, ops.Tensor): 2338 xs = [e] 2339 else: 2340 raise TypeError("'enters' must be a list of Tensors. " 2341 f"Received: {type(e)}.") 2342 for x in xs: 2343 inp_op = x.op.inputs[0].op 2344 control_inputs = graph._control_dependencies_for_inputs([inp_op]) 2345 outer_control_inputs = [] 2346 for op in control_inputs: 2347 # We need to keep control inputs that are in any ancestor 2348 # ControlFlowContext, and within outer WhileContext. 2349 keep_as_control_input = True 2350 op_ctxt = util.GetOutputContext(op) 2351 outer_ctxt = self.outer_context 2352 outer_while_context = (None if outer_ctxt is None else 2353 outer_ctxt.GetWhileContext()) 2354 while outer_ctxt != op_ctxt: 2355 if outer_ctxt is None or outer_ctxt == outer_while_context: 2356 keep_as_control_input = False 2357 break 2358 outer_ctxt = outer_ctxt.outer_context 2359 if keep_as_control_input: 2360 outer_control_inputs.append(op) 2361 x.op._set_control_flow_context(self) 2362 x.op._add_control_inputs(outer_control_inputs) 2363 graph._record_op_seen_by_control_dependencies(x.op) 2364 # pylint: enable=protected-access 2365 2366 def IsWhileContext(self): 2367 return True 2368 2369 2370# @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature". 2371# pylint: disable=redefined-outer-name 2372@tf_export("while_loop", v1=[]) 2373@deprecation.deprecated_arg_values( 2374 None, 2375 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 2376Instead of: 2377results = tf.while_loop(c, b, vars, back_prop=False) 2378Use: 2379results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""", 2380 warn_once=True, 2381 back_prop=False) 2382def while_loop_v2(cond, 2383 body, 2384 loop_vars, 2385 shape_invariants=None, 2386 parallel_iterations=10, 2387 back_prop=True, 2388 swap_memory=False, 2389 maximum_iterations=None, 2390 name=None): 2391 """Repeat `body` while the condition `cond` is true. 2392 2393 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 2394 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 2395 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 2396 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 2397 `cond` and `body`. `cond` and `body` both take as many arguments as there are 2398 `loop_vars`. 2399 2400 In addition to regular Tensors or IndexedSlices, the body may accept and 2401 return TensorArray objects. The flows of the TensorArray objects will 2402 be appropriately forwarded between loops and during gradient calculations. 2403 2404 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 2405 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 2406 stitches together the graph fragments created during the `cond` and `body` 2407 calls with some additional graph nodes to create the graph flow that 2408 repeats `body` until `cond` returns false. 2409 2410 For correctness, `tf.while_loop()` strictly enforces shape invariants for 2411 the loop variables. A shape invariant is a (possibly partial) shape that 2412 is unchanged across the iterations of the loop. An error will be raised 2413 if the shape of a loop variable after an iteration is determined to be more 2414 general than or incompatible with its shape invariant. For example, a shape 2415 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 2416 compatible with [11, 17]. By default (if the argument `shape_invariants` is 2417 not specified), it is assumed that the initial shape of each tensor in 2418 `loop_vars` is the same in every iteration. The `shape_invariants` argument 2419 allows the caller to specify a less specific shape invariant for each loop 2420 variable, which is needed if the shape varies between iterations. The 2421 `tf.Tensor.set_shape` 2422 function may also be used in the `body` function to indicate that 2423 the output loop variable has a particular shape. The shape invariant for 2424 SparseTensor and IndexedSlices are treated specially as follows: 2425 2426 a) If a loop variable is a SparseTensor, the shape invariant must be 2427 TensorShape([r]) where r is the rank of the dense tensor represented 2428 by the sparse tensor. It means the shapes of the three tensors of the 2429 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 2430 is the shape of the SparseTensor.dense_shape property. It must be the shape of 2431 a vector. 2432 2433 b) If a loop variable is an IndexedSlices, the shape invariant must be 2434 a shape invariant of the values tensor of the IndexedSlices. It means 2435 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 2436 [shape.ndims]). 2437 2438 `while_loop` implements non-strict semantics, enabling multiple iterations 2439 to run in parallel. The maximum number of parallel iterations can be 2440 controlled by `parallel_iterations`, which gives users some control over 2441 memory consumption and execution order. For correct programs, `while_loop` 2442 should return the same result for any parallel_iterations > 0. 2443 2444 For training, TensorFlow stores the tensors that are produced in the 2445 forward inference and are needed in back propagation. These tensors are a 2446 main source of memory consumption and often cause OOM errors when training 2447 on GPUs. When the flag swap_memory is true, we swap out these tensors from 2448 GPU to CPU. This for example allows us to train RNN models with very long 2449 sequences and large batches. 2450 2451 Args: 2452 cond: A callable that represents the termination condition of the loop. 2453 body: A callable that represents the loop body. 2454 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 2455 `Tensor`, and `TensorArray` objects. 2456 shape_invariants: The shape invariants for the loop variables. 2457 parallel_iterations: The number of iterations allowed to run in parallel. It 2458 must be a positive integer. 2459 back_prop: (optional) Deprecated. False disables support for back 2460 propagation. Prefer using `tf.stop_gradient` instead. 2461 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2462 maximum_iterations: Optional maximum number of iterations of the while loop 2463 to run. If provided, the `cond` output is AND-ed with an additional 2464 condition ensuring the number of iterations executed is no greater than 2465 `maximum_iterations`. 2466 name: Optional name prefix for the returned tensors. 2467 2468 Returns: 2469 The output tensors for the loop variables after the loop. The return value 2470 has the same structure as `loop_vars`. 2471 2472 Raises: 2473 TypeError: if `cond` or `body` is not callable. 2474 ValueError: if `loop_vars` is empty. 2475 2476 Example: 2477 2478 ```python 2479 i = tf.constant(0) 2480 c = lambda i: tf.less(i, 10) 2481 b = lambda i: (tf.add(i, 1), ) 2482 r = tf.while_loop(c, b, [i]) 2483 ``` 2484 2485 Example with nesting and a namedtuple: 2486 2487 ```python 2488 import collections 2489 Pair = collections.namedtuple('Pair', 'j, k') 2490 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 2491 c = lambda i, p: i < 10 2492 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 2493 ijk_final = tf.while_loop(c, b, ijk_0) 2494 ``` 2495 2496 Example using shape_invariants: 2497 2498 ```python 2499 i0 = tf.constant(0) 2500 m0 = tf.ones([2, 2]) 2501 c = lambda i, m: i < 10 2502 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 2503 tf.while_loop( 2504 c, b, loop_vars=[i0, m0], 2505 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 2506 ``` 2507 2508 Example which demonstrates non-strict semantics: In the following 2509 example, the final value of the counter `i` does not depend on `x`. So 2510 the `while_loop` can increment the counter parallel to updates of `x`. 2511 However, because the loop counter at one loop iteration depends 2512 on the value at the previous iteration, the loop counter itself cannot 2513 be incremented in parallel. Hence if we just want the final value of the 2514 counter (which we print on the line `print(sess.run(i))`), then 2515 `x` will never be incremented, but the counter will be updated on a 2516 single thread. Conversely, if we want the value of the output (which we 2517 print on the line `print(sess.run(out).shape)`), then the counter may be 2518 incremented on its own thread, while `x` can be incremented in 2519 parallel on a separate thread. In the extreme case, it is conceivable 2520 that the thread incrementing the counter runs until completion before 2521 `x` is incremented even a single time. The only thing that can never 2522 happen is that the thread updating `x` can never get ahead of the 2523 counter thread because the thread incrementing `x` depends on the value 2524 of the counter. 2525 2526 ```python 2527 import tensorflow as tf 2528 2529 n = 10000 2530 x = tf.constant(list(range(n))) 2531 c = lambda i, x: i < n 2532 b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1, 2533 [i], "x:")) 2534 i, out = tf.while_loop(c, b, (0, x)) 2535 with tf.compat.v1.Session() as sess: 2536 print(sess.run(i)) # prints [0] ... [9999] 2537 2538 # The following line may increment the counter and x in parallel. 2539 # The counter thread may get ahead of the other thread, but not the 2540 # other way around. So you may see things like 2541 # [9996] x:[9987] 2542 # meaning that the counter thread is on iteration 9996, 2543 # while the other thread is on iteration 9987 2544 print(sess.run(out).shape) 2545 ``` 2546 2547 """ 2548 return while_loop( 2549 cond=cond, 2550 body=body, 2551 loop_vars=loop_vars, 2552 shape_invariants=shape_invariants, 2553 parallel_iterations=parallel_iterations, 2554 back_prop=back_prop, 2555 swap_memory=swap_memory, 2556 name=name, 2557 maximum_iterations=maximum_iterations, 2558 return_same_structure=True) 2559 2560 2561# pylint: disable=redefined-outer-name 2562@tf_export(v1=["while_loop"]) 2563def while_loop(cond, 2564 body, 2565 loop_vars, 2566 shape_invariants=None, 2567 parallel_iterations=10, 2568 back_prop=True, 2569 swap_memory=False, 2570 name=None, 2571 maximum_iterations=None, 2572 return_same_structure=False): 2573 """Repeat `body` while the condition `cond` is true. 2574 2575 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 2576 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 2577 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 2578 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 2579 `cond` and `body`. `cond` and `body` both take as many arguments as there are 2580 `loop_vars`. 2581 2582 In addition to regular Tensors or IndexedSlices, the body may accept and 2583 return TensorArray objects. The flows of the TensorArray objects will 2584 be appropriately forwarded between loops and during gradient calculations. 2585 2586 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 2587 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 2588 stitches together the graph fragments created during the `cond` and `body` 2589 calls with some additional graph nodes to create the graph flow that 2590 repeats `body` until `cond` returns false. 2591 2592 For correctness, `tf.while_loop()` strictly enforces shape invariants for 2593 the loop variables. A shape invariant is a (possibly partial) shape that 2594 is unchanged across the iterations of the loop. An error will be raised 2595 if the shape of a loop variable after an iteration is determined to be more 2596 general than or incompatible with its shape invariant. For example, a shape 2597 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 2598 compatible with [11, 17]. By default (if the argument `shape_invariants` is 2599 not specified), it is assumed that the initial shape of each tensor in 2600 `loop_vars` is the same in every iteration. The `shape_invariants` argument 2601 allows the caller to specify a less specific shape invariant for each loop 2602 variable, which is needed if the shape varies between iterations. The 2603 `tf.Tensor.set_shape` 2604 function may also be used in the `body` function to indicate that 2605 the output loop variable has a particular shape. The shape invariant for 2606 SparseTensor and IndexedSlices are treated specially as follows: 2607 2608 a) If a loop variable is a SparseTensor, the shape invariant must be 2609 TensorShape([r]) where r is the rank of the dense tensor represented 2610 by the sparse tensor. It means the shapes of the three tensors of the 2611 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 2612 is the shape of the SparseTensor.dense_shape property. It must be the shape of 2613 a vector. 2614 2615 b) If a loop variable is an IndexedSlices, the shape invariant must be 2616 a shape invariant of the values tensor of the IndexedSlices. It means 2617 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 2618 [shape.ndims]). 2619 2620 `while_loop` implements non-strict semantics, enabling multiple iterations 2621 to run in parallel. The maximum number of parallel iterations can be 2622 controlled by `parallel_iterations`, which gives users some control over 2623 memory consumption and execution order. For correct programs, `while_loop` 2624 should return the same result for any parallel_iterations > 0. 2625 2626 For training, TensorFlow stores the tensors that are produced in the 2627 forward inference and are needed in back propagation. These tensors are a 2628 main source of memory consumption and often cause OOM errors when training 2629 on GPUs. When the flag swap_memory is true, we swap out these tensors from 2630 GPU to CPU. This for example allows us to train RNN models with very long 2631 sequences and large batches. 2632 2633 Args: 2634 cond: A callable that represents the termination condition of the loop. 2635 body: A callable that represents the loop body. 2636 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 2637 `Tensor`, and `TensorArray` objects. 2638 shape_invariants: The shape invariants for the loop variables. 2639 parallel_iterations: The number of iterations allowed to run in parallel. It 2640 must be a positive integer. 2641 back_prop: Whether backprop is enabled for this while loop. 2642 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2643 name: Optional name prefix for the returned tensors. 2644 maximum_iterations: Optional maximum number of iterations of the while loop 2645 to run. If provided, the `cond` output is AND-ed with an additional 2646 condition ensuring the number of iterations executed is no greater than 2647 `maximum_iterations`. 2648 return_same_structure: If True, output has same structure as `loop_vars`. If 2649 eager execution is enabled, this is ignored (and always treated as True). 2650 2651 Returns: 2652 The output tensors for the loop variables after the loop. 2653 If `return_same_structure` is True, the return value has the same 2654 structure as `loop_vars`. 2655 If `return_same_structure` is False, the return value is a Tensor, 2656 TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list 2657 otherwise. 2658 2659 Raises: 2660 TypeError: if `cond` or `body` is not callable. 2661 ValueError: if `loop_vars` is empty. 2662 2663 Example: 2664 2665 ```python 2666 i = tf.constant(0) 2667 c = lambda i: tf.less(i, 10) 2668 b = lambda i: tf.add(i, 1) 2669 r = tf.while_loop(c, b, [i]) 2670 ``` 2671 2672 Example with nesting and a namedtuple: 2673 2674 ```python 2675 import collections 2676 Pair = collections.namedtuple('Pair', 'j, k') 2677 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 2678 c = lambda i, p: i < 10 2679 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 2680 ijk_final = tf.while_loop(c, b, ijk_0) 2681 ``` 2682 2683 Example using shape_invariants: 2684 2685 ```python 2686 i0 = tf.constant(0) 2687 m0 = tf.ones([2, 2]) 2688 c = lambda i, m: i < 10 2689 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 2690 tf.while_loop( 2691 c, b, loop_vars=[i0, m0], 2692 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 2693 ``` 2694 2695 Example which demonstrates non-strict semantics: In the following 2696 example, the final value of the counter `i` does not depend on `x`. So 2697 the `while_loop` can increment the counter parallel to updates of `x`. 2698 However, because the loop counter at one loop iteration depends 2699 on the value at the previous iteration, the loop counter itself cannot 2700 be incremented in parallel. Hence if we just want the final value of the 2701 counter (which we print on the line `print(sess.run(i))`), then 2702 `x` will never be incremented, but the counter will be updated on a 2703 single thread. Conversely, if we want the value of the output (which we 2704 print on the line `print(sess.run(out).shape)`), then the counter may be 2705 incremented on its own thread, while `x` can be incremented in 2706 parallel on a separate thread. In the extreme case, it is conceivable 2707 that the thread incrementing the counter runs until completion before 2708 `x` is incremented even a single time. The only thing that can never 2709 happen is that the thread updating `x` can never get ahead of the 2710 counter thread because the thread incrementing `x` depends on the value 2711 of the counter. 2712 2713 ```python 2714 import tensorflow as tf 2715 2716 n = 10000 2717 x = tf.constant(list(range(n))) 2718 c = lambda i, x: i < n 2719 b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1, 2720 [i], "x:")) 2721 i, out = tf.while_loop(c, b, (0, x)) 2722 with tf.compat.v1.Session() as sess: 2723 print(sess.run(i)) # prints [0] ... [9999] 2724 2725 # The following line may increment the counter and x in parallel. 2726 # The counter thread may get ahead of the other thread, but not the 2727 # other way around. So you may see things like 2728 # [9996] x:[9987] 2729 # meaning that the counter thread is on iteration 9996, 2730 # while the other thread is on iteration 9987 2731 print(sess.run(out).shape) 2732 ``` 2733 2734 """ 2735 if not callable(cond): 2736 raise TypeError("'cond' must be callable.") 2737 if not callable(body): 2738 raise TypeError("'body' must be callable.") 2739 if parallel_iterations < 1: 2740 raise TypeError("'parallel_iterations' must be a positive integer.") 2741 2742 # Always enable control flow v2 if building a function, regardless of toggle. 2743 executing_eagerly = context.executing_eagerly() 2744 if (util.EnableControlFlowV2(ops.get_default_graph()) and 2745 not executing_eagerly): 2746 return while_v2.while_loop( 2747 cond, 2748 body, 2749 loop_vars, 2750 shape_invariants=shape_invariants, 2751 parallel_iterations=parallel_iterations, 2752 maximum_iterations=maximum_iterations, 2753 name=name, 2754 return_same_structure=return_same_structure, 2755 back_prop=back_prop) 2756 2757 with ops.name_scope(name, "while", loop_vars): 2758 if not loop_vars: 2759 raise ValueError("'loop_vars' must be provided.") 2760 try_to_pack = (len(loop_vars) == 1 and not return_same_structure) 2761 if maximum_iterations is not None: 2762 maximum_iterations = ops.convert_to_tensor( 2763 maximum_iterations, name="maximum_iterations") 2764 if maximum_iterations.shape.ndims != 0: 2765 raise ValueError( 2766 "'maximum_iterations' must be a scalar. " 2767 f"Received shape: {maximum_iterations.shape}") 2768 2769 if executing_eagerly: 2770 counter = 0 2771 maximum_iterations = int(maximum_iterations.numpy()) 2772 else: 2773 counter = constant_op.constant( 2774 0, dtype=maximum_iterations.dtype, name="iteration_counter") 2775 orig_cond = cond 2776 orig_body = body 2777 if try_to_pack: 2778 loop_vars = (counter, loop_vars[0]) 2779 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 2780 math_ops.logical_and(i < maximum_iterations, orig_cond(lv))) 2781 body = lambda i, lv: (i + 1, orig_body(lv)) 2782 else: 2783 loop_vars = (counter, loop_vars) 2784 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 2785 math_ops.logical_and(i < maximum_iterations, orig_cond(*lv))) 2786 body = lambda i, lv: (i + 1, orig_body(*lv)) 2787 try_to_pack = False 2788 2789 if executing_eagerly: 2790 packed = False # whether the body result was packed into a 1-item tuple 2791 2792 loop_var_structure = nest.map_structure(type_spec.type_spec_from_value, 2793 list(loop_vars)) 2794 while cond(*loop_vars): 2795 loop_vars = body(*loop_vars) 2796 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)): 2797 packed = True 2798 loop_vars = (loop_vars,) 2799 nest.assert_same_structure(loop_var_structure, list(loop_vars)) 2800 2801 def convert(x): 2802 if isinstance(x, tensor_array_ops.TensorArray): 2803 return x 2804 return ops.convert_to_tensor(x) 2805 2806 loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True) 2807 if maximum_iterations is not None: 2808 return loop_vars[1] 2809 else: 2810 return loop_vars[0] if packed else loop_vars 2811 2812 if shape_invariants is not None: 2813 if maximum_iterations is not None: 2814 shape_invariants = (tensor_shape.TensorShape([]), shape_invariants) 2815 2816 nest.assert_same_structure( 2817 loop_vars, shape_invariants, expand_composites=False) 2818 shape_invariants = nest.map_structure( 2819 _get_shape_invariant, 2820 loop_vars, 2821 shape_invariants, 2822 expand_composites=False) 2823 2824 loop_context = WhileContext( 2825 maximum_iterations=maximum_iterations, 2826 parallel_iterations=parallel_iterations, 2827 back_prop=back_prop, 2828 swap_memory=swap_memory) 2829 # Only add non-nested loops to the collection. Any nested control flow will 2830 # be encapsulated in the root context. 2831 if loop_context.outer_context is None: 2832 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) 2833 result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, 2834 return_same_structure) 2835 if maximum_iterations is not None: 2836 return result[1] 2837 else: 2838 return result 2839 2840 2841# pylint: enable=redefined-outer-name 2842 2843 2844def _AsTensorList(x, p): 2845 """Return x as a list of Tensors or IndexedSlices. 2846 2847 For entries of `x` that are Operations, this returns an Identity of `p` 2848 with a dependency on the operation. 2849 2850 Args: 2851 x: A Tensor/IndexedSlices/Operation or a list or tuple of them. 2852 p: A Tensor to return for entries in `x` that are Operations. 2853 2854 Returns: 2855 A list of Tensors or IndexedSlices. 2856 """ 2857 if not isinstance(x, (list, _basetuple)): 2858 x = [x] 2859 2860 l = [] 2861 for v in x: 2862 if isinstance(v, ops.Operation): 2863 v = with_dependencies([v], p) 2864 v = ops.convert_to_tensor_or_composite(v) 2865 if isinstance(v, ops.Tensor): 2866 l.append(array_ops.identity(v)) 2867 else: 2868 l.append( 2869 ops.IndexedSlices( 2870 array_ops.identity(v.values), array_ops.identity(v.indices))) 2871 return l 2872 2873 2874def _CheckResults(a, b): 2875 assert len(a) == len(b), ( 2876 "Values returned by a() and b() must have the same length.") 2877 for x, y in zip(a, b): 2878 assert x.dtype == y.dtype, ( 2879 "Values returned by a() [%s] and b() [%s] must have " 2880 "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name)) 2881 2882 2883def with_dependencies(dependencies, output_tensor, name=None): 2884 """Produces the content of `output_tensor` only after `dependencies`. 2885 2886 In some cases, a user may want the output of an operation to be 2887 consumed externally only after some other dependencies have run 2888 first. This function ensures returns `output_tensor`, but only after all 2889 operations in `dependencies` have run. Note that this means that there is 2890 no guarantee that `output_tensor` will be evaluated after any `dependencies` 2891 have run. 2892 2893 See also `tf.tuple` and `tf.group`. 2894 2895 Args: 2896 dependencies: Iterable of operations to run before this op finishes. 2897 output_tensor: A `Tensor` or `IndexedSlices` that will be returned. 2898 name: (Optional) A name for this operation. 2899 2900 Returns: 2901 Same as `output_tensor`. 2902 2903 Raises: 2904 TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. 2905 """ 2906 if context.executing_eagerly(): 2907 return output_tensor 2908 with ops.name_scope(name, "control_dependency", 2909 list(dependencies) + [output_tensor]) as name: 2910 with ops.colocate_with(output_tensor): 2911 with ops.control_dependencies(dependencies): 2912 output_tensor = ops.convert_to_tensor_or_composite(output_tensor) 2913 if isinstance(output_tensor, ops.Tensor): 2914 return _Identity(output_tensor, name=name) 2915 else: 2916 return ops.IndexedSlices( 2917 _Identity(output_tensor.values, name=name), output_tensor.indices, 2918 output_tensor.dense_shape) 2919 2920 2921def _GroupControlDeps(dev, deps, name=None): 2922 with ops.control_dependencies(deps): 2923 if dev is None: 2924 return no_op(name=name) 2925 else: 2926 with ops.device(dev): 2927 return no_op(name=name) 2928 2929 2930# TODO(touts): Accept "inputs" as a list. 2931@tf_export("group") 2932def group(*inputs, **kwargs): 2933 """Create an op that groups multiple operations. 2934 2935 When this op finishes, all ops in `inputs` have finished. This op has no 2936 output. 2937 2938 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 2939 this method, as ops execute in the expected order thanks to automatic control 2940 dependencies.* Only use `tf.group` when working with v1 2941 `tf.Graph` code. 2942 2943 When operating in a v1-style graph context, ops are not executed in the same 2944 order as specified in the code; TensorFlow will attempt to execute ops in 2945 parallel or in an order convenient to the result it is computing. `tf.group` 2946 allows you to request that one or more results finish before execution 2947 continues. 2948 2949 `tf.group` creates a single op (of type `NoOp`), and then adds appropriate 2950 control dependencies. Thus, `c = tf.group(a, b)` will compute the same graph 2951 as this: 2952 2953 with tf.control_dependencies([a, b]): 2954 c = tf.no_op() 2955 2956 See also `tf.tuple` and 2957 `tf.control_dependencies`. 2958 2959 Args: 2960 *inputs: Zero or more tensors to group. 2961 name: A name for this operation (optional). 2962 2963 Returns: 2964 An Operation that executes all its inputs. 2965 2966 Raises: 2967 ValueError: If an unknown keyword argument is provided. 2968 """ 2969 if context.executing_eagerly(): 2970 return None 2971 name = kwargs.pop("name", None) 2972 if kwargs: 2973 raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys())) 2974 with ops.name_scope(name, "group_deps", inputs) as name: 2975 # Grouping no inputs means do nothing 2976 if not inputs: 2977 return no_op(name=name) 2978 2979 # Sorts *inputs according to their devices. 2980 ops_on_device = {} # device -> operations specified on the device. 2981 for inp in nest.flatten(inputs, expand_composites=True): 2982 if not hasattr(inp, "device"): 2983 raise TypeError("'inputs' should be zero or more (nested) Tensors. " 2984 f"Received '{inp}' with type '{type(inp)}'.") 2985 dev = inp.device 2986 if dev in ops_on_device: 2987 ops_on_device[dev].append(inp) 2988 else: 2989 ops_on_device[dev] = [inp] 2990 if len(ops_on_device) == 1: 2991 # 1-level tree. The root node is the returned NoOp node. 2992 (dev, deps), = ops_on_device.items() 2993 return _GroupControlDeps(dev, deps, name=name) 2994 2995 # 2-level tree. The root node is the returned NoOp node. 2996 # deps contains 1 NoOp node for each device. 2997 deps = [] 2998 2999 def device_key(dev): 3000 """A sort key that allows None to be compared to strings.""" 3001 return "" if dev is None else dev 3002 3003 for dev in sorted(ops_on_device, key=device_key): 3004 deps.append(_GroupControlDeps(dev, ops_on_device[dev])) 3005 3006 with ops.control_dependencies(deps): 3007 return no_op(name=name) 3008 3009 3010@tf_export("tuple", v1=[]) 3011@dispatch.add_dispatch_support 3012def tuple_v2(tensors, control_inputs=None, name=None): 3013 """Groups tensors together. 3014 3015 The returned tensors have the same value as the input tensors, but they 3016 are computed only after all the input tensors have been computed. 3017 3018 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 3019 this method, as ops execute in the expected order thanks to automatic control 3020 dependencies.* Only use `tf.tuple` when working with v1 `tf.Graph` code. 3021 3022 See also `tf.group` and `tf.control_dependencies`. 3023 3024 Args: 3025 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 3026 control_inputs: List of additional ops to finish before returning. 3027 name: (optional) A name to use as a `name_scope` for the operation. 3028 3029 Returns: 3030 Same as `tensors`. 3031 3032 Raises: 3033 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 3034 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 3035 objects. 3036 3037 """ 3038 return tuple(tensors=tensors, name=name, control_inputs=control_inputs) # pylint: disable=redefined-builtin 3039 3040 3041@tf_export(v1=["tuple"]) 3042@dispatch.add_dispatch_support 3043def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin 3044 """Group tensors together. 3045 3046 This creates a tuple of tensors with the same values as the `tensors` 3047 argument, except that the value of each tensor is only returned after the 3048 values of all tensors have been computed. 3049 3050 `control_inputs` contains additional ops that have to finish before this op 3051 finishes, but whose outputs are not returned. 3052 3053 This can be used as a "join" mechanism for parallel computations: all the 3054 argument tensors can be computed in parallel, but the values of any tensor 3055 returned by `tuple` are only available after all the parallel computations 3056 are done. 3057 3058 See also `tf.group` and 3059 `tf.control_dependencies`. 3060 3061 Args: 3062 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 3063 name: (optional) A name to use as a `name_scope` for the operation. 3064 control_inputs: List of additional ops to finish before returning. 3065 3066 Returns: 3067 Same as `tensors`. 3068 3069 Raises: 3070 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 3071 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 3072 objects. 3073 3074 """ 3075 if context.executing_eagerly(): 3076 return tensors 3077 with ops.name_scope(name, "tuple", tensors) as name: 3078 tensors = [ 3079 t if (isinstance(t, ops.Operation) or tensor_util.is_tf_type(t) or 3080 t is None) else ops.convert_to_tensor(t) for t in tensors 3081 ] 3082 gating_ops = [ 3083 t if isinstance(t, ops.Operation) else t.op 3084 for t in tensors 3085 if t is not None 3086 ] 3087 if control_inputs: 3088 for c in control_inputs: 3089 if isinstance(c, ops.Tensor): 3090 c = c.op 3091 elif not isinstance(c, ops.Operation): 3092 raise TypeError( 3093 "'control_inputs' must only contain Operation or Tensor. " 3094 f"Received: {type(c)}") 3095 gating_ops.append(c) 3096 # Note that in order to ensure ordering in the pbtxt, we must take care to 3097 # ensure the order here. 3098 gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops. 3099 if not gating_ops: 3100 raise ValueError("'tensors' must have at least one Tensor. " 3101 f"Received: {tensors}.") 3102 gate = group(*gating_ops) 3103 tpl = [] 3104 for t in tensors: 3105 if tensor_util.is_tf_type(t): 3106 tpl.append(with_dependencies([gate], t)) 3107 elif isinstance(t, ops.Operation): 3108 with ops.control_dependencies([gate]): 3109 tpl.append(group(t)) 3110 else: 3111 tpl.append(None) 3112 return tpl 3113 3114 3115def _assert_at_most_n_true(predicates, n, msg): 3116 """Returns an Assert op that checks that at most n predicates are True. 3117 3118 Args: 3119 predicates: list of bool scalar tensors. 3120 n: maximum number of true predicates allowed. 3121 msg: Error message. 3122 """ 3123 preds_c = array_ops.stack(predicates, name="preds_c") 3124 num_true_conditions = math_ops.reduce_sum( 3125 math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") 3126 condition = math_ops.less_equal(num_true_conditions, 3127 constant_op.constant(n, name="n_true_conds")) 3128 preds_names = ", ".join(getattr(p, "name", "?") for p in predicates) 3129 error_msg = [ 3130 "%s: more than %d conditions (%s) evaluated as True:" % 3131 (msg, n, preds_names), preds_c 3132 ] 3133 return Assert(condition, data=error_msg, summarize=len(predicates)) 3134 3135 3136def _case_create_default_action(predicates, actions): 3137 """Creates default action for a list of actions and their predicates. 3138 3139 It uses the input actions to select an arbitrary as default and makes sure 3140 that corresponding predicates have valid values. 3141 3142 Args: 3143 predicates: a list of bool scalar tensors 3144 actions: a list of callable objects which return tensors. 3145 3146 Returns: 3147 a callable 3148 """ 3149 k = len(predicates) - 1 # could pick any 3150 predicate, action = predicates[k], actions[k] 3151 other_predicates, other_actions = predicates[:k], actions[:k] 3152 3153 def default_action(): 3154 others_msg = ("Implementation error: " 3155 "selected default action #%d was called, but some of other " 3156 "predicates are True: " % k) 3157 default_msg = ("Input error: " 3158 "None of conditions evaluated as True:", 3159 array_ops.stack(predicates, name="preds_c")) 3160 with ops.control_dependencies([ 3161 _assert_at_most_n_true(other_predicates, n=0, msg=others_msg), 3162 Assert(predicate, data=default_msg) 3163 ]): 3164 return action() 3165 3166 return default_action, other_predicates, other_actions 3167 3168 3169def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name, 3170 allow_python_preds): 3171 """Verifies input arguments for the case function. 3172 3173 Args: 3174 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 3175 callable which returns a list of tensors. 3176 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3177 name: A name for the case operation. 3178 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 3179 addition to boolean Tensors 3180 3181 Raises: 3182 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3183 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3184 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3185 callable. 3186 3187 Returns: 3188 a tuple <list of scalar bool tensors, list of callables>. 3189 """ 3190 if not isinstance(pred_fn_pairs, (list, _basetuple, dict)): 3191 raise TypeError("'pred_fn_pairs' must be a list, tuple, or dict. " 3192 f"Received: {type(pred_fn_pairs)}") 3193 3194 if isinstance(pred_fn_pairs, collections.OrderedDict): 3195 pred_fn_pairs = pred_fn_pairs.items() 3196 elif isinstance(pred_fn_pairs, dict): 3197 if context.executing_eagerly(): 3198 # No name to sort on in eager mode. Use dictionary traversal order, 3199 # which is nondeterministic in versions of Python < 3.6 3200 if not exclusive: 3201 raise ValueError("Unordered dictionaries are not supported for the " 3202 "'pred_fn_pairs' argument when `exclusive=False` and " 3203 "eager mode is enabled.") 3204 pred_fn_pairs = list(pred_fn_pairs.items()) 3205 else: 3206 pred_fn_pairs = sorted( 3207 pred_fn_pairs.items(), key=lambda item: item[0].name) 3208 if not exclusive: 3209 logging.warn( 3210 "%s: An unordered dictionary of predicate/fn pairs was " 3211 "provided, but exclusive=False. The order of conditional " 3212 "tests is deterministic but not guaranteed.", name) 3213 for pred_fn_pair in pred_fn_pairs: 3214 if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2: 3215 raise TypeError("Each entry in 'pred_fn_pairs' must be a 2-tuple. " 3216 f"Received {pred_fn_pair}.") 3217 pred, fn = pred_fn_pair 3218 3219 if isinstance(pred, ops.Tensor): 3220 if pred.dtype != dtypes.bool: 3221 raise TypeError("pred must be Tensor of type bool: %s" % pred.name) 3222 elif not allow_python_preds: 3223 raise TypeError("pred must be a Tensor, got: %s" % pred) 3224 elif not isinstance(pred, bool): 3225 raise TypeError("pred must be a Tensor or bool, got: %s" % pred) 3226 3227 if not callable(fn): 3228 raise TypeError("fn for pred %s must be callable." % pred.name) 3229 3230 predicates, actions = zip(*pred_fn_pairs) 3231 return predicates, actions 3232 3233 3234def _case_helper(cond_fn, 3235 pred_fn_pairs, 3236 default, 3237 exclusive, 3238 name, 3239 allow_python_preds=False, 3240 **cond_kwargs): 3241 """Implementation of case that allows for different cond functions. 3242 3243 Args: 3244 cond_fn: method that has signature and semantics of `cond` above. 3245 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 3246 callable which returns a list of tensors. 3247 default: Optional callable that returns a list of tensors. 3248 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3249 name: A name for this operation (optional). 3250 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 3251 addition to boolean Tensors 3252 **cond_kwargs: keyword arguments that will be passed to `cond_fn`. 3253 3254 Returns: 3255 The tensors returned by the first pair whose predicate evaluated to True, or 3256 those returned by `default` if none does. 3257 3258 Raises: 3259 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3260 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3261 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3262 callable. 3263 """ 3264 predicates, actions = _case_verify_and_canonicalize_args( 3265 pred_fn_pairs, exclusive, name, allow_python_preds) 3266 with ops.name_scope(name, "case", [predicates]): 3267 if default is None: 3268 default, predicates, actions = _case_create_default_action( 3269 predicates, actions) 3270 fn = default 3271 # To eval conditions in direct order we create nested conditions in reverse: 3272 # cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...)) 3273 for predicate, action in reversed(list(zip(predicates, actions))): 3274 fn = functools.partial( 3275 cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs) 3276 if exclusive: 3277 with ops.control_dependencies([ 3278 _assert_at_most_n_true( 3279 predicates, n=1, msg="Input error: exclusive=True") 3280 ]): 3281 return fn() 3282 else: 3283 return fn() 3284 3285 3286def _indexed_case_verify_and_canonicalize_args(branch_fns, default, 3287 branch_index): 3288 """Verifies input arguments for the case function. 3289 3290 Args: 3291 branch_fns: Dict or list of pairs of an `int` and a callable which 3292 returns a list of tensors. 3293 default: Optional callable that returns a list of tensors. 3294 branch_index: Optional int `Tensor`, which selects for the corresponding 3295 pred_fn_pair. 3296 3297 Raises: 3298 TypeError: If `branch_fns` is not a list/dictionary. 3299 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3300 callables. 3301 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3302 callable. 3303 3304 Returns: 3305 branch_fns: validated list of callables for each branch (default last). 3306 """ 3307 if not isinstance(branch_index, ops.Tensor): 3308 raise TypeError("'branch_index' must be a Tensor, got {}".format( 3309 type(branch_index))) 3310 if not branch_index.dtype.is_integer: 3311 raise TypeError("'branch_index' must be an integer Tensor, got {}".format( 3312 branch_index.dtype)) 3313 3314 if not branch_fns: 3315 raise ValueError("Must provide at least one item in 'branch_fns'") 3316 if not isinstance(branch_fns, (list, _basetuple, dict)): 3317 raise TypeError("'branch_fns' must be a list, tuple, or dict") 3318 3319 if isinstance(branch_fns, dict): 3320 branch_fns = branch_fns.items() 3321 3322 if all(callable(fn) for fn in branch_fns): 3323 branch_fns = list(enumerate(branch_fns)) 3324 3325 for key_fn_pair in branch_fns: 3326 if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2: 3327 raise TypeError("Each entry in 'branch_fns' must be a 2-tuple. " 3328 f"Received {key_fn_pair}.") 3329 key, branch_fn = key_fn_pair 3330 3331 if not isinstance(key, int): 3332 raise TypeError("key must be a Python `int`, got {}".format(type(key))) 3333 3334 if not callable(branch_fn): 3335 raise TypeError("fn for key {} must be callable.".format(key)) 3336 3337 keys = [p[0] for p in branch_fns] 3338 if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys): 3339 raise ValueError( 3340 "branch indices (keys) must form contiguous range of [0 to {}) but " 3341 "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys))))) 3342 actions = [p[1] for p in sorted(branch_fns)] 3343 if default is not None: 3344 actions.append(default) 3345 return actions 3346 3347 3348def _indexed_case_helper(branch_fns, 3349 default, 3350 branch_index, 3351 name, 3352 lower_using_switch_merge=None): 3353 """Implementation of case that emits the n-way indexed Case op. 3354 3355 Args: 3356 branch_fns: Dict or list of pairs of a boolean scalar tensor, and a 3357 callable which returns a list of tensors. 3358 default: Optional callable that returns a list of tensors. 3359 branch_index: Optional int `Tensor`, which selects for the corresponding 3360 pred_fn_pair. 3361 name: A name for this operation (optional). 3362 lower_using_switch_merge: Lower this op using switch merge ops (optional). 3363 3364 Returns: 3365 The tensors returned by the pair whose key matched branch_index, or 3366 those returned by `default` if none does. 3367 3368 Raises: 3369 TypeError: If `branch_fns` is not a list/dictionary. 3370 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3371 callables. 3372 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3373 callable. 3374 """ 3375 branch_fns = _indexed_case_verify_and_canonicalize_args( 3376 branch_fns, default, branch_index) 3377 with ops.name_scope(name, "case", [branch_index]): 3378 if context.executing_eagerly() and not hasattr(branch_index, "graph"): 3379 branch_index = array_ops.where( 3380 math_ops.less(branch_index, 0) 3381 | math_ops.greater_equal(branch_index, len(branch_fns)), 3382 len(branch_fns) - 1, branch_index) 3383 return branch_fns[int(branch_index)]() 3384 return cond_v2.indexed_case( 3385 branch_index, 3386 branch_fns, 3387 lower_using_switch_merge=lower_using_switch_merge) 3388 3389 3390@tf_export("case", v1=[]) 3391@dispatch.add_dispatch_support 3392def case_v2(pred_fn_pairs, 3393 default=None, 3394 exclusive=False, 3395 strict=False, 3396 name="case"): 3397 """Create a case operation. 3398 3399 See also `tf.switch_case`. 3400 3401 The `pred_fn_pairs` parameter is a list of pairs of size N. 3402 Each pair contains a boolean scalar tensor and a python callable that 3403 creates the tensors to be returned if the boolean evaluates to True. 3404 `default` is a callable generating a list of tensors. All the callables 3405 in `pred_fn_pairs` as well as `default` (if provided) should return the same 3406 number and types of tensors. 3407 3408 If `exclusive==True`, all predicates are evaluated, and an exception is 3409 thrown if more than one of the predicates evaluates to `True`. 3410 If `exclusive==False`, execution stops at the first predicate which 3411 evaluates to True, and the tensors generated by the corresponding function 3412 are returned immediately. If none of the predicates evaluate to True, this 3413 operation returns the tensors generated by `default`. 3414 3415 `tf.case` supports nested structures as implemented in 3416 `tf.nest`. All of the callables must return the same (possibly nested) value 3417 structure of lists, tuples, and/or named tuples. Singleton lists and tuples 3418 form the only exceptions to this: when returned by a callable, they are 3419 implicitly unpacked to single values. This behavior is disabled by passing 3420 `strict=True`. 3421 3422 @compatibility(v2) 3423 `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and 3424 tf.Variable are no longer hashable in v2, so cannot be used as a key for a 3425 dictionary. Please use a list or a tuple instead. 3426 @end_compatibility 3427 3428 3429 **Example 1:** 3430 3431 Pseudocode: 3432 3433 ``` 3434 if (x < y) return 17; 3435 else return 23; 3436 ``` 3437 3438 Expressions: 3439 3440 ```python 3441 f1 = lambda: tf.constant(17) 3442 f2 = lambda: tf.constant(23) 3443 r = tf.case([(tf.less(x, y), f1)], default=f2) 3444 ``` 3445 3446 **Example 2:** 3447 3448 Pseudocode: 3449 3450 ``` 3451 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 3452 if (x < y) return 17; 3453 else if (x > z) return 23; 3454 else return -1; 3455 ``` 3456 3457 Expressions: 3458 3459 ```python 3460 def f1(): return tf.constant(17) 3461 def f2(): return tf.constant(23) 3462 def f3(): return tf.constant(-1) 3463 r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)], 3464 default=f3, exclusive=True) 3465 ``` 3466 3467 Args: 3468 pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which 3469 returns a list of tensors. 3470 default: Optional callable that returns a list of tensors. 3471 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3472 strict: A boolean that enables/disables 'strict' mode; see above. 3473 name: A name for this operation (optional). 3474 3475 Returns: 3476 The tensors returned by the first pair whose predicate evaluated to True, or 3477 those returned by `default` if none does. 3478 3479 Raises: 3480 TypeError: If `pred_fn_pairs` is not a list/tuple. 3481 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3482 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3483 callable. 3484 """ 3485 return _case_helper( 3486 cond, 3487 pred_fn_pairs, 3488 default, 3489 exclusive, 3490 name, 3491 allow_python_preds=False, 3492 strict=strict) 3493 3494 3495@tf_export(v1=["case"]) 3496@dispatch.add_dispatch_support 3497def case(pred_fn_pairs, 3498 default=None, 3499 exclusive=False, 3500 strict=False, 3501 name="case"): 3502 """Create a case operation. 3503 3504 See also `tf.switch_case`. 3505 3506 The `pred_fn_pairs` parameter is a dict or list of pairs of size N. 3507 Each pair contains a boolean scalar tensor and a python callable that 3508 creates the tensors to be returned if the boolean evaluates to True. 3509 `default` is a callable generating a list of tensors. All the callables 3510 in `pred_fn_pairs` as well as `default` (if provided) should return the same 3511 number and types of tensors. 3512 3513 If `exclusive==True`, all predicates are evaluated, and an exception is 3514 thrown if more than one of the predicates evaluates to `True`. 3515 If `exclusive==False`, execution stops at the first predicate which 3516 evaluates to True, and the tensors generated by the corresponding function 3517 are returned immediately. If none of the predicates evaluate to True, this 3518 operation returns the tensors generated by `default`. 3519 3520 `tf.case` supports nested structures as implemented in 3521 `tf.nest`. All of the callables must return the same (possibly nested) value 3522 structure of lists, tuples, and/or named tuples. Singleton lists and tuples 3523 form the only exceptions to this: when returned by a callable, they are 3524 implicitly unpacked to single values. This behavior is disabled by passing 3525 `strict=True`. 3526 3527 If an unordered dictionary is used for `pred_fn_pairs`, the order of the 3528 conditional tests is not guaranteed. However, the order is guaranteed to be 3529 deterministic, so that variables created in conditional branches are created 3530 in fixed order across runs. 3531 3532 @compatibility(eager) 3533 Unordered dictionaries are not supported in eager mode when `exclusive=False`. 3534 Use a list of tuples instead. 3535 @end_compatibility 3536 3537 3538 **Example 1:** 3539 3540 Pseudocode: 3541 3542 ``` 3543 if (x < y) return 17; 3544 else return 23; 3545 ``` 3546 3547 Expressions: 3548 3549 ```python 3550 f1 = lambda: tf.constant(17) 3551 f2 = lambda: tf.constant(23) 3552 r = tf.case([(tf.less(x, y), f1)], default=f2) 3553 ``` 3554 3555 **Example 2:** 3556 3557 Pseudocode: 3558 3559 ``` 3560 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 3561 if (x < y) return 17; 3562 else if (x > z) return 23; 3563 else return -1; 3564 ``` 3565 3566 Expressions: 3567 3568 ```python 3569 def f1(): return tf.constant(17) 3570 def f2(): return tf.constant(23) 3571 def f3(): return tf.constant(-1) 3572 r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2}, 3573 default=f3, exclusive=True) 3574 ``` 3575 3576 Args: 3577 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 3578 callable which returns a list of tensors. 3579 default: Optional callable that returns a list of tensors. 3580 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3581 strict: A boolean that enables/disables 'strict' mode; see above. 3582 name: A name for this operation (optional). 3583 3584 Returns: 3585 The tensors returned by the first pair whose predicate evaluated to True, or 3586 those returned by `default` if none does. 3587 3588 Raises: 3589 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3590 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3591 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3592 callable. 3593 """ 3594 return _case_helper( 3595 cond, 3596 pred_fn_pairs, 3597 default, 3598 exclusive, 3599 name, 3600 allow_python_preds=False, 3601 strict=strict) 3602 3603 3604@tf_export("switch_case") 3605def switch_case(branch_index, 3606 branch_fns, 3607 default=None, 3608 name="switch_case"): 3609 """Create a switch/case operation, i.e. an integer-indexed conditional. 3610 3611 See also `tf.case`. 3612 3613 This op can be substantially more efficient than `tf.case` when exactly one 3614 branch will be selected. `tf.switch_case` is more like a C++ switch/case 3615 statement than `tf.case`, which is more like an if/elif/elif/else chain. 3616 3617 The `branch_fns` parameter is either a dict from `int` to callables, or list 3618 of (`int`, callable) pairs, or simply a list of callables (in which case the 3619 index is implicitly the key). The `branch_index` `Tensor` is used to select an 3620 element in `branch_fns` with matching `int` key, falling back to `default` 3621 if none match, or `max(keys)` if no `default` is provided. The keys must form 3622 a contiguous set from `0` to `len(branch_fns) - 1`. 3623 3624 `tf.switch_case` supports nested structures as implemented in `tf.nest`. All 3625 callables must return the same (possibly nested) value structure of lists, 3626 tuples, and/or named tuples. 3627 3628 **Example:** 3629 3630 Pseudocode: 3631 3632 ```c++ 3633 switch (branch_index) { // c-style switch 3634 case 0: return 17; 3635 case 1: return 31; 3636 default: return -1; 3637 } 3638 ``` 3639 or 3640 ```python 3641 branches = {0: lambda: 17, 1: lambda: 31} 3642 branches.get(branch_index, lambda: -1)() 3643 ``` 3644 3645 Expressions: 3646 3647 ```python 3648 def f1(): return tf.constant(17) 3649 def f2(): return tf.constant(31) 3650 def f3(): return tf.constant(-1) 3651 r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3) 3652 # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3}) 3653 ``` 3654 3655 Args: 3656 branch_index: An int Tensor specifying which of `branch_fns` should be 3657 executed. 3658 branch_fns: A `dict` mapping `int`s to callables, or a `list` of 3659 (`int`, callable) pairs, or simply a list of callables (in which case the 3660 index serves as the key). Each callable must return a matching structure 3661 of tensors. 3662 default: Optional callable that returns a structure of tensors. 3663 name: A name for this operation (optional). 3664 3665 Returns: 3666 The tensors returned by the callable identified by `branch_index`, or those 3667 returned by `default` if no key matches and `default` was provided, or those 3668 returned by the max-keyed `branch_fn` if no `default` is provided. 3669 3670 Raises: 3671 TypeError: If `branch_fns` is not a list/dictionary. 3672 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3673 callables. 3674 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3675 callable. 3676 """ 3677 return _indexed_case_helper(branch_fns, default, branch_index, name) 3678 3679 3680@tf_export("__internal__.execute_fn_for_device", v1=[]) 3681def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"): 3682 """Executes one of the provided callables based on the device placement. 3683 3684 This API is used when the implementations for high level function depend on 3685 the underlying device placement. It takes a dictionary of device type to 3686 callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of 3687 the device where to run this op matches the key in 'device_branch_fns', 3688 the corresponding callable is executed, falling back to 'default_fn' if none 3689 matches. 3690 3691 **Example:** 3692 ```python 3693 def f1(): return tf.constant(1) 3694 def f2(): return tf.constant(2) 3695 r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1) 3696 ``` 3697 'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on 3698 any other device types. 3699 3700 3701 Args: 3702 device_branch_fns: a dictionary of device types to the callables. Each 3703 callable must return a matching structure of tensors. 3704 default_fn: fallback callable when the underlying device does not match any 3705 key in the 'device_branch_fns'. 3706 name: A name for this operation (optional). 3707 3708 Returns: 3709 The tensors returned by the callable identified by device type during 3710 execution, or those returned by 'default_fn' if no key matches. 3711 """ 3712 # Always execute the default fn for XLA to avoid complicated graph by case op. 3713 # see more discussions in b/167276293. 3714 is_in_xla = util.GraphOrParentsInXlaContext(ops.get_default_graph()) 3715 if is_in_xla: 3716 return default_fn() 3717 device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()} 3718 branch_fns = list(device_branch_fns_upper.values()) 3719 devices = list(device_branch_fns_upper.keys()) 3720 device_index = gen_functional_ops.device_index(device_names=devices) 3721 return _indexed_case_helper( 3722 branch_fns, 3723 default_fn, 3724 device_index, 3725 name, 3726 lower_using_switch_merge=False) 3727 3728 3729class XLAControlFlowContext(ControlFlowContext): 3730 """Base class for XLA and TPU control flow contexts.""" 3731 3732 def __init__(self): 3733 super(XLAControlFlowContext, self).__init__() 3734 self._name = "XLAControlFlowContext" 3735 3736 def to_control_flow_context_def(self, context_def, export_scope=None): 3737 # pylint: disable=useless-super-delegation 3738 # NOTE(slebedev): the method is required by `ControlFlowContext`. 3739 super(XLAControlFlowContext, 3740 self).to_control_flow_context_def(context_def, export_scope) 3741 3742 def IsXLAContext(self): 3743 return True 3744 3745 def AddOp(self, _): 3746 pass 3747 3748 def AddValue(self, x): 3749 return x 3750 3751 def RequiresUniqueFunctionRetracing(self): 3752 """Returns whether the tf.function should be retraced if the context changes. 3753 """ 3754 return False 3755 3756 3757@tf_export("__internal__.get_enclosing_xla_context", v1=[]) 3758def get_enclosing_xla_context(): 3759 """Recursively find and return the XLAControlFlowContext.""" 3760 graph = ops.get_default_graph() 3761 while graph is not None: 3762 # pylint: disable=protected-access 3763 context_ = graph._get_control_flow_context() 3764 # pylint: enable=protected-access 3765 while context_ is not None: 3766 if isinstance(context_, XLAControlFlowContext): 3767 return context_ 3768 context_ = context_.outer_context 3769 # This may be a FuncGraph due to defuns or v2 control flow. We need to 3770 # find the original graph with the XLAControlFlowContext. 3771 graph = getattr(graph, "outer_graph", None) 3772 return None 3773 3774 3775def from_control_flow_context_def(context_def, import_scope=None): 3776 """Deserializes `context_def` into the appropriate ControlFlowContext. 3777 3778 Args: 3779 context_def: ControlFlowContextDef proto 3780 import_scope: Optional `string`. Name scope to add. 3781 3782 Returns: 3783 A ControlFlowContext subclass 3784 """ 3785 if context_def.HasField("cond_ctxt"): 3786 return CondContext.from_proto( 3787 context_def.cond_ctxt, import_scope=import_scope) 3788 if context_def.HasField("while_ctxt"): 3789 return WhileContext.from_proto( 3790 context_def.while_ctxt, import_scope=import_scope) 3791 raise NotImplementedError("Unknown ControlFlowContextDef field: %s" % 3792 context_def.WhichOneof("ctxt")) 3793 3794 3795ops.register_proto_function( 3796 ops.GraphKeys.COND_CONTEXT, 3797 proto_type=control_flow_pb2.CondContextDef, 3798 to_proto=CondContext.to_proto, 3799 from_proto=CondContext.from_proto) 3800 3801ops.register_proto_function( 3802 ops.GraphKeys.WHILE_CONTEXT, 3803 proto_type=control_flow_pb2.WhileContextDef, 3804 to_proto=WhileContext.to_proto, 3805 from_proto=WhileContext.from_proto) 3806