1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Control Flow Operations. 16 17See the [autograph](https://www.tensorflow.org/guide/autograph) guide. 18""" 19# pylint: disable=g-bad-name 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import abc 25import collections 26import functools 27 28import six 29 30from tensorflow.core.framework import attr_value_pb2 31from tensorflow.core.protobuf import control_flow_pb2 32from tensorflow.python.eager import context 33from tensorflow.python.framework import composite_tensor 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.framework import tensor_spec 40from tensorflow.python.framework import tensor_util 41from tensorflow.python.framework import type_spec 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_util as util 44from tensorflow.python.ops import gen_array_ops 45from tensorflow.python.ops import gen_control_flow_ops 46from tensorflow.python.ops import gen_functional_ops 47from tensorflow.python.ops import gen_logging_ops 48from tensorflow.python.ops import gen_math_ops 49from tensorflow.python.ops import math_ops 50from tensorflow.python.ops import tensor_array_ops 51# go/tf-wildcard-import 52# pylint: disable=wildcard-import,undefined-variable 53from tensorflow.python.ops.gen_control_flow_ops import * 54# pylint: enable=wildcard-import 55from tensorflow.python.platform import tf_logging as logging 56from tensorflow.python.util import compat 57from tensorflow.python.util import deprecation 58from tensorflow.python.util import dispatch 59from tensorflow.python.util import nest 60from tensorflow.python.util import tf_should_use 61from tensorflow.python.util.lazy_loader import LazyLoader 62from tensorflow.python.util.tf_export import tf_export 63 64# This is to avoid a circular dependency: 65# cond_v2 -> gradients_util -> control_flow_ops 66cond_v2 = LazyLoader("cond_v2", globals(), 67 "tensorflow.python.ops.cond_v2") 68 69# This is to avoid circular dependencies: 70# while_v2 -> control_flow_ops 71# while_v2 -> gradients_util -> control_flow_ops 72while_v2 = LazyLoader("while_v2", globals(), 73 "tensorflow.python.ops.while_v2") 74 75# def_function also uses cond 76def_function = LazyLoader( 77 "def_function", globals(), 78 "tensorflow.python.eager.def_function") 79 80 81# We override the 'tuple' for a control flow op, so we keep python's 82# existing 'tuple' for later use in this module. 83_basetuple = tuple 84 85 86def _summarize_eager(tensor, summarize=None): 87 """Returns a summarized string representation of eager `tensor`. 88 89 Args: 90 tensor: EagerTensor to summarize 91 summarize: Include these many first elements of `array` 92 """ 93 # Emulate the behavior of Tensor::SummarizeValue() 94 if summarize is None: 95 summarize = 3 96 elif summarize < 0: 97 summarize = array_ops.size(tensor) 98 99 # reshape((-1,)) is the fastest way to get a flat array view 100 if tensor._rank(): # pylint: disable=protected-access 101 flat = tensor.numpy().reshape((-1,)) 102 lst = [str(x) for x in flat[:summarize]] 103 if len(lst) < flat.size: 104 lst.append("...") 105 else: 106 # tensor.numpy() returns a scalar for zero dimensional arrays 107 if gen_math_ops.not_equal(summarize, 0): 108 lst = [str(tensor.numpy())] 109 else: 110 lst = [] 111 112 return ", ".join(lst) 113 114 115# pylint: disable=protected-access 116 117 118# Assert and Print are special symbols in python, so we must 119# use an upper-case version of them. 120@tf_export("debugging.Assert", "Assert") 121@dispatch.add_dispatch_support 122@tf_should_use.should_use_result 123def Assert(condition, data, summarize=None, name=None): 124 """Asserts that the given condition is true. 125 126 If `condition` evaluates to false, print the list of tensors in `data`. 127 `summarize` determines how many entries of the tensors to print. 128 129 Args: 130 condition: The condition to evaluate. 131 data: The tensors to print out when condition is false. 132 summarize: Print this many entries of each tensor. 133 name: A name for this operation (optional). 134 135 Returns: 136 assert_op: An `Operation` that, when executed, raises a 137 `tf.errors.InvalidArgumentError` if `condition` is not true. 138 @compatibility(eager) 139 returns None 140 @end_compatibility 141 142 Raises: 143 @compatibility(TF1) 144 When in TF V1 mode (that is, outside `tf.function`) Assert needs a control 145 dependency on the output to ensure the assertion executes: 146 147 ```python 148 # Ensure maximum element of x is smaller or equal to 1 149 assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x]) 150 with tf.control_dependencies([assert_op]): 151 ... code using x ... 152 ``` 153 154 @end_compatibility 155 """ 156 if context.executing_eagerly(): 157 if not condition: 158 xs = ops.convert_n_to_tensor(data) 159 data_str = [_summarize_eager(x, summarize) for x in xs] 160 raise errors.InvalidArgumentError( 161 node_def=None, 162 op=None, 163 message="Expected '%s' to be true. Summarized data: %s" % 164 (condition, "\n".join(data_str))) 165 return 166 167 with ops.name_scope(name, "Assert", [condition, data]) as name: 168 xs = ops.convert_n_to_tensor(data) 169 if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs): 170 # As a simple heuristic, we assume that string and int32 are 171 # on host to avoid the need to use cond. If it is not case, 172 # we will pay the price copying the tensor to host memory. 173 return gen_logging_ops._assert(condition, data, summarize, name="Assert") 174 else: 175 condition = ops.convert_to_tensor(condition, name="Condition") 176 177 def true_assert(): 178 return gen_logging_ops._assert( 179 condition, data, summarize, name="Assert") 180 181 guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard") 182 if context.executing_eagerly(): 183 return 184 return guarded_assert.op 185 186 187def _Identity(data, name=None): 188 """Return a tensor with the same shape and contents as the input tensor. 189 190 Args: 191 data: A Tensor. 192 name: A name for this operation (optional). 193 194 Returns: 195 A Tensor with the same type and value as the input Tensor. 196 """ 197 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 198 if isinstance(data, ops.Tensor): 199 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 200 return gen_array_ops.ref_identity(data, name=name) 201 else: 202 return array_ops.identity(data, name=name) 203 elif isinstance(data, composite_tensor.CompositeTensor): 204 return nest.map_structure(_Identity, data, expand_composites=True) 205 else: 206 raise TypeError("Type %s not supported" % type(data)) 207 208 209def _NextIteration(data, name=None): 210 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 211 if isinstance(data, ops.Tensor): 212 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 213 return ref_next_iteration(data, name=name) 214 else: 215 return next_iteration(data, name=name) 216 elif isinstance(data, composite_tensor.CompositeTensor): 217 return nest.map_structure(_NextIteration, data, expand_composites=True) 218 else: 219 raise TypeError("Type %s not supported" % type(data)) 220 221 222def _Enter(data, 223 frame_name, 224 is_constant=False, 225 parallel_iterations=10, 226 use_ref=True, 227 use_input_shape=True, 228 name=None): 229 """Creates or finds a child frame, and makes `data` available to it. 230 231 The unique `frame_name` is used by the `Executor` to identify frames. If 232 `is_constant` is true, `data` is a constant in the child frame; otherwise 233 it may be changed in the child frame. At most `parallel_iterations` 234 iterations are run in parallel in the child frame. 235 236 Args: 237 data: The tensor to be made available to the child frame. 238 frame_name: The name of the child frame. 239 is_constant: If true, the output is constant within the child frame. 240 parallel_iterations: The number of iterations allowed to run in parallel. 241 use_ref: If true, use ref_enter if data is of ref type. 242 use_input_shape: If true, set the result's shape based on data's shape. 243 name: A name for this operation (optional). 244 245 Returns: 246 The same tensor as `data`. 247 """ 248 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 249 if isinstance(data, ops.Tensor): 250 if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access 251 result = gen_control_flow_ops.ref_enter( 252 data, frame_name, is_constant, parallel_iterations, name=name) 253 else: 254 result = gen_control_flow_ops.enter( 255 data, frame_name, is_constant, parallel_iterations, name=name) 256 if use_input_shape: 257 result.set_shape(data.get_shape()) 258 return result 259 elif isinstance(data, composite_tensor.CompositeTensor): 260 261 def enter_component(t): 262 return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref, 263 use_input_shape) 264 265 return nest.map_structure(enter_component, data, expand_composites=True) 266 else: 267 raise TypeError("Type %s not supported" % type(data)) 268 269 270def exit(data, name=None): # pylint: disable=redefined-builtin 271 """Exits the current frame to its parent frame. 272 273 Exit makes its input `data` available to the parent frame. 274 275 Args: 276 data: The tensor to be made available to the parent frame. 277 name: A name for this operation (optional). 278 279 Returns: 280 The same tensor as `data`. 281 """ 282 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 283 if isinstance(data, ops.Tensor): 284 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 285 return gen_control_flow_ops.ref_exit(data, name) 286 else: 287 return gen_control_flow_ops._exit(data, name) 288 elif isinstance(data, composite_tensor.CompositeTensor): 289 return nest.map_structure(exit, data, expand_composites=True) 290 else: 291 raise TypeError("Type %s not supported" % type(data)) 292 293 294def switch(data, pred, dtype=None, name=None): 295 """Forwards `data` to an output determined by `pred`. 296 297 If `pred` is false, the `data` input is forwarded to the first output. 298 Otherwise, the data goes to the second output. 299 300 This op handles `Tensor`s and `IndexedSlices`. 301 302 Args: 303 data: The tensor to be forwarded to the appropriate output. 304 pred: A scalar that specifies which output port will receive data. 305 dtype: Optional element type for the returned tensor. If missing, the type 306 is inferred from the type of `value`. 307 name: A name for this operation (optional). 308 309 Returns: 310 `(output_false, output_true)`: If `pred` is true, data will be forwarded 311 to `output_true`, otherwise it goes to `output_false`. 312 """ 313 with ops.name_scope(name, "Switch", [data, pred]) as name: 314 data = ops.internal_convert_to_tensor_or_composite( 315 data, dtype=dtype, name="data", as_ref=True) 316 pred = ops.convert_to_tensor(pred, name="pred") 317 if isinstance(data, ops.Tensor): 318 return gen_control_flow_ops.switch(data, pred, name=name) 319 else: 320 if not isinstance(data, composite_tensor.CompositeTensor): 321 raise TypeError("Type %s not supported" % type(data)) 322 tensors = nest.flatten(data, expand_composites=True) 323 mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors] 324 mapped_f, mapped_t = zip(*mapped) 325 return (nest.pack_sequence_as(data, mapped_f, expand_composites=True), 326 nest.pack_sequence_as(data, mapped_t, expand_composites=True)) 327 328 329def _SwitchRefOrTensor(data, pred, name="Switch"): 330 """Forwards `data` to an output determined by `pred`. 331 332 If `pred` is false, the `data` input is forwarded to the first output. 333 Otherwise, the data goes to the second output. 334 335 This op handles `Tensor`s and `IndexedSlices`. 336 337 Args: 338 data: The tensor to be forwarded to the appropriate output. 339 pred: A scalar that specifies which output port will receive data. 340 name: A name for this operation (optional). 341 342 Returns: 343 `(output_false, output_true)`: If `pred` is true, data will be forwarded to 344 `output_true`, otherwise it goes to `output_false`. 345 346 Raises: 347 TypeError: if data is not a Tensor or IndexedSlices 348 """ 349 data = ops.convert_to_tensor_or_composite(data, name="data") 350 # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below 351 # addresses the following scenario. 352 # 353 # Assume you execute Optimizer.apply_gradients() in a branch of a cond(). 354 # 355 # 1. The update op is created inside a `with ops.colocate(var):` block 356 # 357 # 2. Some tensor `data` is captured and a switch is created in a 358 # `with ops.colocate_with(data):` block. 359 # 360 # with ops.colocate_with(var): 361 # with ops.colocate_with(data): 362 # op = ... 363 # 364 # var and data may be pinned to different devices, so we want to ops 365 # created within ops.colocate_with(data) to ignore the existing stack. 366 with ops.colocate_with(data, ignore_existing=True): 367 if isinstance(data, ops.Tensor): 368 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 369 return ref_switch(data, pred, name=name) 370 return switch(data, pred, name=name) 371 372 373def merge(inputs, name=None): 374 """Returns the value of an available element of `inputs`. 375 376 This op tests each of the tensors in `inputs` in turn to determine if any of 377 them is available. If it finds an available tensor, it returns it and its 378 index in `inputs`. 379 380 It is an error if more than one tensor in `inputs` is available. If no tensor 381 in `inputs` is available, the returned tensor and index are not set. 382 383 This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of 384 `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices 385 before merging. 386 387 Args: 388 inputs: The input tensors, at most one of which is available. 389 name: A name for this operation (optional). 390 391 Returns: 392 A tuple containing the chosen input tensor and its index in `inputs`. 393 394 Raises: 395 ValueError: If any of the inputs is None, or inputs are IndexedSlices and 396 some but not all have a dense_shape property. 397 """ 398 if any(inp is None for inp in inputs): 399 raise ValueError("At least one of the merge inputs is None: %s" % inputs) 400 with ops.name_scope(name, "Merge", inputs) as name: 401 inputs = [ 402 ops.internal_convert_to_tensor_or_composite(inp, as_ref=True) 403 for inp in inputs 404 ] 405 if all(isinstance(v, ops.Tensor) for v in inputs): 406 if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access 407 return gen_control_flow_ops.ref_merge(inputs, name) 408 else: 409 return gen_control_flow_ops.merge(inputs, name) 410 else: 411 # If there is a mix of tensors and indexed slices, then convert the 412 # tensors to indexed slices. 413 if all(isinstance(v, (ops.IndexedSlices, ops.Tensor)) for v in inputs): 414 inputs = math_ops._as_indexed_slices_list(inputs, optimize=False) 415 416 for v in inputs: 417 if not isinstance(v, composite_tensor.CompositeTensor): 418 raise TypeError("Type %s not supported" % type(v)) 419 420 for v in inputs[1:]: 421 nest.assert_same_structure(inputs[0], v, expand_composites=True) 422 423 flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs] 424 merged_results = [ 425 gen_control_flow_ops.merge(component) 426 for component in zip(*flat_inputs) 427 ] 428 flat_merged = [tensor for (tensor, _) in merged_results] 429 chosen_index = merged_results[0][1] 430 merged_inputs = nest.pack_sequence_as( 431 inputs[0], flat_merged, expand_composites=True) 432 return (merged_inputs, chosen_index) 433 434 435# pylint: enable=protected-access 436 437 438def _convert_tensorarray_to_flow(tensor_or_tensor_array): 439 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray): 440 return tensor_or_tensor_array.flow 441 else: 442 return tensor_or_tensor_array 443 444 445def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows): 446 if len(tensors_or_tensorarrays) != len(tensors_or_flows): 447 raise ValueError( 448 "Lengths of original Tensor list and new list do not match: %d vs. %d" % 449 (len(tensors_or_tensorarrays), len(tensors_or_flows))) 450 return [ 451 tensor_array_ops.build_ta_with_new_flow(ta, t_or_flow) if isinstance( 452 ta, tensor_array_ops.TensorArray) else t_or_flow 453 for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows) 454 ] 455 456 457def _ShapeLessThanOrEqual(shape1, shape2): 458 if shape2.dims is None: 459 return True 460 if shape1.ndims != shape2.ndims: 461 return False 462 for dim1, dim2 in zip(shape1.dims, shape2.dims): 463 if dim2.value is not None and dim1.value != dim2.value: 464 return False 465 return True 466 467 468def _get_shape_invariant(var, shape=None): 469 """Returns shape invariant(s) for the given variable. 470 471 Args: 472 var: The tensor whose shape is described. 473 shape: The shape invariant for the tensor. If not specified, then a default 474 shape invariant for `var` is returned. 475 476 Returns: 477 `TensorShape` or `list` of `TensorShape`: The shape invariant for `var` (if 478 it is a `Tensor`), or the shape invariants for the components that comprise 479 `var` (if it is a `CompositeTensor`). 480 """ 481 if isinstance(var, composite_tensor.CompositeTensor): 482 # Get a TypeSpec for `var`. 483 if shape is None: 484 spec = var._type_spec # pylint: disable=protected-access 485 else: 486 spec = _shape_invariant_to_type_spec(var, shape) 487 488 tensor_specs = nest.flatten(spec, expand_composites=True) 489 return [tspec.shape for tspec in tensor_specs] 490 491 elif shape is None: 492 return var.shape 493 elif isinstance(shape, tensor_spec.TensorSpec): 494 if var.dtype != shape.dtype: 495 raise TypeError("TensorSpec %r is not compatible with %r" % (shape, var)) 496 return shape.shape 497 elif isinstance(shape, type_spec.TypeSpec): 498 raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var)) 499 else: 500 return shape 501 502 503def _shape_invariant_to_type_spec(var, shape): 504 """Converts a shape invariant to a TypeSpec. 505 506 Args: 507 var: The tensor whose shape is described by the shape invariant. 508 shape: A `TypeSpec` or `TensorShape`. If `shape` is already a `TypeSpec`, 509 then it is simply returned as-is. 510 511 Returns: 512 A `TypeSpec` for `var`, consistent with the given shape. 513 """ 514 if shape is None: 515 return type_spec.type_spec_from_value(var) 516 elif isinstance(shape, type_spec.TypeSpec): 517 if not shape.is_compatible_with(var): 518 raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var)) 519 return shape 520 elif not isinstance(shape, tensor_shape.TensorShape): 521 raise TypeError( 522 "Expected shape to be a TypeSpec, TensorShape or None, got %r for" 523 " value %r" % (shape, var)) 524 525 if isinstance(var, ops.Tensor): 526 return tensor_spec.TensorSpec(shape, var.dtype) 527 528 elif isinstance(var, composite_tensor.CompositeTensor): 529 try: 530 return var._shape_invariant_to_type_spec(shape) # pylint: disable=protected-access 531 except NotImplementedError: 532 raise TypeError( 533 "To describe or constrain a %s, use a %s instead of a TensorShape." % 534 (type(var).__name__, type(var._type_spec).__name__)) # pylint: disable=protected-access 535 536 else: 537 raise TypeError("Expected var to be a Tensor or CompositeTensor, got %s" 538 % var) 539 540 541def _SetShapeInvariants(input_vars, enter_vars, shapes): 542 """Set the shapes of the tensors in `enter_vars` to `shapes`. 543 544 Args: 545 input_vars: A list of tensors that are inputs to `enter_vars`. 546 enter_vars: A list of tensors whose shapes will be set. 547 shapes: A (possibly nested) list of shapes. 548 549 Raises: 550 ValueError: If any tensor in `enter_vars` has a less specific shape 551 than its corresponding shape in `shapes`. 552 """ 553 if shapes is None: 554 return 555 flat_shapes = nest.flatten(shapes) 556 if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes): 557 raise ValueError("`shapes` must be a (possibly nested) list of shapes.") 558 # Check that the shapes of the inputs are less than the shape invariants, 559 # and set the shapes of `enter_vars` to the shape invariants. 560 for inp, var, shape in zip(input_vars, enter_vars, flat_shapes): 561 if isinstance(var, ops.Tensor): 562 if not _ShapeLessThanOrEqual(inp.get_shape(), shape): 563 raise ValueError( 564 "The shape invariant specified for %s is not compatible with " 565 "the initial shape of the loop variable. It enters the loop " 566 "with shape %s, but the specified shape invariant is %s." % 567 (inp.name, inp.get_shape(), shape)) 568 var.set_shape(shape) 569 else: 570 raise TypeError("Type %s not supported" % type(var)) 571 572 573def _EnforceShapeInvariant(merge_var, next_var): 574 """Check if the shapes of the loops variables are invariants. 575 576 Args: 577 merge_var: The list of tensors representing the initial values of the loop 578 variables. 579 next_var: The list of tensors representing the values of the loop variables 580 after one loop iteration. 581 582 Raises: 583 ValueError: If any tensor in `merge_var` has a more specific shape than 584 its corresponding tensor in `next_var`. 585 """ 586 if isinstance(merge_var, ops.Tensor): 587 m_shape = merge_var.get_shape() 588 n_shape = next_var.get_shape() 589 if not _ShapeLessThanOrEqual(n_shape, m_shape): 590 enter = merge_var.op.inputs[0].op 591 assert util.IsLoopEnter(enter) 592 input_t = enter.inputs[0] 593 raise ValueError( 594 "Input tensor '%s' enters the loop with shape %s, but has shape %s " 595 "after one iteration. To allow the shape to vary across iterations, " 596 "use the `shape_invariants` argument of tf.while_loop to specify a " 597 "less-specific shape." % (input_t.name, input_t.shape, n_shape)) 598 else: 599 raise TypeError("Type %s not supported" % type(merge_var)) 600 601 602def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): 603 """Add NextIteration and back edge from v to m.""" 604 if isinstance(m, ops.Tensor): 605 v = ops.convert_to_tensor(v) 606 v = _NextIteration(v) 607 if enforce_shape_invariant: 608 # Make sure the shapes of loop outputs are correct. We do this before 609 # calling _update_input, which will raise a less-helpful error message if 610 # the types don't match. 611 # TODO(skyewm): call this for other cases below (needs testing) 612 _EnforceShapeInvariant(m, v) 613 m.op._update_input(1, v) # pylint: disable=protected-access 614 elif isinstance(m, composite_tensor.CompositeTensor): 615 # pylint: disable=protected-access 616 def update_component(m_component, v_component): 617 m_component.op._update_input(1, v_component) 618 619 if isinstance(m, ops.IndexedSlices): 620 v = math_ops._as_indexed_slices(v, optimize=False) 621 # pylint: enable=protected-access 622 v = _NextIteration(v) 623 return nest.map_structure(update_component, m, v, expand_composites=True) 624 else: 625 raise TypeError("Type %s not supported" % type(m)) 626 return v 627 628 629@six.add_metaclass(abc.ABCMeta) 630class ControlFlowContext(object): 631 """The base class for control flow context. 632 633 The usage pattern is a sequence of (Enter, Exit) followed by a final 634 ExitResult. 635 636 We maintain the following state for control flow contexts during graph 637 construction: 638 1. graph has _control_flow_context: the current context used to 639 construct new nodes. Changed by ctxt.Enter() and ctxt.Exit() 640 2. op has _control_flow_context: the context to which the op belongs. 641 Set at the time the op is created. Immutable. 642 3. A ControlFlowContext has _outer_context: the context in which this 643 context is created. Set at the time a context is created. Immutable. 644 4. A ControlFlowContext has _context_stack. 645 Pushed and popped by ctxt.Enter() and ctxt.Exit() 646 """ 647 648 def __init__(self, values_def=None, import_scope=None): 649 self._nested_contexts = [] 650 self._outer_context = ops.get_default_graph()._get_control_flow_context() 651 if self._outer_context: 652 self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access 653 self._context_stack = [] 654 if values_def: 655 self._init_values_from_proto(values_def, import_scope=import_scope) 656 else: 657 # The names of tensors that have been already seen in this context. 658 self._values = set() 659 # The keys are the names of tensors referenced by but external to this 660 # context. Each value is the Tensor that should be used by this context to 661 # access the key value (e.g. a switch output guarding a cond input value). 662 self._external_values = {} 663 664 def _init_values_from_proto(self, values_def, import_scope=None): 665 """Initializes values and external_values from `ValuesDef` protocol buffer. 666 667 Args: 668 values_def: `ValuesDef` protocol buffer. 669 import_scope: Optional `string`. Name scope to add. 670 """ 671 assert isinstance(values_def, control_flow_pb2.ValuesDef) 672 self._values = set( 673 ops.prepend_name_scope(value, import_scope) 674 for value in values_def.values) 675 g = ops.get_default_graph() 676 self._external_values = {} 677 for k, v in values_def.external_values.items(): 678 k = ops.prepend_name_scope(k, import_scope) 679 self._external_values[k] = g.as_graph_element( 680 ops.prepend_name_scope(v, import_scope)) 681 op_names = set([ 682 op.split(":")[0] 683 for op in self._values - set(self._external_values.keys()) 684 ]) 685 for op in op_names: 686 # pylint: disable=protected-access 687 g.as_graph_element(op)._set_control_flow_context(self) 688 # pylint: enable=protected-access 689 690 @property 691 def name(self): 692 return self._name 693 694 @property 695 def outer_context(self): 696 """Return the context containing this context.""" 697 return self._outer_context 698 699 @property 700 def grad_state(self): 701 raise NotImplementedError("Abstract method") 702 703 @property 704 def back_prop(self): 705 raise NotImplementedError("Abstract method") 706 707 @abc.abstractmethod 708 def to_control_flow_context_def(self, context_def, export_scope=None): 709 """Serializes this into `context_def`. 710 711 Args: 712 context_def: a `ControlFlowContextDef` protocol buffer. 713 export_scope: Optional `string`. Name scope to remove. 714 """ 715 raise NotImplementedError("Abstract method") 716 717 def _to_values_def(self, export_scope=None): 718 """Converts the values to a `ValuesDef` protocol buffer. 719 720 Args: 721 export_scope: Optional `string`. Name scope to remove. 722 723 Returns: 724 A `ValuesDef` protocol buffer. 725 """ 726 values_def = control_flow_pb2.ValuesDef() 727 values_def.values.extend( 728 [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)]) 729 for k, v in self._external_values.items(): 730 k = ops.strip_name_scope(k, export_scope) 731 values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope) 732 return values_def 733 734 def AddName(self, name): 735 self._values.add(name) 736 737 # pylint: disable=protected-access 738 def Enter(self): 739 """Enter this control flow context.""" 740 graph = ops.get_default_graph() 741 self._context_stack.append(graph._get_control_flow_context()) 742 graph._set_control_flow_context(self) 743 744 def Exit(self): 745 """Exit this control flow context.""" 746 graph = ops.get_default_graph() 747 last_context = self._context_stack.pop() 748 graph._set_control_flow_context(last_context) 749 750 def EnterGradientColocation(self, op, gradient_uid): 751 """Start building a gradient colocated with an op.""" 752 if self._outer_context: 753 self._outer_context.EnterGradientColocation(op, gradient_uid) 754 755 def ExitGradientColocation(self, op, gradient_uid): 756 """Start building a gradient colocated with an op.""" 757 if self._outer_context: 758 self._outer_context.ExitGradientColocation(op, gradient_uid) 759 760 def ExitResult(self, result): 761 """Make a list of tensors available in the outer context.""" 762 if self._outer_context: 763 def fn(x): 764 self._outer_context.AddName(x.name) 765 return x 766 nest.map_structure(fn, result, expand_composites=True) 767 768 def GetWhileContext(self): 769 """Return the while context containing this context.""" 770 if self._outer_context: 771 return self._outer_context.GetWhileContext() 772 return None 773 774 def _RemoveExternalControlEdges(self, op): 775 """Remove any external control dependency on this op.""" 776 while_ctxt = self.GetWhileContext() 777 # A control input of `op` is internal if it is in the same while 778 # loop context as the enclosing while loop context of self. 779 if while_ctxt is None: 780 internal_control_inputs = op.control_inputs 781 else: 782 internal_control_inputs = [] 783 for x in op.control_inputs: 784 ctxt = util.GetOutputContext(x) 785 if ctxt is not None and ctxt.GetWhileContext() == while_ctxt: 786 internal_control_inputs.append(x) 787 external_control_inputs = [] 788 if len(internal_control_inputs) != len(op.control_inputs): 789 external_control_inputs = list( 790 set(op.control_inputs) - set(internal_control_inputs)) 791 op._remove_all_control_inputs() 792 op._add_control_inputs(internal_control_inputs) 793 return internal_control_inputs, external_control_inputs 794 795 # pylint: enable=protected-access 796 797 def AddInnerOp(self, op): 798 """Notifies a scope about an operator added to an inner scope.""" 799 if self._outer_context: 800 self._outer_context.AddInnerOp(op) 801 802 def GetControlPivot(self): 803 """Returns the pivot node for this context, or None.""" 804 return None 805 806 def IsWhileContext(self): 807 return False 808 809 def IsCondContext(self): 810 return False 811 812 def IsXLAContext(self): 813 return False 814 815 def __str__(self): 816 return self.name 817 818 819class CondContext(ControlFlowContext): 820 """The context for the conditional construct.""" 821 822 def __init__(self, 823 pred=None, 824 pivot=None, 825 branch=None, 826 name="cond_text", 827 context_def=None, 828 import_scope=None): 829 """Creates a `CondContext`. 830 831 Args: 832 pred: The `boolean` tensor for the conditional predicate. 833 pivot: The predicate tensor in this branch. 834 branch: 0 or 1 representing this branch. 835 name: Name of the `CondContext` python object. 836 context_def: Optional `ContextDef` protocol buffer to initialize the 837 `CondContext` object from. 838 import_scope: Optional `string`. Name scope to add. Only used when 839 initialing from protocol buffer. 840 """ 841 self._name = ops.get_default_graph().unique_name(name) 842 843 if context_def: 844 self._init_from_proto(context_def, import_scope=import_scope) 845 else: 846 # Initializes the default fields. 847 ControlFlowContext.__init__(self) 848 self._pred = pred # The boolean tensor for the cond predicate 849 self._pivot = pivot # The predicate tensor in this branch 850 self._branch = branch # 0 or 1 representing this branch 851 852 # Values considered to have been already seen in this context. pred is not 853 # included in this context. 854 self._values.add(pred.name) 855 self._external_values[pred.name] = pred 856 self._values.add(pivot.name) 857 pivot.op._set_control_flow_context(self) # pylint: disable=protected-access 858 859 def _init_from_proto(self, context_def, import_scope=None): 860 """Creates a new `CondContext` from protocol buffer. 861 862 Args: 863 context_def: `CondContextDef` protocol buffer. 864 import_scope: Optional `string`. Name scope to add. 865 """ 866 assert isinstance(context_def, control_flow_pb2.CondContextDef) 867 # Create from context_def. 868 g = ops.get_default_graph() 869 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 870 self._pred = g.as_graph_element( 871 ops.prepend_name_scope(context_def.pred_name, import_scope)) 872 self._pivot = g.as_graph_element( 873 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 874 self._branch = context_def.branch 875 super(CondContext, self).__init__( 876 values_def=context_def.values_def, import_scope=import_scope) 877 878 @property 879 def pred(self): 880 return self._pred 881 882 @property 883 def pivot(self): 884 return self._pivot 885 886 @property 887 def branch(self): 888 return self._branch 889 890 @property 891 def grad_state(self): 892 if self.GetWhileContext(): 893 return self.GetWhileContext().grad_state 894 return None 895 896 @property 897 def back_prop(self): 898 if self.GetWhileContext(): 899 self.GetWhileContext().back_prop 900 return False 901 902 def GetControlPivot(self): 903 return self._pivot 904 905 def to_proto(self, export_scope=None): 906 """Converts a `CondContext` to a `CondContextDef` protocol buffer. 907 908 Args: 909 export_scope: Optional `string`. Name scope to remove. 910 911 Returns: 912 A `CondContextDef` protocol buffer. 913 """ 914 if (export_scope is None or self.name.startswith(export_scope)): 915 context_def = control_flow_pb2.CondContextDef() 916 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 917 context_def.pred_name = ops.strip_name_scope(self._pred.name, 918 export_scope) 919 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 920 export_scope) 921 context_def.branch = self._branch 922 context_def.values_def.MergeFrom( 923 super(CondContext, self)._to_values_def(export_scope)) 924 for nested in self._nested_contexts: 925 nested_def = context_def.nested_contexts.add() 926 nested.to_control_flow_context_def(nested_def) 927 928 return context_def 929 else: 930 return None 931 932 @staticmethod 933 def from_proto(context_def, import_scope=None): 934 """Returns a `CondContext` object created from `context_def`.""" 935 ret = CondContext(context_def=context_def, import_scope=import_scope) 936 937 ret.Enter() 938 for nested_def in context_def.nested_contexts: 939 from_control_flow_context_def(nested_def, import_scope=import_scope) 940 ret.Exit() 941 return ret 942 943 def to_control_flow_context_def(self, context_def, export_scope=None): 944 context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 945 946 def AddValue(self, val): 947 """Add `val` to the current context and its outer context recursively.""" 948 if val.name in self._values: 949 # Use the real value if it comes from outer context. This is needed in 950 # particular for nested conds. 951 result = self._external_values.get(val.name) 952 result = val if result is None else result 953 else: 954 result = val 955 self._values.add(val.name) 956 if self._outer_context: 957 result = self._outer_context.AddValue(val) 958 self._values.add(result.name) 959 self._external_values[result.name] = result 960 with ops.control_dependencies(None): 961 result = _SwitchRefOrTensor(result, self._pred)[self._branch] 962 if self._outer_context: 963 self._outer_context.AddInnerOp(result.op) 964 965 result.op.graph.prevent_fetching(result.op) 966 # pylint: disable=protected-access 967 result.op._set_control_flow_context(self) 968 # pylint: enable=protected-access 969 970 # Mark Switch output as seen by this context and any outer contexts, 971 # just like what we do for normal op outputs in _AddOpInternal() below. 972 ctxt = self 973 while ctxt is not None: 974 # pylint: disable=protected-access 975 ctxt._values.add(result.name) 976 ctxt = ctxt._outer_context 977 # pylint: enable=protected-access 978 979 self._external_values[val.name] = result 980 return result 981 982 def AddOp(self, op): 983 self._AddOpInternal(op) 984 985 def _AddOpInternal(self, op): 986 """Add `op` to the current context.""" 987 if not op.inputs: 988 # If we're in a while loop, remove any control inputs from outside the 989 # loop. 990 self._RemoveExternalControlEdges(op) 991 992 if not any( 993 util.OpInContext(input_op, self) for input_op in op.control_inputs): 994 # pylint: disable=protected-access 995 op._add_control_input(self._pivot.op) 996 # pylint: enable=protected-access 997 else: 998 # Make each input to 'op' available in this CondContext. If an input is 999 # already part of this context there's nothing to do, but if it's 1000 # external, AddValue() will handle adding the appropriate Switch node and 1001 # other bookkeeping. 1002 for index in range(len(op.inputs)): 1003 x = op.inputs[index] 1004 if op.type == "Merge" and x.op.type == "NextIteration": 1005 # Edge case: if we're importing a while loop inside this CondContext, 1006 # AddValue() will not correctly handle the NextIteration inputs to 1007 # Merge node. The problem is that the NextIteration should also be 1008 # part of this context, but if we're importing it won't have been 1009 # processed and added to the context yet, so AddValue() will try to 1010 # add a Switch which results in an invalid graph. Instead, we use the 1011 # NextIteration input as-is here, and it will eventually be added to 1012 # the context via AddOp(). 1013 real_x = x 1014 else: 1015 real_x = self.AddValue(x) 1016 if real_x != x: 1017 # pylint: disable=protected-access 1018 op._update_input(index, real_x) 1019 # pylint: enable=protected-access 1020 # Remove any external control dependency on this op. 1021 self._RemoveExternalControlEdges(op) 1022 # pylint: disable=protected-access 1023 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 1024 op._add_control_input(self._pivot.op) 1025 # pylint: enable=protected-access 1026 1027 # Mark op's outputs as seen by this context and any outer contexts. 1028 output_names = [x.name for x in op.outputs] 1029 ctxt = self 1030 while ctxt is not None: 1031 # pylint: disable=protected-access 1032 ctxt._values.update(output_names) 1033 ctxt = ctxt._outer_context 1034 # pylint: enable=protected-access 1035 1036 if self._outer_context or not util.IsLoopExit(op): 1037 op.graph.prevent_fetching(op) 1038 1039 if self._outer_context: 1040 self._outer_context.AddInnerOp(op) 1041 1042 def _ProcessOutputTensor(self, val): 1043 """Process an output tensor of a conditional branch.""" 1044 real_val = val 1045 if val.name not in self._values: 1046 # Handle the special case of lambda: x 1047 self._values.add(val.name) 1048 if self._outer_context: 1049 real_val = self._outer_context.AddValue(val) 1050 self._values.add(real_val.name) 1051 self._external_values[real_val.name] = real_val 1052 real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch] 1053 self._external_values[val.name] = real_val 1054 else: 1055 external_val = self._external_values.get(val.name) 1056 if external_val is not None: 1057 real_val = external_val 1058 return real_val 1059 1060 def _BuildCondTensor(self, v): 1061 if isinstance(v, ops.Operation): 1062 # Use pivot as the proxy for this op. 1063 return with_dependencies([v], self._pivot) 1064 else: 1065 v = nest.map_structure( 1066 _convert_tensorarray_to_flow, v, expand_composites=True) 1067 return self._ProcessOutputTensor(ops.convert_to_tensor(v)) 1068 1069 def BuildCondBranch(self, fn): 1070 """Add the subgraph defined by fn() to the graph.""" 1071 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1072 original_result = fn() 1073 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1074 if len(post_summaries) > len(pre_summaries): 1075 new_summaries = post_summaries[len(pre_summaries):] 1076 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1077 summary_ref[:] = pre_summaries 1078 with ops.control_dependencies(new_summaries): 1079 if original_result is None: 1080 return no_op(), None 1081 elif not isinstance(original_result, ops.Operation): 1082 original_result = nest.map_structure( 1083 array_ops.identity, original_result, expand_composites=True) 1084 if original_result is None: 1085 return None, None 1086 1087 result = nest.map_structure( 1088 self._BuildCondTensor, original_result, expand_composites=True) 1089 if not isinstance(result, (list, _basetuple)): 1090 result = [result] 1091 return original_result, result 1092 1093 def IsCondContext(self): 1094 return True 1095 1096 1097def _UnpackIfSingleton(res): 1098 if isinstance(res, (list, _basetuple)) and len(res) == 1: 1099 return res[0] 1100 else: 1101 return res 1102 1103 1104def _eager_cond_implementation(pred, true_fn, false_fn, strict, name): 1105 """Special cases for `cond` when executing eagerly.""" 1106 pred = ops.convert_to_tensor(pred) 1107 pred_constant_value = tensor_util.constant_value(pred) 1108 if pred_constant_value is None: 1109 # Eager tensors from a parallel device may not have a constant 1110 # value. Running the cond op itself would work, but we don't have logic to 1111 # build cond ops without wrapping in a function first. 1112 if (not isinstance(true_fn, def_function.Function) 1113 or not isinstance(false_fn, def_function.Function)): 1114 raise TypeError("When running tf.cond on a parallel device, `true_fn` " 1115 "and `false_fn` must be decorated with `tf.function`.") 1116 @def_function.function 1117 def _parallel_device_cond_wrapper(): 1118 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 1119 functions_run_eagerly = def_function.functions_run_eagerly() 1120 if functions_run_eagerly: 1121 # We need to use tf.function to deal with variable creation inside the 1122 # cond, and skipping it because of run_functions_eagerly would just 1123 # crash immediately. 1124 logging.warning( 1125 "It looks like tf.function behavior was disabled, perhaps using " 1126 "tf.config.run_functions_eagerly. Parallelized tf.cond requires " 1127 "tf.function to work. This primitive will override the disable.") 1128 def_function.run_functions_eagerly(False) 1129 try: 1130 return _parallel_device_cond_wrapper() 1131 finally: 1132 if functions_run_eagerly is not None: 1133 def_function.run_functions_eagerly(functions_run_eagerly) 1134 else: 1135 # For conditions which are eager tensors with a constant value (most of 1136 # them), we only call the relevant branch function and execute it eagerly. 1137 with ops.name_scope(name, "cond", [pred]): 1138 if pred_constant_value: 1139 result = true_fn() 1140 else: 1141 result = false_fn() 1142 if not strict: 1143 result = _UnpackIfSingleton(result) 1144 return result 1145 1146 1147# pylint: disable=redefined-outer-name 1148# pylint: disable=g-doc-args 1149@tf_export(v1=["cond"]) 1150@dispatch.add_dispatch_support 1151@deprecation.deprecated_args( 1152 None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", 1153 "fn1", "fn2") 1154def cond(pred, 1155 true_fn=None, 1156 false_fn=None, 1157 strict=False, 1158 name=None, 1159 fn1=None, 1160 fn2=None): 1161 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 1162 1163 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 1164 `false_fn` must have the same non-zero number and type of outputs. 1165 1166 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 1167 `false_fn` will be executed regardless of which branch is selected at runtime. 1168 1169 Although this behavior is consistent with the dataflow model of TensorFlow, 1170 it has frequently surprised users who expected a lazier semantics. 1171 Consider the following simple program: 1172 1173 ```python 1174 z = tf.multiply(a, b) 1175 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 1176 ``` 1177 1178 If `x < y`, the `tf.add` operation will be executed and `tf.square` 1179 operation will not be executed. Since `z` is needed for at least one 1180 branch of the `cond`, the `tf.multiply` operation is always executed, 1181 unconditionally. 1182 1183 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 1184 call to `cond`, and not at all during `Session.run()`). `cond` 1185 stitches together the graph fragments created during the `true_fn` and 1186 `false_fn` calls with some additional graph nodes to ensure that the right 1187 branch gets executed depending on the value of `pred`. 1188 1189 `tf.cond` supports nested structures as implemented in 1190 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 1191 same (possibly nested) value structure of lists, tuples, and/or named tuples. 1192 Singleton lists and tuples form the only exceptions to this: when returned by 1193 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 1194 This behavior is disabled by passing `strict=True`. 1195 1196 Args: 1197 pred: A scalar determining whether to return the result of `true_fn` or 1198 `false_fn`. 1199 true_fn: The callable to be performed if pred is true. 1200 false_fn: The callable to be performed if pred is false. 1201 strict: A boolean that enables/disables 'strict' mode; see above. 1202 name: Optional name prefix for the returned tensors. 1203 1204 Returns: 1205 Tensors returned by the call to either `true_fn` or `false_fn`. If the 1206 callables return a singleton list, the element is extracted from the list. 1207 1208 Raises: 1209 TypeError: if `true_fn` or `false_fn` is not callable. 1210 ValueError: if `true_fn` and `false_fn` do not return the same number of 1211 tensors, or return tensors of different types. 1212 1213 Example: 1214 1215 ```python 1216 x = tf.constant(2) 1217 y = tf.constant(5) 1218 def f1(): return tf.multiply(x, 17) 1219 def f2(): return tf.add(y, 23) 1220 r = tf.cond(tf.less(x, y), f1, f2) 1221 # r is set to f1(). 1222 # Operations in f2 (e.g., tf.add) are not executed. 1223 ``` 1224 1225 """ 1226 # We needed to make true_fn/false_fn keyword arguments for 1227 # backwards-compatibility. This check exists so that we can convert back to 1228 # having them be positional arguments. 1229 # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after 1230 # `fn1` and `fn2` are deleted. 1231 if fn1 is not None: 1232 if true_fn is not None: 1233 raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") 1234 true_fn = fn1 1235 elif true_fn is None: 1236 raise TypeError("cond(): true_fn argument required") 1237 if fn2 is not None: 1238 if false_fn is not None: 1239 raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") 1240 false_fn = fn2 1241 elif false_fn is None: 1242 raise TypeError("cond(): false_fn argument required") 1243 1244 if not callable(true_fn): 1245 raise TypeError("true_fn must be callable.") 1246 if not callable(false_fn): 1247 raise TypeError("false_fn must be callable.") 1248 1249 if context.executing_eagerly(): 1250 return _eager_cond_implementation(pred, true_fn, false_fn, strict, name) 1251 1252 # Always enable control flow v2 if building a function, regardless of toggle. 1253 if util.EnableControlFlowV2(ops.get_default_graph()): 1254 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 1255 1256 with ops.name_scope(name, "cond", [pred]): 1257 # Add the Switch to the graph. 1258 if isinstance(pred, bool): 1259 raise TypeError("pred must not be a Python bool") 1260 p_2, p_1 = switch(pred, pred) 1261 pivot_1 = array_ops.identity(p_1, name="switch_t") 1262 pivot_2 = array_ops.identity(p_2, name="switch_f") 1263 pred = array_ops.identity(pred, name="pred_id") 1264 # Disable the fetching of tensors that are only on one branch of cond. 1265 for tensor in [p_1, p_2, pivot_1, pivot_2, pred]: 1266 tensor.op.graph.prevent_fetching(tensor.op) 1267 1268 # Build the graph for the true branch in a new context. 1269 context_t = CondContext(pred, pivot_1, branch=1) 1270 try: 1271 context_t.Enter() 1272 orig_res_t, res_t = context_t.BuildCondBranch(true_fn) 1273 if orig_res_t is None: 1274 raise ValueError("true_fn must have a return value.") 1275 context_t.ExitResult(res_t) 1276 finally: 1277 context_t.Exit() 1278 1279 # Build the graph for the false branch in a new context. 1280 context_f = CondContext(pred, pivot_2, branch=0) 1281 try: 1282 context_f.Enter() 1283 orig_res_f, res_f = context_f.BuildCondBranch(false_fn) 1284 if orig_res_f is None: 1285 raise ValueError("false_fn must have a return value.") 1286 context_f.ExitResult(res_f) 1287 finally: 1288 context_f.Exit() 1289 1290 if not strict: 1291 orig_res_t = _UnpackIfSingleton(orig_res_t) 1292 orig_res_f = _UnpackIfSingleton(orig_res_f) 1293 1294 # Check that the return values of the two branches have the same structure. 1295 try: 1296 nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True) 1297 except (TypeError, ValueError): 1298 nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f) 1299 nest.map_structure(_cast_indexed_slice_indices, res_t, res_f) 1300 try: 1301 nest.assert_same_structure(orig_res_t, orig_res_f, 1302 expand_composites=True) 1303 except TypeError as e: 1304 raise TypeError( 1305 "Incompatible return types of true_fn and false_fn: {}".format(e)) 1306 except ValueError as e: 1307 raise ValueError( 1308 "Incompatible return values of true_fn and false_fn: {}".format(e)) 1309 1310 # Add the final merge to the graph. 1311 if not res_t: 1312 raise ValueError("true_fn and false_fn must return at least one result.") 1313 1314 res_t_flat = nest.flatten(res_t, expand_composites=True) 1315 res_f_flat = nest.flatten(res_f, expand_composites=True) 1316 1317 for (x, y) in zip(res_t_flat, res_f_flat): 1318 assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) 1319 if x.dtype.base_dtype != y.dtype.base_dtype: 1320 raise ValueError( 1321 "Outputs of true_fn and false_fn must have the same type: " 1322 "%s, %s" % (x.dtype.name, y.dtype.name)) 1323 1324 merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] 1325 merges = _convert_flows_to_tensorarrays( 1326 nest.flatten(orig_res_t, expand_composites=True), merges) 1327 1328 # Only add non-nested conds to the collection. Any nested control flow will 1329 # be encapsulated in the root context. 1330 assert context_t.outer_context == context_f.outer_context 1331 if context_t.outer_context is None: 1332 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) 1333 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) 1334 1335 merges = nest.pack_sequence_as( 1336 structure=orig_res_t, flat_sequence=merges, expand_composites=True) 1337 1338 # Singleton lists and tuples are automatically unpacked if strict == False. 1339 if not strict: 1340 merges = _UnpackIfSingleton(merges) 1341 return merges 1342 1343 1344def _cast_indexed_slice_indices(a, b): 1345 """Cast IndexedSlice.indices from int32 to int64 where necessary. 1346 1347 If `a` and `b` are both IndexedSlices, and their indices have different 1348 dtypes, then cast both their dtypes to `int64` (modifies `a` and `b` 1349 in-place). Otherwise, does nothing. 1350 1351 Args: 1352 a: A value, which may be an IndexedSlices. 1353 b: A value, which may be an IndexedSlices. 1354 """ 1355 if (isinstance(a, ops.IndexedSlices) and isinstance(b, ops.IndexedSlices) 1356 and a.indices.dtype != b.indices.dtype): 1357 # pylint: disable=protected-access 1358 a._indices = math_ops.cast(a.indices, dtypes.int64) 1359 b._indices = math_ops.cast(b.indices, dtypes.int64) 1360 1361 1362# pylint: enable=g-doc-args 1363# pylint: enable=redefined-outer-name 1364 1365 1366@tf_export("cond", v1=[]) 1367@dispatch.add_dispatch_support 1368def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None): 1369 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 1370 1371 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 1372 `false_fn` must have the same non-zero number and type of outputs. 1373 1374 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 1375 `false_fn` will be executed regardless of which branch is selected at runtime. 1376 1377 Although this behavior is consistent with the dataflow model of TensorFlow, 1378 it has frequently surprised users who expected a lazier semantics. 1379 Consider the following simple program: 1380 1381 ```python 1382 z = tf.multiply(a, b) 1383 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 1384 ``` 1385 1386 If `x < y`, the `tf.add` operation will be executed and `tf.square` 1387 operation will not be executed. Since `z` is needed for at least one 1388 branch of the `cond`, the `tf.multiply` operation is always executed, 1389 unconditionally. 1390 1391 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 1392 call to `cond`, and not at all during `Session.run()`). `cond` 1393 stitches together the graph fragments created during the `true_fn` and 1394 `false_fn` calls with some additional graph nodes to ensure that the right 1395 branch gets executed depending on the value of `pred`. 1396 1397 `tf.cond` supports nested structures as implemented in 1398 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 1399 same (possibly nested) value structure of lists, tuples, and/or named tuples. 1400 Singleton lists and tuples form the only exceptions to this: when returned by 1401 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 1402 1403 Note: It is illegal to "directly" use tensors created inside a cond branch 1404 outside it, e.g. by storing a reference to a branch tensor in the python 1405 state. If you need to use a tensor created in a branch function you should 1406 return it as an output of the branch function and use the output from 1407 `tf.cond` instead. 1408 1409 Args: 1410 pred: A scalar determining whether to return the result of `true_fn` or 1411 `false_fn`. 1412 true_fn: The callable to be performed if pred is true. 1413 false_fn: The callable to be performed if pred is false. 1414 name: Optional name prefix for the returned tensors. 1415 1416 Returns: 1417 Tensors returned by the call to either `true_fn` or `false_fn`. If the 1418 callables return a singleton list, the element is extracted from the list. 1419 1420 Raises: 1421 TypeError: if `true_fn` or `false_fn` is not callable. 1422 ValueError: if `true_fn` and `false_fn` do not return the same number of 1423 tensors, or return tensors of different types. 1424 1425 Example: 1426 1427 ```python 1428 x = tf.constant(2) 1429 y = tf.constant(5) 1430 def f1(): return tf.multiply(x, 17) 1431 def f2(): return tf.add(y, 23) 1432 r = tf.cond(tf.less(x, y), f1, f2) 1433 # r is set to f1(). 1434 # Operations in f2 (e.g., tf.add) are not executed. 1435 ``` 1436 1437 """ 1438 return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name) 1439 1440 1441def _resource_safe_shape(t): 1442 """Returns the shape of t or the variable it points to.""" 1443 if t.dtype == dtypes.resource: 1444 while t.op.inputs: 1445 t = t.op.inputs[0] 1446 return tensor_shape.TensorShape(t.op.get_attr("shape")) 1447 return array_ops.shape_internal(t, optimize=False) 1448 1449 1450# TODO(yuanbyu): Consider having a unified notion of context for 1451# not only conditionals and loops but also control dependency and 1452# subgraphs. 1453class WhileContext(ControlFlowContext): 1454 """The context for the loop construct.""" 1455 1456 def __init__(self, 1457 maximum_iterations=None, 1458 parallel_iterations=10, 1459 back_prop=True, 1460 swap_memory=False, 1461 name="while_context", 1462 grad_state=None, 1463 context_def=None, 1464 import_scope=None): 1465 """"Creates a `WhileContext`. 1466 1467 Args: 1468 maximum_iterations: Optional upper bound on number of loop iterations. 1469 parallel_iterations: The number of iterations allowed to run in parallel. 1470 back_prop: Whether backprop is enabled for this while loop. 1471 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 1472 name: Optional name prefix for the returned tensors. 1473 grad_state: The gradient loop state. 1474 context_def: Optional `WhileContextDef` protocol buffer to initialize the 1475 `Whilecontext` python object from. 1476 import_scope: Optional `string`. Name scope to add. Only used when 1477 initialing from protocol buffer. 1478 """ 1479 if context_def: 1480 self._init_from_proto(context_def, import_scope=import_scope) 1481 else: 1482 ControlFlowContext.__init__(self) 1483 self._init_from_args(maximum_iterations, parallel_iterations, back_prop, 1484 swap_memory, name) 1485 # The gradient loop state. 1486 self._grad_state = grad_state 1487 1488 def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop, 1489 swap_memory, name): 1490 """Creates a new `WhileContext` from arguments. 1491 1492 Args: 1493 maximum_iterations: Optional upper bound on number of loop iterations. 1494 parallel_iterations: The number of iterations allowed to run in parallel. 1495 back_prop: Whether backprop is enabled for this while loop. 1496 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 1497 name: Optional name prefix for the returned tensors. 1498 1499 Raises: 1500 ValueError: If `parallel_iterations` has invalid value. 1501 """ 1502 if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0): 1503 raise ValueError("`parallel_iterations` must be a positive integer: " 1504 "%s" % parallel_iterations) 1505 self._name = ops.get_default_graph().unique_name(name) 1506 self._maximum_iterations = maximum_iterations 1507 self._parallel_iterations = parallel_iterations 1508 self._back_prop = back_prop 1509 self._swap_memory = swap_memory 1510 # We use this node to control constants created by the pred lambda. 1511 self._pivot_for_pred = None 1512 # We use this node to control constants created by the body lambda. 1513 self._pivot_for_body = None 1514 # The boolean tensor for loop termination condition. Used in code 1515 # generation for gradient computation 1516 self._pivot = None 1517 # The list of exit tensors for loop variables. 1518 self._loop_exits = [] 1519 # The list of enter tensors for loop variables. 1520 self._loop_enters = [] 1521 self._graph = ops.get_default_graph() 1522 1523 def _init_from_proto(self, context_def, import_scope=None): 1524 """Creates a new `WhileContext` from protocol buffer. 1525 1526 Args: 1527 context_def: `WhileContextDef` protocol buffer. 1528 import_scope: Optional `string`. Name scope to add. 1529 """ 1530 assert isinstance(context_def, control_flow_pb2.WhileContextDef) 1531 # Create from context_def. 1532 g = ops.get_default_graph() 1533 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 1534 if context_def.maximum_iterations_name: 1535 self._maximum_iterations = g.as_graph_element( 1536 ops.prepend_name_scope(context_def.maximum_iterations_name, 1537 import_scope)) 1538 else: 1539 self._maximum_iterations = None 1540 self._parallel_iterations = context_def.parallel_iterations 1541 self._back_prop = context_def.back_prop 1542 self._swap_memory = context_def.swap_memory 1543 self._pivot_for_pred = g.as_graph_element( 1544 ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope)) 1545 # We use this node to control constants created by the body lambda. 1546 self._pivot_for_body = g.as_graph_element( 1547 ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope)) 1548 # The boolean tensor for loop termination condition. Used in code 1549 # generation for gradient computation. 1550 self._pivot = g.as_graph_element( 1551 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 1552 # The list of exit tensors for loop variables. 1553 self._loop_exits = [ 1554 g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) 1555 for exit_name in context_def.loop_exit_names 1556 ] 1557 # The list of enter tensors for loop variables. 1558 self._loop_enters = [ 1559 g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) 1560 for enter_name in context_def.loop_enter_names 1561 ] 1562 super(WhileContext, self).__init__( 1563 values_def=context_def.values_def, import_scope=import_scope) 1564 1565 # import_scope causes self.name to be different from the original serialized 1566 # context's name. Rewrite "frame_name" attrs with the new name. 1567 if import_scope: 1568 for tensor_name in self._values: 1569 op = g.as_graph_element(tensor_name).op 1570 if util.IsLoopEnter(op): 1571 # pylint: disable=protected-access 1572 op._set_attr("frame_name", 1573 attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) 1574 # pylint: enable=protected-access 1575 self._graph = ops.get_default_graph() 1576 1577 @property 1578 def maximum_iterations(self): 1579 """The maximum number of iterations that will be executed.""" 1580 return self._maximum_iterations 1581 1582 @property 1583 def parallel_iterations(self): 1584 """The number of iterations allowed to run in parallel.""" 1585 return self._parallel_iterations 1586 1587 @property 1588 def back_prop(self): 1589 """True iff backprop is enabled for this while loop.""" 1590 return self._back_prop 1591 1592 @property 1593 def swap_memory(self): 1594 """True iff GPU-CPU memory swap is enabled for this while loop.""" 1595 return self._swap_memory 1596 1597 @property 1598 def pivot(self): 1599 """The boolean tensor representing the loop termination condition.""" 1600 return self._pivot 1601 1602 @property 1603 def loop_enters(self): 1604 """The list of enter tensors for loop variables.""" 1605 return self._loop_enters 1606 1607 @property 1608 def loop_exits(self): 1609 """The list of exit tensors for loop variables.""" 1610 return self._loop_exits 1611 1612 @property 1613 def grad_state(self): 1614 """The gradient loop state.""" 1615 return self._grad_state 1616 1617 def to_proto(self, export_scope=None): 1618 """Converts a `WhileContext` to a `WhileContextDef` protocol buffer. 1619 1620 Args: 1621 export_scope: Optional `string`. Name scope to remove. 1622 1623 Returns: 1624 A `WhileContextDef` protocol buffer. 1625 """ 1626 if (export_scope is None or self.name.startswith(export_scope)): 1627 context_def = control_flow_pb2.WhileContextDef() 1628 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 1629 context_def.parallel_iterations = self._parallel_iterations 1630 if self._maximum_iterations is not None: 1631 context_def.maximum_iterations_name = ops.strip_name_scope( 1632 self._maximum_iterations.name, export_scope) 1633 context_def.back_prop = self._back_prop 1634 context_def.swap_memory = self._swap_memory 1635 context_def.pivot_for_pred_name = ops.strip_name_scope( 1636 self._pivot_for_pred.name, export_scope) 1637 context_def.pivot_for_body_name = ops.strip_name_scope( 1638 self._pivot_for_body.name, export_scope) 1639 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 1640 export_scope) 1641 context_def.loop_exit_names.extend([ 1642 ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits 1643 ]) 1644 context_def.loop_enter_names.extend([ 1645 ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters 1646 ]) 1647 context_def.values_def.MergeFrom( 1648 super(WhileContext, self)._to_values_def(export_scope=export_scope)) 1649 for nested in self._nested_contexts: 1650 nested_def = context_def.nested_contexts.add() 1651 nested.to_control_flow_context_def(nested_def) 1652 1653 return context_def 1654 else: 1655 return None 1656 1657 def to_control_flow_context_def(self, context_def, export_scope=None): 1658 context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 1659 1660 @staticmethod 1661 def from_proto(context_def, import_scope=None): 1662 """Returns a `WhileContext` object created from `context_def`. 1663 1664 Args: 1665 context_def: A `WhileContextDef` protocol buffer. 1666 import_scope: Optional `string`. Name scope to add. 1667 1668 Returns: 1669 A `WhileContext` Python object. 1670 """ 1671 ret = WhileContext(context_def=context_def, import_scope=import_scope) 1672 ret.Enter() 1673 for nested_def in context_def.nested_contexts: 1674 from_control_flow_context_def(nested_def, import_scope=import_scope) 1675 ret.Exit() 1676 return ret 1677 1678 def GetWhileContext(self): 1679 return self 1680 1681 def GetControlPivot(self): 1682 if self._pivot_for_body is not None: 1683 return self._pivot_for_body 1684 return self._pivot_for_pred 1685 1686 def AddValue(self, val): 1687 """Add `val` to the current context and its outer context recursively.""" 1688 result = val 1689 new_value = val.name not in self._values 1690 # Don't treat ops in this context as new values. Usually all known values 1691 # are in self._values, except when we're importing a while loop inside this 1692 # WhileContext. Since there's a cycle in this case, `val` may be part of the 1693 # imported while loop but not yet processed by this context and added to 1694 # self._values in _AddOpInternal. We only want to process external input 1695 # tensors to the while loop here. 1696 new_value &= val.op._control_flow_context is not self # pylint: disable=protected-access 1697 if new_value: 1698 self._values.add(val.name) 1699 1700 # If we are in a grad context and val is from its forward context, 1701 # use GetRealValue(), which adds the logic to save the history of 1702 # val in forward. 1703 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 1704 if grad_ctxt: 1705 grad_ctxt = grad_ctxt.GetWhileContext() 1706 if grad_ctxt.grad_state: 1707 forward_ctxt = util.GetWhileContext(val.op) 1708 if util.IsLoopExit(val.op): 1709 forward_ctxt = forward_ctxt.outer_context 1710 if forward_ctxt: 1711 forward_ctxt = forward_ctxt.GetWhileContext() 1712 if forward_ctxt == grad_ctxt.grad_state.forward_context: 1713 real_val = grad_ctxt.grad_state.GetRealValue(val) 1714 self._external_values[val.name] = real_val 1715 return real_val 1716 1717 if self._outer_context is not None: 1718 result = self._outer_context.AddValue(val) 1719 # Create an Enter to make `result` known to this loop context. 1720 with ops.control_dependencies(None): 1721 enter = _Enter( 1722 result, 1723 self._name, 1724 is_constant=True, 1725 parallel_iterations=self._parallel_iterations) 1726 enter.graph.prevent_feeding(enter) 1727 if self._outer_context: 1728 self._outer_context.AddInnerOp(enter.op) 1729 # Fix the control inputs and control flow context of these enter ops. 1730 self._FixControlInputsAndContext([enter]) 1731 1732 # Add `enter` in this context. 1733 self._values.add(enter.name) 1734 self._external_values[val.name] = enter 1735 result = enter 1736 else: 1737 actual_val = self._external_values.get(val.name) 1738 if actual_val is not None: 1739 result = actual_val 1740 return result 1741 1742 def AddOp(self, op): 1743 """Add `op` to the current context.""" 1744 # For a reduction op, if op is in a grad context and its input is from 1745 # its forward context, moving op to the forward context means we would 1746 # store the tensor after the reduction as opposed to the tensor before 1747 # reduction, and therefore could significantly reduce memory consumption. 1748 # For now, we do this only for a few ops. 1749 # 1750 # If in XLA context, do not move constant ops to forward pass as pushing to 1751 # and popping from a stack removes the constant property of an op and breaks 1752 # XLA compilation, which requires certain inputs to be constant for certain 1753 # ops. 1754 if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}: 1755 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 1756 if grad_ctxt: 1757 grad_ctxt = grad_ctxt.GetWhileContext() 1758 if grad_ctxt.grad_state: 1759 op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op) 1760 if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context: 1761 op_input_ctxt = op.inputs[0].op._get_control_flow_context() 1762 op._set_control_flow_context(op_input_ctxt) 1763 op_input_ctxt._AddOpInternal(op) 1764 return 1765 self._AddOpInternal(op) 1766 1767 def _AddOpInternal(self, op): 1768 """Add `op` to the current context. 1769 1770 We move any external control dependencies of the op to the loop pivot, to 1771 ensure they get executed. 1772 """ 1773 # This is needed to prevent frame mismatch errors where there are Const 1774 # nodes inside tf.function in v1 while_loop and inlining is turned on. 1775 if op.type in ["PartitionedCall", "StatefulPartitionedCall"]: 1776 op._add_control_input(self.GetControlPivot().op) # pylint: disable=protected-access 1777 if not op.inputs: 1778 # Remove any external control dependency on this op 1779 control_inputs, external_inputs = self._RemoveExternalControlEdges(op) 1780 # Add a control edge from the control pivot to this op. 1781 if not control_inputs: 1782 # pylint: disable=protected-access 1783 op._add_control_input(self.GetControlPivot().op) 1784 # pylint: enable=protected-access 1785 for x in op.outputs: 1786 self._values.add(x.name) 1787 else: 1788 for index in range(len(op.inputs)): 1789 x = op.inputs[index] 1790 real_x = self.AddValue(x) 1791 if real_x != x: 1792 op._update_input(index, real_x) # pylint: disable=protected-access 1793 # Remove any external control dependency on this op. 1794 _, external_inputs = self._RemoveExternalControlEdges(op) 1795 # Add a control dependency to prevent loop invariants from 1796 # enabling ops that should not be executed. 1797 self._MaybeAddControlDependency(op) 1798 for x in op.outputs: 1799 self._values.add(x.name) 1800 if external_inputs: 1801 # Use an identity to pull control inputs as data inputs. Note that we 1802 # ignore ops which don't have outputs. TODO(apassos): fix that 1803 with ops.control_dependencies(None): 1804 self.Enter() 1805 external_inputs = [ 1806 array_ops.identity(x.outputs[0]).op 1807 for x in external_inputs 1808 if x.outputs 1809 ] 1810 self.Exit() 1811 op._add_control_inputs(external_inputs) # pylint: disable=protected-access 1812 if self._outer_context or not util.IsLoopExit(op): 1813 op.graph.prevent_fetching(op) 1814 for x in op.outputs: 1815 op.graph.prevent_feeding(x) 1816 1817 if self._outer_context: 1818 self._outer_context.AddInnerOp(op) 1819 1820 def _MaybeAddControlDependency(self, op): 1821 """Add a control input to the op if it only depends on loop invariants.""" 1822 1823 def _IsOpFree(op): 1824 """Determines if `op` needs a control dependency.""" 1825 if op.control_inputs: 1826 return False 1827 # pylint: disable=protected-access 1828 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 1829 return True 1830 # pylint: enable=protected-access 1831 for x in op.inputs: 1832 if not util.IsLoopConstantEnter(x.op): 1833 return False 1834 return True 1835 1836 if _IsOpFree(op): 1837 # pylint: disable=protected-access 1838 op._add_control_input(self.GetControlPivot().op) 1839 # pylint: enable=protected-access 1840 1841 def AddForwardLoopCounter(self, outer_grad_state): 1842 """Adds a loop that counts the number of iterations. 1843 1844 This is added to the forward loop at the time when we start to 1845 create the loop for backprop gradient computation. Called in 1846 the outer context of this forward context. 1847 1848 The pseudocode is: 1849 `n = 0; while (_pivot) { n++; }` 1850 1851 Note that a control dependency is added to `n` to ensure the correct 1852 execution order of stack push ops. 1853 1854 Args: 1855 outer_grad_state: The outer grad state. None if not nested. 1856 1857 Returns: 1858 The number of iterations taken by the forward loop and the loop index. 1859 """ 1860 n = constant_op.constant(0, name="f_count") 1861 if outer_grad_state is not None: 1862 # Force the stack pushes of i-th execution of an inner loop to be ordered 1863 # before the pushes of (i+1)-th execution of the same inner loop. 1864 outer_add_op = outer_grad_state.forward_index.op.inputs[0].op 1865 n.op._add_control_input(outer_add_op) # pylint: disable=protected-access 1866 1867 self.Enter() 1868 self.AddName(n.name) 1869 enter_n = _Enter( 1870 n, 1871 self._name, 1872 is_constant=False, 1873 parallel_iterations=self._parallel_iterations, 1874 name="f_count") 1875 self.loop_enters.append(enter_n) 1876 1877 merge_n = merge([enter_n, enter_n])[0] 1878 switch_n = switch(merge_n, self._pivot) 1879 1880 index = math_ops.add(switch_n[1], 1) 1881 next_n = _NextIteration(index) 1882 merge_n.op._update_input(1, next_n) 1883 1884 total_iterations = exit(switch_n[0], name="f_count") 1885 self.loop_exits.append(total_iterations) 1886 self.ExitResult([total_iterations]) 1887 self.Exit() 1888 return total_iterations, next_n 1889 1890 def AddBackpropLoopCounter(self, count, outer_grad_state): 1891 """Add the backprop loop that controls the iterations. 1892 1893 This is added to the backprop loop. It is used to control the loop 1894 termination of the backprop loop. Called in the outer context of 1895 this grad context. 1896 1897 The pseudocode is: 1898 `n = count; while (n >= 1) { n--; }` 1899 1900 Note that a control dependency is added to `final_zero` to ensure the 1901 correct execution order of stack pop ops. 1902 1903 Args: 1904 count: The number of iterations for backprop. 1905 outer_grad_state: The outer grad state. None if not nested. 1906 1907 Returns: 1908 The loop index. 1909 """ 1910 in_separate_functions = count.graph is not ops.get_default_graph() 1911 if in_separate_functions: 1912 # Brings the count into this graph 1913 count = array_ops.identity(count) 1914 else: 1915 # TODO(apassos) XLA expects this constant to be created outside the loop, 1916 # so doing that for now. 1917 one = constant_op.constant(1, name="b_count") 1918 1919 self.Enter() 1920 self.AddName(count.name) 1921 enter_count = _Enter( 1922 count, 1923 self._name, 1924 is_constant=False, 1925 parallel_iterations=self._parallel_iterations, 1926 name="b_count") 1927 self.loop_enters.append(enter_count) 1928 1929 merge_count = merge([enter_count, enter_count])[0] 1930 self._pivot_for_pred = merge_count 1931 1932 if in_separate_functions: 1933 one = constant_op.constant(1, name="b_count") 1934 pred = math_ops.greater_equal(merge_count, one) 1935 self._pivot = loop_cond(pred, name="b_count") 1936 switch_count = switch(merge_count, self._pivot) 1937 1938 index = math_ops.subtract(switch_count[1], one) 1939 self._pivot_for_body = index 1940 next_count = _NextIteration(index) 1941 merge_count.op._update_input(1, next_count) 1942 1943 final_zero = exit(switch_count[0], name="b_count") 1944 self.loop_exits.append(final_zero) 1945 if outer_grad_state is not None: 1946 # Force the stack pops of i-th execution of an inner loop to be ordered 1947 # before the pops of (i+1)-th execution of the same inner loop. 1948 # pylint: disable=protected-access 1949 outer_grad_state.grad_sync._add_control_input(final_zero.op) 1950 # pylint: enable=protected-access 1951 1952 self.ExitResult([final_zero]) 1953 self.Exit() 1954 return next_count 1955 1956 def AddBackpropAccumulator(self, op, grad): 1957 """Add an accumulation loop for every loop invariant. 1958 1959 This is added to the backprop loop. It is used to accumulate partial 1960 gradients within each loop iteration. Called when in the gradient while 1961 context. 1962 1963 The pseudocode is: 1964 ``` 1965 acc = 0.0; 1966 while (_pivot) { 1967 acc += grad; 1968 } 1969 ``` 1970 1971 Args: 1972 op: The Enter op for a loop invariant. 1973 grad: The partial gradient of an iteration for a loop invariant. 1974 1975 Returns: 1976 The gradient for a loop invariant. 1977 """ 1978 self.Exit() 1979 # Create a zeros tensor with the right shape for acc. If we don't 1980 # know the full shape statically, we will have to get the shape 1981 # dynamically from the forward inference. Getting the shape right 1982 # for the zeros is only needed for the base case when the loop exits 1983 # without running any iterations. 1984 shape = grad.get_shape() 1985 if shape.is_fully_defined(): 1986 if self.outer_context: 1987 self.outer_context.Enter() 1988 acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc") 1989 if self.outer_context: 1990 self.outer_context.Exit() 1991 else: 1992 value = op.inputs[0] 1993 if (isinstance(self.outer_context, WhileContext) and 1994 self.outer_context.grad_state is not None): 1995 # We are in a nested while loop. 1996 forward_ctxt = self.grad_state.forward_context 1997 forward_ctxt.outer_context.Enter() 1998 zeros_shape = array_ops.shape_internal(value, optimize=False) 1999 forward_ctxt.outer_context.Exit() 2000 outer_grad_state = self.grad_state.outer_grad_state 2001 history_zeros_shape = outer_grad_state.AddForwardAccumulator( 2002 zeros_shape) 2003 self.outer_context.Enter() 2004 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 2005 history_zeros_shape, zeros_shape) 2006 acc = array_ops.zeros(real_shape, grad.dtype) 2007 self.outer_context.Exit() 2008 else: 2009 if self.outer_context: 2010 self.outer_context.Enter() 2011 zeros_shape = array_ops.shape_internal(value, optimize=False) 2012 acc = array_ops.zeros(zeros_shape, grad.dtype) 2013 if self.outer_context: 2014 self.outer_context.Exit() 2015 2016 self.Enter() 2017 self.AddName(acc.name) 2018 enter_acc = _Enter( 2019 acc, 2020 self._name, 2021 is_constant=False, 2022 parallel_iterations=self._parallel_iterations, 2023 name="b_acc") 2024 self.loop_enters.append(enter_acc) 2025 2026 merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0] 2027 switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot) 2028 2029 add_acc = math_ops.add(switch_acc_true, grad) 2030 next_acc = _NextIteration(add_acc) 2031 merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access 2032 2033 result_acc = exit(switch_acc_false, name="b_acc") 2034 self.loop_exits.append(result_acc) 2035 self.ExitResult([result_acc]) 2036 return result_acc 2037 2038 def AddBackpropIndexedSlicesAccumulator(self, op, grad): 2039 """This is used for accumulating gradients that are IndexedSlices. 2040 2041 This is essentially the equivalent of AddBackpropAccumulator but optimized 2042 for things like updating embeddings from within a while loop. 2043 2044 Args: 2045 op: The Enter op for a loop invariant. 2046 grad: The partial gradients represented as an IndexedSlices. 2047 2048 Returns: 2049 The accumulated IndexedSlices gradient of the loop invariant. 2050 """ 2051 values = grad.values 2052 indices = grad.indices 2053 dense_shape = grad.dense_shape 2054 2055 self.Exit() 2056 if self.outer_context: 2057 self.outer_context.Enter() 2058 if values.get_shape().is_fully_defined(): 2059 values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] + 2060 values.get_shape().dims[1:]) 2061 if self.outer_context: 2062 self.outer_context.Enter() 2063 values_acc = constant_op.constant( 2064 0, values.dtype, shape=values_shape, name="b_acc") 2065 if self.outer_context: 2066 self.outer_context.Exit() 2067 else: 2068 values_shape = _resource_safe_shape(op.inputs[0])[1:] 2069 values_shape = array_ops.concat([[1], values_shape], 0) 2070 values_acc = array_ops.zeros(values_shape, dtype=values.dtype) 2071 indices_acc = constant_op.constant([0], indices.dtype) 2072 shape_acc = None 2073 if dense_shape is not None: 2074 if dense_shape.get_shape().is_fully_defined(): 2075 if self.outer_context: 2076 self.outer_context.Enter() 2077 shape_acc = constant_op.constant( 2078 0, dense_shape.dtype, shape=dense_shape.get_shape()) 2079 if self.outer_context: 2080 self.outer_context.Exit() 2081 else: 2082 shape_acc = array_ops.zeros_like( 2083 array_ops.shape_internal( 2084 op.inputs[0], optimize=False, out_type=dense_shape.dtype), 2085 optimize=False) 2086 2087 if self.outer_context: 2088 self.outer_context.Exit() 2089 2090 self.Enter() 2091 self.AddName(values_acc.name) 2092 self.AddName(indices_acc.name) 2093 init_acc = [indices_acc, values_acc] 2094 if shape_acc is not None: 2095 self.AddName(shape_acc.name) 2096 init_acc.append(shape_acc) 2097 2098 # Set use_input_shape=False since the accumulator tensors will grow in 2099 # size. If use_input_shape=True, the _update_input call below will result in 2100 # incompatible shapes. 2101 enter_acc = [ 2102 _Enter( 2103 x, 2104 self._name, 2105 is_constant=False, 2106 parallel_iterations=self._parallel_iterations, 2107 use_input_shape=False, 2108 name="b_acc") for x in init_acc 2109 ] 2110 # Manually set appropriate partial shapes. 2111 enter_acc[0].set_shape([None]) 2112 if values_acc.shape.dims is not None: 2113 enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:]) 2114 self.loop_enters.extend(enter_acc) 2115 2116 merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc] 2117 switch_acc = [switch(x, self._pivot) for x in merge_acc] 2118 2119 # The actual accumulation. 2120 acc_indexed_slices = [ 2121 array_ops.concat([xa[1], xv], 0) 2122 for xa, xv in zip(switch_acc[:2], [indices, values]) 2123 ] 2124 if shape_acc is not None: 2125 # For the shape we just keep the maximum 2126 acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1])) 2127 2128 next_acc = [_NextIteration(x) for x in acc_indexed_slices] 2129 for xm, xn in zip(merge_acc, next_acc): 2130 xm.op._update_input(1, xn) # pylint: disable=protected-access 2131 2132 exit_acc = [exit(x[0], name="b_acc") for x in switch_acc] 2133 self.loop_exits.extend(exit_acc) 2134 2135 self.ExitResult(exit_acc) 2136 return ops.IndexedSlices( 2137 indices=exit_acc[0], 2138 values=exit_acc[1], 2139 dense_shape=exit_acc[2] if shape_acc is not None else None) 2140 2141 def _InitializeValues(self, values): 2142 """Makes the values known to this context.""" 2143 self._values = set() 2144 for x in values: 2145 if isinstance(x, ops.Tensor): 2146 self._values.add(x.name) 2147 else: 2148 raise TypeError("Type %s not supported" % type(x)) 2149 2150 def _BuildLoop(self, pred, body, original_loop_vars, loop_vars, 2151 shape_invariants): 2152 """Core: Add the loop termination condition and body to the graph.""" 2153 flat_loop_vars = nest.flatten(original_loop_vars, expand_composites=True) 2154 2155 # Let the context know the loop variables so the loop variables 2156 # would be added in the outer contexts properly. 2157 self._InitializeValues(loop_vars) 2158 real_vars = loop_vars 2159 if self._outer_context: 2160 real_vars = [self._outer_context.AddValue(x) for x in loop_vars] 2161 with ops.control_dependencies(None): 2162 enter_vars = [ 2163 _Enter( 2164 x, 2165 self._name, 2166 is_constant=False, 2167 parallel_iterations=self._parallel_iterations, 2168 use_input_shape=(shape_invariants is None)) for x in real_vars 2169 ] 2170 for x in enter_vars: 2171 x.graph.prevent_feeding(x) 2172 if self._outer_context: 2173 self._outer_context.AddInnerOp(x.op) 2174 2175 # Finds the closest enclosing non-None control pivot. 2176 outer_context = self._outer_context 2177 control_pivot = None 2178 while outer_context is not None and control_pivot is None: 2179 control_pivot = outer_context.GetControlPivot() 2180 # pylint: disable=protected-access 2181 outer_context = outer_context._outer_context 2182 # pylint: enable=protected-access 2183 2184 if control_pivot is not None: 2185 for var in enter_vars: 2186 if util.IsLoopConstantEnter(var.op.inputs[0].op): 2187 # pylint: disable=protected-access 2188 var.op._add_control_input(control_pivot.op) 2189 # pylint: enable=protected-access 2190 _SetShapeInvariants(real_vars, enter_vars, shape_invariants) 2191 2192 # Fix the control inputs and control flow context of these enter ops. 2193 self._FixControlInputsAndContext(enter_vars) 2194 self._InitializeValues(enter_vars) 2195 self._loop_enters = enter_vars 2196 2197 merge_vars = [merge([x, x])[0] for x in enter_vars] 2198 self._pivot_for_pred = merge_vars[0] 2199 2200 # Build the graph for pred. 2201 merge_vars_with_tensor_arrays = ( 2202 _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars)) 2203 packed_vars = nest.pack_sequence_as( 2204 structure=original_loop_vars, 2205 flat_sequence=merge_vars_with_tensor_arrays, 2206 expand_composites=True) 2207 c = ops.convert_to_tensor(pred(*packed_vars)) 2208 self._pivot = loop_cond(c, name="LoopCond") 2209 switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars] 2210 2211 # Build the graph for body. 2212 vars_for_body = [_Identity(x[1]) for x in switch_vars] 2213 self._pivot_for_body = vars_for_body[0] 2214 # Convert TensorArray flow variables inside the context back into 2215 # their associated TensorArrays for calling the body. 2216 vars_for_body_with_tensor_arrays = ( 2217 _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body)) 2218 packed_vars_for_body = nest.pack_sequence_as( 2219 structure=original_loop_vars, 2220 flat_sequence=vars_for_body_with_tensor_arrays, 2221 expand_composites=True) 2222 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2223 body_result = body(*packed_vars_for_body) 2224 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2225 if not nest.is_sequence_or_composite(body_result): 2226 body_result = [body_result] 2227 if len(post_summaries) > len(pre_summaries): 2228 new_summaries = post_summaries[len(pre_summaries):] 2229 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2230 summary_ref[:] = pre_summaries 2231 with ops.control_dependencies(new_summaries): 2232 2233 def map_fn(x): 2234 # TODO(apassos) figure out how to trigger with tensor arrays as well 2235 if isinstance(x, tensor_array_ops.TensorArray): 2236 return x 2237 return array_ops.identity(x) 2238 2239 body_result = nest.map_structure( 2240 map_fn, body_result, expand_composites=True) 2241 2242 # Compare the structure types of input and output of body. 2243 # For backwards compatibility, the first layer is forced to a list 2244 # during this comparison, because inputs are typically lists and 2245 # outputs of the body are typically tuples. 2246 nest.assert_same_structure( 2247 list(packed_vars_for_body), list(body_result), expand_composites=True) 2248 2249 # Store body_result to keep track of TensorArrays returned by body 2250 original_body_result = body_result 2251 # Convert TensorArrays returned by body into their flow variables 2252 result = nest.map_structure( 2253 _convert_tensorarray_to_flow, 2254 nest.flatten(body_result, expand_composites=True), 2255 expand_composites=True) 2256 result = ops.convert_n_to_tensor_or_composite(result) 2257 2258 # Add NextIteration and the back edges to complete the loop. 2259 if len(merge_vars) != len(result): 2260 raise ValueError("Number of inputs and outputs of body must match " 2261 "loop_vars: %d, %d" % (len(merge_vars), len(result))) 2262 next_vars = [] 2263 for m, v in zip(merge_vars, result): 2264 next_vars.append(_AddNextAndBackEdge(m, v)) 2265 2266 # Add the exit ops. 2267 exit_vars = [exit(x[0]) for x in switch_vars] 2268 self._loop_exits = exit_vars 2269 2270 # Exit the loop. 2271 self.ExitResult(exit_vars) 2272 2273 return original_body_result, exit_vars 2274 2275 def BuildLoop(self, pred, body, loop_vars, shape_invariants, 2276 return_same_structure): 2277 """Add the loop termination condition and body to the graph.""" 2278 2279 # Keep original_loop_vars to identify which are TensorArrays 2280 original_loop_vars = loop_vars 2281 # Convert TensorArrays to their flow variables 2282 loop_vars = nest.map_structure( 2283 _convert_tensorarray_to_flow, 2284 nest.flatten(loop_vars, expand_composites=False), 2285 expand_composites=True) 2286 loop_vars = ops.convert_n_to_tensor_or_composite(loop_vars) 2287 if shape_invariants is None: 2288 shape_invariants = nest.map_structure( 2289 _get_shape_invariant, loop_vars, expand_composites=False) 2290 loop_vars = nest.flatten(loop_vars, expand_composites=True) 2291 try: 2292 self.Enter() 2293 # _BuildLoop calls _update_input in several places. _mutation_lock() 2294 # ensures a Session.run call cannot occur between creating and mutating 2295 # new ops. 2296 with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access 2297 original_body_result, exit_vars = self._BuildLoop( 2298 pred, body, original_loop_vars, loop_vars, shape_invariants) 2299 finally: 2300 self.Exit() 2301 2302 flat_result = nest.flatten(original_body_result, expand_composites=True) 2303 # Convert TensorArray flow variables outside the context back into 2304 # their associated TensorArrays for returning to caller. 2305 exit_vars_with_tensor_arrays = ( 2306 _convert_flows_to_tensorarrays(flat_result, exit_vars)) 2307 packed_exit_vars = nest.pack_sequence_as( 2308 structure=original_body_result, 2309 flat_sequence=exit_vars_with_tensor_arrays, 2310 expand_composites=True) 2311 2312 if return_same_structure: 2313 return packed_exit_vars 2314 else: 2315 return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars 2316 2317 def _FixControlInputsAndContext(self, enters): 2318 graph = ops.get_default_graph() 2319 # pylint: disable=protected-access 2320 for e in enters: 2321 if isinstance(e, ops.Tensor): 2322 xs = [e] 2323 else: 2324 raise TypeError("Type %s not supported" % type(e)) 2325 for x in xs: 2326 inp_op = x.op.inputs[0].op 2327 control_inputs = graph._control_dependencies_for_inputs([inp_op]) 2328 outer_control_inputs = [] 2329 for op in control_inputs: 2330 # We need to keep control inputs that are in any ancestor 2331 # ControlFlowContext, and within outer WhileContext. 2332 keep_as_control_input = True 2333 op_ctxt = util.GetOutputContext(op) 2334 outer_ctxt = self.outer_context 2335 outer_while_context = (None if outer_ctxt is None else 2336 outer_ctxt.GetWhileContext()) 2337 while outer_ctxt != op_ctxt: 2338 if outer_ctxt is None or outer_ctxt == outer_while_context: 2339 keep_as_control_input = False 2340 break 2341 outer_ctxt = outer_ctxt.outer_context 2342 if keep_as_control_input: 2343 outer_control_inputs.append(op) 2344 x.op._set_control_flow_context(self) 2345 x.op._add_control_inputs(outer_control_inputs) 2346 graph._record_op_seen_by_control_dependencies(x.op) 2347 # pylint: enable=protected-access 2348 2349 def IsWhileContext(self): 2350 return True 2351 2352 2353# @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature". 2354# pylint: disable=redefined-outer-name 2355@tf_export("while_loop", v1=[]) 2356@deprecation.deprecated_arg_values( 2357 None, 2358 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 2359Instead of: 2360results = tf.while_loop(c, b, vars, back_prop=False) 2361Use: 2362results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""", 2363 warn_once=True, 2364 back_prop=False) 2365def while_loop_v2(cond, 2366 body, 2367 loop_vars, 2368 shape_invariants=None, 2369 parallel_iterations=10, 2370 back_prop=True, 2371 swap_memory=False, 2372 maximum_iterations=None, 2373 name=None): 2374 """Repeat `body` while the condition `cond` is true. 2375 2376 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 2377 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 2378 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 2379 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 2380 `cond` and `body`. `cond` and `body` both take as many arguments as there are 2381 `loop_vars`. 2382 2383 In addition to regular Tensors or IndexedSlices, the body may accept and 2384 return TensorArray objects. The flows of the TensorArray objects will 2385 be appropriately forwarded between loops and during gradient calculations. 2386 2387 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 2388 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 2389 stitches together the graph fragments created during the `cond` and `body` 2390 calls with some additional graph nodes to create the graph flow that 2391 repeats `body` until `cond` returns false. 2392 2393 For correctness, `tf.while_loop()` strictly enforces shape invariants for 2394 the loop variables. A shape invariant is a (possibly partial) shape that 2395 is unchanged across the iterations of the loop. An error will be raised 2396 if the shape of a loop variable after an iteration is determined to be more 2397 general than or incompatible with its shape invariant. For example, a shape 2398 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 2399 compatible with [11, 17]. By default (if the argument `shape_invariants` is 2400 not specified), it is assumed that the initial shape of each tensor in 2401 `loop_vars` is the same in every iteration. The `shape_invariants` argument 2402 allows the caller to specify a less specific shape invariant for each loop 2403 variable, which is needed if the shape varies between iterations. The 2404 `tf.Tensor.set_shape` 2405 function may also be used in the `body` function to indicate that 2406 the output loop variable has a particular shape. The shape invariant for 2407 SparseTensor and IndexedSlices are treated specially as follows: 2408 2409 a) If a loop variable is a SparseTensor, the shape invariant must be 2410 TensorShape([r]) where r is the rank of the dense tensor represented 2411 by the sparse tensor. It means the shapes of the three tensors of the 2412 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 2413 is the shape of the SparseTensor.dense_shape property. It must be the shape of 2414 a vector. 2415 2416 b) If a loop variable is an IndexedSlices, the shape invariant must be 2417 a shape invariant of the values tensor of the IndexedSlices. It means 2418 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 2419 [shape.ndims]). 2420 2421 `while_loop` implements non-strict semantics, enabling multiple iterations 2422 to run in parallel. The maximum number of parallel iterations can be 2423 controlled by `parallel_iterations`, which gives users some control over 2424 memory consumption and execution order. For correct programs, `while_loop` 2425 should return the same result for any parallel_iterations > 0. 2426 2427 For training, TensorFlow stores the tensors that are produced in the 2428 forward inference and are needed in back propagation. These tensors are a 2429 main source of memory consumption and often cause OOM errors when training 2430 on GPUs. When the flag swap_memory is true, we swap out these tensors from 2431 GPU to CPU. This for example allows us to train RNN models with very long 2432 sequences and large batches. 2433 2434 Args: 2435 cond: A callable that represents the termination condition of the loop. 2436 body: A callable that represents the loop body. 2437 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 2438 `Tensor`, and `TensorArray` objects. 2439 shape_invariants: The shape invariants for the loop variables. 2440 parallel_iterations: The number of iterations allowed to run in parallel. It 2441 must be a positive integer. 2442 back_prop: (optional) Deprecated. False disables support for back 2443 propagation. Prefer using `tf.stop_gradient` instead. 2444 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2445 maximum_iterations: Optional maximum number of iterations of the while loop 2446 to run. If provided, the `cond` output is AND-ed with an additional 2447 condition ensuring the number of iterations executed is no greater than 2448 `maximum_iterations`. 2449 name: Optional name prefix for the returned tensors. 2450 2451 Returns: 2452 The output tensors for the loop variables after the loop. The return value 2453 has the same structure as `loop_vars`. 2454 2455 Raises: 2456 TypeError: if `cond` or `body` is not callable. 2457 ValueError: if `loop_vars` is empty. 2458 2459 Example: 2460 2461 ```python 2462 i = tf.constant(0) 2463 c = lambda i: tf.less(i, 10) 2464 b = lambda i: (tf.add(i, 1), ) 2465 r = tf.while_loop(c, b, [i]) 2466 ``` 2467 2468 Example with nesting and a namedtuple: 2469 2470 ```python 2471 import collections 2472 Pair = collections.namedtuple('Pair', 'j, k') 2473 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 2474 c = lambda i, p: i < 10 2475 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 2476 ijk_final = tf.while_loop(c, b, ijk_0) 2477 ``` 2478 2479 Example using shape_invariants: 2480 2481 ```python 2482 i0 = tf.constant(0) 2483 m0 = tf.ones([2, 2]) 2484 c = lambda i, m: i < 10 2485 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 2486 tf.while_loop( 2487 c, b, loop_vars=[i0, m0], 2488 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 2489 ``` 2490 2491 Example which demonstrates non-strict semantics: In the following 2492 example, the final value of the counter `i` does not depend on `x`. So 2493 the `while_loop` can increment the counter parallel to updates of `x`. 2494 However, because the loop counter at one loop iteration depends 2495 on the value at the previous iteration, the loop counter itself cannot 2496 be incremented in parallel. Hence if we just want the final value of the 2497 counter (which we print on the line `print(sess.run(i))`), then 2498 `x` will never be incremented, but the counter will be updated on a 2499 single thread. Conversely, if we want the value of the output (which we 2500 print on the line `print(sess.run(out).shape)`), then the counter may be 2501 incremented on its own thread, while `x` can be incremented in 2502 parallel on a separate thread. In the extreme case, it is conceivable 2503 that the thread incrementing the counter runs until completion before 2504 `x` is incremented even a single time. The only thing that can never 2505 happen is that the thread updating `x` can never get ahead of the 2506 counter thread because the thread incrementing `x` depends on the value 2507 of the counter. 2508 2509 ```python 2510 import tensorflow as tf 2511 2512 n = 10000 2513 x = tf.constant(list(range(n))) 2514 c = lambda i, x: i < n 2515 b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1, 2516 [i], "x:")) 2517 i, out = tf.while_loop(c, b, (0, x)) 2518 with tf.compat.v1.Session() as sess: 2519 print(sess.run(i)) # prints [0] ... [9999] 2520 2521 # The following line may increment the counter and x in parallel. 2522 # The counter thread may get ahead of the other thread, but not the 2523 # other way around. So you may see things like 2524 # [9996] x:[9987] 2525 # meaning that the counter thread is on iteration 9996, 2526 # while the other thread is on iteration 9987 2527 print(sess.run(out).shape) 2528 ``` 2529 2530 """ 2531 return while_loop( 2532 cond=cond, 2533 body=body, 2534 loop_vars=loop_vars, 2535 shape_invariants=shape_invariants, 2536 parallel_iterations=parallel_iterations, 2537 back_prop=back_prop, 2538 swap_memory=swap_memory, 2539 name=name, 2540 maximum_iterations=maximum_iterations, 2541 return_same_structure=True) 2542 2543 2544# pylint: disable=redefined-outer-name 2545@tf_export(v1=["while_loop"]) 2546def while_loop(cond, 2547 body, 2548 loop_vars, 2549 shape_invariants=None, 2550 parallel_iterations=10, 2551 back_prop=True, 2552 swap_memory=False, 2553 name=None, 2554 maximum_iterations=None, 2555 return_same_structure=False): 2556 """Repeat `body` while the condition `cond` is true. 2557 2558 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 2559 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 2560 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 2561 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 2562 `cond` and `body`. `cond` and `body` both take as many arguments as there are 2563 `loop_vars`. 2564 2565 In addition to regular Tensors or IndexedSlices, the body may accept and 2566 return TensorArray objects. The flows of the TensorArray objects will 2567 be appropriately forwarded between loops and during gradient calculations. 2568 2569 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 2570 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 2571 stitches together the graph fragments created during the `cond` and `body` 2572 calls with some additional graph nodes to create the graph flow that 2573 repeats `body` until `cond` returns false. 2574 2575 For correctness, `tf.while_loop()` strictly enforces shape invariants for 2576 the loop variables. A shape invariant is a (possibly partial) shape that 2577 is unchanged across the iterations of the loop. An error will be raised 2578 if the shape of a loop variable after an iteration is determined to be more 2579 general than or incompatible with its shape invariant. For example, a shape 2580 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 2581 compatible with [11, 17]. By default (if the argument `shape_invariants` is 2582 not specified), it is assumed that the initial shape of each tensor in 2583 `loop_vars` is the same in every iteration. The `shape_invariants` argument 2584 allows the caller to specify a less specific shape invariant for each loop 2585 variable, which is needed if the shape varies between iterations. The 2586 `tf.Tensor.set_shape` 2587 function may also be used in the `body` function to indicate that 2588 the output loop variable has a particular shape. The shape invariant for 2589 SparseTensor and IndexedSlices are treated specially as follows: 2590 2591 a) If a loop variable is a SparseTensor, the shape invariant must be 2592 TensorShape([r]) where r is the rank of the dense tensor represented 2593 by the sparse tensor. It means the shapes of the three tensors of the 2594 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 2595 is the shape of the SparseTensor.dense_shape property. It must be the shape of 2596 a vector. 2597 2598 b) If a loop variable is an IndexedSlices, the shape invariant must be 2599 a shape invariant of the values tensor of the IndexedSlices. It means 2600 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 2601 [shape.ndims]). 2602 2603 `while_loop` implements non-strict semantics, enabling multiple iterations 2604 to run in parallel. The maximum number of parallel iterations can be 2605 controlled by `parallel_iterations`, which gives users some control over 2606 memory consumption and execution order. For correct programs, `while_loop` 2607 should return the same result for any parallel_iterations > 0. 2608 2609 For training, TensorFlow stores the tensors that are produced in the 2610 forward inference and are needed in back propagation. These tensors are a 2611 main source of memory consumption and often cause OOM errors when training 2612 on GPUs. When the flag swap_memory is true, we swap out these tensors from 2613 GPU to CPU. This for example allows us to train RNN models with very long 2614 sequences and large batches. 2615 2616 Args: 2617 cond: A callable that represents the termination condition of the loop. 2618 body: A callable that represents the loop body. 2619 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 2620 `Tensor`, and `TensorArray` objects. 2621 shape_invariants: The shape invariants for the loop variables. 2622 parallel_iterations: The number of iterations allowed to run in parallel. It 2623 must be a positive integer. 2624 back_prop: Whether backprop is enabled for this while loop. 2625 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2626 name: Optional name prefix for the returned tensors. 2627 maximum_iterations: Optional maximum number of iterations of the while loop 2628 to run. If provided, the `cond` output is AND-ed with an additional 2629 condition ensuring the number of iterations executed is no greater than 2630 `maximum_iterations`. 2631 return_same_structure: If True, output has same structure as `loop_vars`. If 2632 eager execution is enabled, this is ignored (and always treated as True). 2633 2634 Returns: 2635 The output tensors for the loop variables after the loop. 2636 If `return_same_structure` is True, the return value has the same 2637 structure as `loop_vars`. 2638 If `return_same_structure` is False, the return value is a Tensor, 2639 TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list 2640 otherwise. 2641 2642 Raises: 2643 TypeError: if `cond` or `body` is not callable. 2644 ValueError: if `loop_vars` is empty. 2645 2646 Example: 2647 2648 ```python 2649 i = tf.constant(0) 2650 c = lambda i: tf.less(i, 10) 2651 b = lambda i: tf.add(i, 1) 2652 r = tf.while_loop(c, b, [i]) 2653 ``` 2654 2655 Example with nesting and a namedtuple: 2656 2657 ```python 2658 import collections 2659 Pair = collections.namedtuple('Pair', 'j, k') 2660 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 2661 c = lambda i, p: i < 10 2662 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 2663 ijk_final = tf.while_loop(c, b, ijk_0) 2664 ``` 2665 2666 Example using shape_invariants: 2667 2668 ```python 2669 i0 = tf.constant(0) 2670 m0 = tf.ones([2, 2]) 2671 c = lambda i, m: i < 10 2672 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 2673 tf.while_loop( 2674 c, b, loop_vars=[i0, m0], 2675 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 2676 ``` 2677 2678 Example which demonstrates non-strict semantics: In the following 2679 example, the final value of the counter `i` does not depend on `x`. So 2680 the `while_loop` can increment the counter parallel to updates of `x`. 2681 However, because the loop counter at one loop iteration depends 2682 on the value at the previous iteration, the loop counter itself cannot 2683 be incremented in parallel. Hence if we just want the final value of the 2684 counter (which we print on the line `print(sess.run(i))`), then 2685 `x` will never be incremented, but the counter will be updated on a 2686 single thread. Conversely, if we want the value of the output (which we 2687 print on the line `print(sess.run(out).shape)`), then the counter may be 2688 incremented on its own thread, while `x` can be incremented in 2689 parallel on a separate thread. In the extreme case, it is conceivable 2690 that the thread incrementing the counter runs until completion before 2691 `x` is incremented even a single time. The only thing that can never 2692 happen is that the thread updating `x` can never get ahead of the 2693 counter thread because the thread incrementing `x` depends on the value 2694 of the counter. 2695 2696 ```python 2697 import tensorflow as tf 2698 2699 n = 10000 2700 x = tf.constant(list(range(n))) 2701 c = lambda i, x: i < n 2702 b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1, 2703 [i], "x:")) 2704 i, out = tf.while_loop(c, b, (0, x)) 2705 with tf.compat.v1.Session() as sess: 2706 print(sess.run(i)) # prints [0] ... [9999] 2707 2708 # The following line may increment the counter and x in parallel. 2709 # The counter thread may get ahead of the other thread, but not the 2710 # other way around. So you may see things like 2711 # [9996] x:[9987] 2712 # meaning that the counter thread is on iteration 9996, 2713 # while the other thread is on iteration 9987 2714 print(sess.run(out).shape) 2715 ``` 2716 2717 """ 2718 if not callable(cond): 2719 raise TypeError("cond must be callable.") 2720 if not callable(body): 2721 raise TypeError("body must be callable.") 2722 if parallel_iterations < 1: 2723 raise TypeError("parallel_iterations must be a positive integer.") 2724 2725 # Always enable control flow v2 if building a function, regardless of toggle. 2726 executing_eagerly = context.executing_eagerly() 2727 if (util.EnableControlFlowV2(ops.get_default_graph()) and 2728 not executing_eagerly): 2729 return while_v2.while_loop( 2730 cond, 2731 body, 2732 loop_vars, 2733 shape_invariants=shape_invariants, 2734 parallel_iterations=parallel_iterations, 2735 maximum_iterations=maximum_iterations, 2736 name=name, 2737 return_same_structure=return_same_structure, 2738 back_prop=back_prop) 2739 2740 with ops.name_scope(name, "while", loop_vars): 2741 if not loop_vars: 2742 raise ValueError("No loop variables provided") 2743 try_to_pack = (len(loop_vars) == 1 and not return_same_structure) 2744 if maximum_iterations is not None: 2745 maximum_iterations = ops.convert_to_tensor( 2746 maximum_iterations, name="maximum_iterations") 2747 if maximum_iterations.shape.ndims != 0: 2748 raise ValueError("maximum_iterations must be a scalar, saw shape: %s" % 2749 maximum_iterations.shape) 2750 2751 if executing_eagerly: 2752 counter = 0 2753 maximum_iterations = int(maximum_iterations.numpy()) 2754 else: 2755 counter = constant_op.constant( 2756 0, dtype=maximum_iterations.dtype, name="iteration_counter") 2757 orig_cond = cond 2758 orig_body = body 2759 if try_to_pack: 2760 loop_vars = (counter, loop_vars[0]) 2761 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 2762 math_ops.logical_and(i < maximum_iterations, orig_cond(lv))) 2763 body = lambda i, lv: (i + 1, orig_body(lv)) 2764 else: 2765 loop_vars = (counter, loop_vars) 2766 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 2767 math_ops.logical_and(i < maximum_iterations, orig_cond(*lv))) 2768 body = lambda i, lv: (i + 1, orig_body(*lv)) 2769 try_to_pack = False 2770 2771 if executing_eagerly: 2772 packed = False # whether the body result was packed into a 1-item tuple 2773 2774 loop_var_structure = nest.map_structure(type_spec.type_spec_from_value, 2775 list(loop_vars)) 2776 while cond(*loop_vars): 2777 loop_vars = body(*loop_vars) 2778 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)): 2779 packed = True 2780 loop_vars = (loop_vars,) 2781 nest.assert_same_structure(loop_var_structure, list(loop_vars)) 2782 2783 def convert(x): 2784 if isinstance(x, tensor_array_ops.TensorArray): 2785 return x 2786 return ops.convert_to_tensor(x) 2787 2788 loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True) 2789 if maximum_iterations is not None: 2790 return loop_vars[1] 2791 else: 2792 return loop_vars[0] if packed else loop_vars 2793 2794 if shape_invariants is not None: 2795 if maximum_iterations is not None: 2796 shape_invariants = (tensor_shape.TensorShape([]), shape_invariants) 2797 2798 nest.assert_same_structure( 2799 loop_vars, shape_invariants, expand_composites=False) 2800 shape_invariants = nest.map_structure( 2801 _get_shape_invariant, 2802 loop_vars, 2803 shape_invariants, 2804 expand_composites=False) 2805 2806 loop_context = WhileContext( 2807 maximum_iterations=maximum_iterations, 2808 parallel_iterations=parallel_iterations, 2809 back_prop=back_prop, 2810 swap_memory=swap_memory) 2811 # Only add non-nested loops to the collection. Any nested control flow will 2812 # be encapsulated in the root context. 2813 if loop_context.outer_context is None: 2814 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) 2815 result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, 2816 return_same_structure) 2817 if maximum_iterations is not None: 2818 return result[1] 2819 else: 2820 return result 2821 2822 2823# pylint: enable=redefined-outer-name 2824 2825 2826def _AsTensorList(x, p): 2827 """Return x as a list of Tensors or IndexedSlices. 2828 2829 For entries of `x` that are Operations, this returns an Identity of `p` 2830 with a dependency on the operation. 2831 2832 Args: 2833 x: A Tensor/IndexedSlices/Operation or a list or tuple of them. 2834 p: A Tensor to return for entries in `x` that are Operations. 2835 2836 Returns: 2837 A list of Tensors or IndexedSlices. 2838 """ 2839 if not isinstance(x, (list, _basetuple)): 2840 x = [x] 2841 2842 l = [] 2843 for v in x: 2844 if isinstance(v, ops.Operation): 2845 v = with_dependencies([v], p) 2846 v = ops.convert_to_tensor_or_composite(v) 2847 if isinstance(v, ops.Tensor): 2848 l.append(array_ops.identity(v)) 2849 else: 2850 l.append( 2851 ops.IndexedSlices( 2852 array_ops.identity(v.values), array_ops.identity(v.indices))) 2853 return l 2854 2855 2856def _CheckResults(a, b): 2857 assert len(a) == len(b), ( 2858 "Values returned by a() and b() must have the same length.") 2859 for x, y in zip(a, b): 2860 assert x.dtype == y.dtype, ( 2861 "Values returned by a() [%s] and b() [%s] must have " 2862 "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name)) 2863 2864 2865def with_dependencies(dependencies, output_tensor, name=None): 2866 """Produces the content of `output_tensor` only after `dependencies`. 2867 2868 In some cases, a user may want the output of an operation to be 2869 consumed externally only after some other dependencies have run 2870 first. This function ensures returns `output_tensor`, but only after all 2871 operations in `dependencies` have run. Note that this means that there is 2872 no guarantee that `output_tensor` will be evaluated after any `dependencies` 2873 have run. 2874 2875 See also `tf.tuple` and `tf.group`. 2876 2877 Args: 2878 dependencies: Iterable of operations to run before this op finishes. 2879 output_tensor: A `Tensor` or `IndexedSlices` that will be returned. 2880 name: (Optional) A name for this operation. 2881 2882 Returns: 2883 Same as `output_tensor`. 2884 2885 Raises: 2886 TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. 2887 """ 2888 if context.executing_eagerly(): 2889 return output_tensor 2890 with ops.name_scope(name, "control_dependency", 2891 list(dependencies) + [output_tensor]) as name: 2892 with ops.colocate_with(output_tensor): 2893 with ops.control_dependencies(dependencies): 2894 output_tensor = ops.convert_to_tensor_or_composite(output_tensor) 2895 if isinstance(output_tensor, ops.Tensor): 2896 return _Identity(output_tensor, name=name) 2897 else: 2898 return ops.IndexedSlices( 2899 _Identity(output_tensor.values, name=name), output_tensor.indices, 2900 output_tensor.dense_shape) 2901 2902 2903def _GroupControlDeps(dev, deps, name=None): 2904 with ops.control_dependencies(deps): 2905 if dev is None: 2906 return no_op(name=name) 2907 else: 2908 with ops.device(dev): 2909 return no_op(name=name) 2910 2911 2912# TODO(touts): Accept "inputs" as a list. 2913@tf_export("group") 2914def group(*inputs, **kwargs): 2915 """Create an op that groups multiple operations. 2916 2917 When this op finishes, all ops in `inputs` have finished. This op has no 2918 output. 2919 2920 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 2921 this method, as ops execute in the expected order thanks to automatic control 2922 dependencies.* Only use `tf.group` when working with v1 2923 `tf.Graph` code. 2924 2925 When operating in a v1-style graph context, ops are not executed in the same 2926 order as specified in the code; TensorFlow will attempt to execute ops in 2927 parallel or in an order convenient to the result it is computing. `tf.group` 2928 allows you to request that one or more results finish before execution 2929 continues. 2930 2931 `tf.group` creates a single op (of type `NoOp`), and then adds appropriate 2932 control dependencies. Thus, `c = tf.group(a, b)` will compute the same graph 2933 as this: 2934 2935 with tf.control_dependencies([a, b]): 2936 c = tf.no_op() 2937 2938 See also `tf.tuple` and 2939 `tf.control_dependencies`. 2940 2941 Args: 2942 *inputs: Zero or more tensors to group. 2943 name: A name for this operation (optional). 2944 2945 Returns: 2946 An Operation that executes all its inputs. 2947 2948 Raises: 2949 ValueError: If an unknown keyword argument is provided. 2950 """ 2951 if context.executing_eagerly(): 2952 return None 2953 name = kwargs.pop("name", None) 2954 if kwargs: 2955 raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys())) 2956 with ops.name_scope(name, "group_deps", inputs) as name: 2957 # Grouping no inputs means do nothing 2958 if not inputs: 2959 return no_op(name=name) 2960 2961 # Sorts *inputs according to their devices. 2962 ops_on_device = {} # device -> operations specified on the device. 2963 for inp in nest.flatten(inputs, expand_composites=True): 2964 if not hasattr(inp, "device"): 2965 raise TypeError("Expected tf.group() expected Tensor arguments not " 2966 "'%s' with type '%s'" % (inp, type(inp))) 2967 dev = inp.device 2968 if dev in ops_on_device: 2969 ops_on_device[dev].append(inp) 2970 else: 2971 ops_on_device[dev] = [inp] 2972 if len(ops_on_device) == 1: 2973 # 1-level tree. The root node is the returned NoOp node. 2974 (dev, deps), = ops_on_device.items() 2975 return _GroupControlDeps(dev, deps, name=name) 2976 2977 # 2-level tree. The root node is the returned NoOp node. 2978 # deps contains 1 NoOp node for each device. 2979 deps = [] 2980 2981 def device_key(dev): 2982 """A sort key that allows None to be compared to strings.""" 2983 return "" if dev is None else dev 2984 2985 for dev in sorted(ops_on_device, key=device_key): 2986 deps.append(_GroupControlDeps(dev, ops_on_device[dev])) 2987 2988 with ops.control_dependencies(deps): 2989 return no_op(name=name) 2990 2991 2992@tf_export("tuple", v1=[]) 2993@dispatch.add_dispatch_support 2994def tuple_v2(tensors, control_inputs=None, name=None): 2995 """Groups tensors together. 2996 2997 The returned tensors have the same value as the input tensors, but they 2998 are computed only after all the input tensors have been computed. 2999 3000 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 3001 this method, as ops execute in the expected order thanks to automatic control 3002 dependencies.* Only use `tf.tuple` when working with v1 `tf.Graph` code. 3003 3004 See also `tf.group` and `tf.control_dependencies`. 3005 3006 Args: 3007 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 3008 control_inputs: List of additional ops to finish before returning. 3009 name: (optional) A name to use as a `name_scope` for the operation. 3010 3011 Returns: 3012 Same as `tensors`. 3013 3014 Raises: 3015 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 3016 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 3017 objects. 3018 3019 """ 3020 return tuple(tensors=tensors, name=name, control_inputs=control_inputs) # pylint: disable=redefined-builtin 3021 3022 3023@tf_export(v1=["tuple"]) 3024@dispatch.add_dispatch_support 3025def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin 3026 """Group tensors together. 3027 3028 This creates a tuple of tensors with the same values as the `tensors` 3029 argument, except that the value of each tensor is only returned after the 3030 values of all tensors have been computed. 3031 3032 `control_inputs` contains additional ops that have to finish before this op 3033 finishes, but whose outputs are not returned. 3034 3035 This can be used as a "join" mechanism for parallel computations: all the 3036 argument tensors can be computed in parallel, but the values of any tensor 3037 returned by `tuple` are only available after all the parallel computations 3038 are done. 3039 3040 See also `tf.group` and 3041 `tf.control_dependencies`. 3042 3043 Args: 3044 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 3045 name: (optional) A name to use as a `name_scope` for the operation. 3046 control_inputs: List of additional ops to finish before returning. 3047 3048 Returns: 3049 Same as `tensors`. 3050 3051 Raises: 3052 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 3053 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 3054 objects. 3055 3056 """ 3057 if context.executing_eagerly(): 3058 return tensors 3059 with ops.name_scope(name, "tuple", tensors) as name: 3060 tensors = [ 3061 t if (isinstance(t, ops.Operation) or tensor_util.is_tf_type(t) or 3062 t is None) else ops.convert_to_tensor(t) for t in tensors 3063 ] 3064 gating_ops = [ 3065 t if isinstance(t, ops.Operation) else t.op 3066 for t in tensors 3067 if t is not None 3068 ] 3069 if control_inputs: 3070 for c in control_inputs: 3071 if isinstance(c, ops.Tensor): 3072 c = c.op 3073 elif not isinstance(c, ops.Operation): 3074 raise TypeError("Control input must be Operation or Tensor: %s" % c) 3075 gating_ops.append(c) 3076 # Note that in order to ensure ordering in the pbtxt, we must take care to 3077 # ensure the order here. 3078 gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops. 3079 if not gating_ops: 3080 raise ValueError("Must have at least one Tensor: %s" % tensors) 3081 gate = group(*gating_ops) 3082 tpl = [] 3083 for t in tensors: 3084 if tensor_util.is_tf_type(t): 3085 tpl.append(with_dependencies([gate], t)) 3086 elif isinstance(t, ops.Operation): 3087 with ops.control_dependencies([gate]): 3088 tpl.append(group(t)) 3089 else: 3090 tpl.append(None) 3091 return tpl 3092 3093 3094def _assert_at_most_n_true(predicates, n, msg): 3095 """Returns an Assert op that checks that at most n predicates are True. 3096 3097 Args: 3098 predicates: list of bool scalar tensors. 3099 n: maximum number of true predicates allowed. 3100 msg: Error message. 3101 """ 3102 preds_c = array_ops.stack(predicates, name="preds_c") 3103 num_true_conditions = math_ops.reduce_sum( 3104 math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") 3105 condition = math_ops.less_equal(num_true_conditions, 3106 constant_op.constant(n, name="n_true_conds")) 3107 preds_names = ", ".join(getattr(p, "name", "?") for p in predicates) 3108 error_msg = [ 3109 "%s: more than %d conditions (%s) evaluated as True:" % 3110 (msg, n, preds_names), preds_c 3111 ] 3112 return Assert(condition, data=error_msg, summarize=len(predicates)) 3113 3114 3115def _case_create_default_action(predicates, actions): 3116 """Creates default action for a list of actions and their predicates. 3117 3118 It uses the input actions to select an arbitrary as default and makes sure 3119 that corresponding predicates have valid values. 3120 3121 Args: 3122 predicates: a list of bool scalar tensors 3123 actions: a list of callable objects which return tensors. 3124 3125 Returns: 3126 a callable 3127 """ 3128 k = len(predicates) - 1 # could pick any 3129 predicate, action = predicates[k], actions[k] 3130 other_predicates, other_actions = predicates[:k], actions[:k] 3131 3132 def default_action(): 3133 others_msg = ("Implementation error: " 3134 "selected default action #%d was called, but some of other " 3135 "predicates are True: " % k) 3136 default_msg = ("Input error: " 3137 "None of conditions evaluated as True:", 3138 array_ops.stack(predicates, name="preds_c")) 3139 with ops.control_dependencies([ 3140 _assert_at_most_n_true(other_predicates, n=0, msg=others_msg), 3141 Assert(predicate, data=default_msg) 3142 ]): 3143 return action() 3144 3145 return default_action, other_predicates, other_actions 3146 3147 3148def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name, 3149 allow_python_preds): 3150 """Verifies input arguments for the case function. 3151 3152 Args: 3153 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 3154 callable which returns a list of tensors. 3155 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3156 name: A name for the case operation. 3157 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 3158 addition to boolean Tensors 3159 3160 Raises: 3161 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3162 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3163 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3164 callable. 3165 3166 Returns: 3167 a tuple <list of scalar bool tensors, list of callables>. 3168 """ 3169 if not isinstance(pred_fn_pairs, (list, _basetuple, dict)): 3170 raise TypeError("fns must be a list, tuple, or dict") 3171 3172 if isinstance(pred_fn_pairs, collections.OrderedDict): 3173 pred_fn_pairs = pred_fn_pairs.items() 3174 elif isinstance(pred_fn_pairs, dict): 3175 if context.executing_eagerly(): 3176 # No name to sort on in eager mode. Use dictionary traversal order, 3177 # which is nondeterministic in versions of Python < 3.6 3178 if not exclusive: 3179 raise ValueError("Unordered dictionaries are not supported for the " 3180 "`pred_fn_pairs` argument when `exclusive=False` and " 3181 "eager mode is enabled.") 3182 pred_fn_pairs = list(pred_fn_pairs.items()) 3183 else: 3184 pred_fn_pairs = sorted( 3185 pred_fn_pairs.items(), key=lambda item: item[0].name) 3186 if not exclusive: 3187 logging.warn( 3188 "%s: An unordered dictionary of predicate/fn pairs was " 3189 "provided, but exclusive=False. The order of conditional " 3190 "tests is deterministic but not guaranteed.", name) 3191 for pred_fn_pair in pred_fn_pairs: 3192 if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2: 3193 raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple") 3194 pred, fn = pred_fn_pair 3195 3196 if isinstance(pred, ops.Tensor): 3197 if pred.dtype != dtypes.bool: 3198 raise TypeError("pred must be Tensor of type bool: %s" % pred.name) 3199 elif not allow_python_preds: 3200 raise TypeError("pred must be a Tensor, got: %s" % pred) 3201 elif not isinstance(pred, bool): 3202 raise TypeError("pred must be a Tensor or bool, got: %s" % pred) 3203 3204 if not callable(fn): 3205 raise TypeError("fn for pred %s must be callable." % pred.name) 3206 3207 predicates, actions = zip(*pred_fn_pairs) 3208 return predicates, actions 3209 3210 3211def _case_helper(cond_fn, 3212 pred_fn_pairs, 3213 default, 3214 exclusive, 3215 name, 3216 allow_python_preds=False, 3217 **cond_kwargs): 3218 """Implementation of case that allows for different cond functions. 3219 3220 Args: 3221 cond_fn: method that has signature and semantics of `cond` above. 3222 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 3223 callable which returns a list of tensors. 3224 default: Optional callable that returns a list of tensors. 3225 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3226 name: A name for this operation (optional). 3227 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 3228 addition to boolean Tensors 3229 **cond_kwargs: keyword arguments that will be passed to `cond_fn`. 3230 3231 Returns: 3232 The tensors returned by the first pair whose predicate evaluated to True, or 3233 those returned by `default` if none does. 3234 3235 Raises: 3236 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3237 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3238 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3239 callable. 3240 """ 3241 predicates, actions = _case_verify_and_canonicalize_args( 3242 pred_fn_pairs, exclusive, name, allow_python_preds) 3243 with ops.name_scope(name, "case", [predicates]): 3244 if default is None: 3245 default, predicates, actions = _case_create_default_action( 3246 predicates, actions) 3247 fn = default 3248 # To eval conditions in direct order we create nested conditions in reverse: 3249 # cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...)) 3250 for predicate, action in reversed(list(zip(predicates, actions))): 3251 fn = functools.partial( 3252 cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs) 3253 if exclusive: 3254 with ops.control_dependencies([ 3255 _assert_at_most_n_true( 3256 predicates, n=1, msg="Input error: exclusive=True") 3257 ]): 3258 return fn() 3259 else: 3260 return fn() 3261 3262 3263def _indexed_case_verify_and_canonicalize_args(branch_fns, default, 3264 branch_index): 3265 """Verifies input arguments for the case function. 3266 3267 Args: 3268 branch_fns: Dict or list of pairs of an `int` and a callable which 3269 returns a list of tensors. 3270 default: Optional callable that returns a list of tensors. 3271 branch_index: Optional int `Tensor`, which selects for the corresponding 3272 pred_fn_pair. 3273 3274 Raises: 3275 TypeError: If `branch_fns` is not a list/dictionary. 3276 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3277 callables. 3278 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3279 callable. 3280 3281 Returns: 3282 branch_fns: validated list of callables for each branch (default last). 3283 """ 3284 if not isinstance(branch_index, ops.Tensor): 3285 raise TypeError("branch_index must be a Tensor, got {}".format( 3286 type(branch_index))) 3287 if not branch_index.dtype.is_integer: 3288 raise TypeError("branch_index must be an integer Tensor, got {}".format( 3289 branch_index.dtype)) 3290 3291 if not branch_fns: 3292 raise ValueError("Must provide at least one item in branch_fns") 3293 if not isinstance(branch_fns, (list, _basetuple, dict)): 3294 raise TypeError("branch_fns must be a list, tuple, or dict") 3295 3296 if isinstance(branch_fns, dict): 3297 branch_fns = branch_fns.items() 3298 3299 if all(callable(fn) for fn in branch_fns): 3300 branch_fns = list(enumerate(branch_fns)) 3301 3302 for key_fn_pair in branch_fns: 3303 if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2: 3304 raise TypeError("Each entry in branch_fns must be a 2-tuple") 3305 key, branch_fn = key_fn_pair 3306 3307 if not isinstance(key, int): 3308 raise TypeError("key must be a Python `int`, got {}".format(type(key))) 3309 3310 if not callable(branch_fn): 3311 raise TypeError("fn for key {} must be callable.".format(key)) 3312 3313 keys = [p[0] for p in branch_fns] 3314 if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys): 3315 raise ValueError( 3316 "branch indices (keys) must form contiguous range of [0 to {}) but " 3317 "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys))))) 3318 actions = [p[1] for p in sorted(branch_fns)] 3319 if default is not None: 3320 actions.append(default) 3321 return actions 3322 3323 3324def _indexed_case_helper(branch_fns, 3325 default, 3326 branch_index, 3327 name, 3328 lower_using_switch_merge=None): 3329 """Implementation of case that emits the n-way indexed Case op. 3330 3331 Args: 3332 branch_fns: Dict or list of pairs of a boolean scalar tensor, and a 3333 callable which returns a list of tensors. 3334 default: Optional callable that returns a list of tensors. 3335 branch_index: Optional int `Tensor`, which selects for the corresponding 3336 pred_fn_pair. 3337 name: A name for this operation (optional). 3338 lower_using_switch_merge: Lower this op using switch merge ops (optional). 3339 3340 Returns: 3341 The tensors returned by the pair whose key matched branch_index, or 3342 those returned by `default` if none does. 3343 3344 Raises: 3345 TypeError: If `branch_fns` is not a list/dictionary. 3346 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3347 callables. 3348 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3349 callable. 3350 """ 3351 branch_fns = _indexed_case_verify_and_canonicalize_args( 3352 branch_fns, default, branch_index) 3353 with ops.name_scope(name, "case", [branch_index]): 3354 if context.executing_eagerly() and not hasattr(branch_index, "graph"): 3355 branch_index = array_ops.where( 3356 math_ops.less(branch_index, 0) 3357 | math_ops.greater_equal(branch_index, len(branch_fns)), 3358 len(branch_fns) - 1, branch_index) 3359 return branch_fns[int(branch_index)]() 3360 return cond_v2.indexed_case( 3361 branch_index, 3362 branch_fns, 3363 lower_using_switch_merge=lower_using_switch_merge) 3364 3365 3366@tf_export("case", v1=[]) 3367@dispatch.add_dispatch_support 3368def case_v2(pred_fn_pairs, 3369 default=None, 3370 exclusive=False, 3371 strict=False, 3372 name="case"): 3373 """Create a case operation. 3374 3375 See also `tf.switch_case`. 3376 3377 The `pred_fn_pairs` parameter is a list of pairs of size N. 3378 Each pair contains a boolean scalar tensor and a python callable that 3379 creates the tensors to be returned if the boolean evaluates to True. 3380 `default` is a callable generating a list of tensors. All the callables 3381 in `pred_fn_pairs` as well as `default` (if provided) should return the same 3382 number and types of tensors. 3383 3384 If `exclusive==True`, all predicates are evaluated, and an exception is 3385 thrown if more than one of the predicates evaluates to `True`. 3386 If `exclusive==False`, execution stops at the first predicate which 3387 evaluates to True, and the tensors generated by the corresponding function 3388 are returned immediately. If none of the predicates evaluate to True, this 3389 operation returns the tensors generated by `default`. 3390 3391 `tf.case` supports nested structures as implemented in 3392 `tf.contrib.framework.nest`. All of the callables must return the same 3393 (possibly nested) value structure of lists, tuples, and/or named tuples. 3394 Singleton lists and tuples form the only exceptions to this: when returned by 3395 a callable, they are implicitly unpacked to single values. This 3396 behavior is disabled by passing `strict=True`. 3397 3398 @compatibility(v2) 3399 `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and 3400 tf.Variable are no longer hashable in v2, so cannot be used as a key for a 3401 dictionary. Please use a list or a tuple instead. 3402 @end_compatibility 3403 3404 3405 **Example 1:** 3406 3407 Pseudocode: 3408 3409 ``` 3410 if (x < y) return 17; 3411 else return 23; 3412 ``` 3413 3414 Expressions: 3415 3416 ```python 3417 f1 = lambda: tf.constant(17) 3418 f2 = lambda: tf.constant(23) 3419 r = tf.case([(tf.less(x, y), f1)], default=f2) 3420 ``` 3421 3422 **Example 2:** 3423 3424 Pseudocode: 3425 3426 ``` 3427 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 3428 if (x < y) return 17; 3429 else if (x > z) return 23; 3430 else return -1; 3431 ``` 3432 3433 Expressions: 3434 3435 ```python 3436 def f1(): return tf.constant(17) 3437 def f2(): return tf.constant(23) 3438 def f3(): return tf.constant(-1) 3439 r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)], 3440 default=f3, exclusive=True) 3441 ``` 3442 3443 Args: 3444 pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which 3445 returns a list of tensors. 3446 default: Optional callable that returns a list of tensors. 3447 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3448 strict: A boolean that enables/disables 'strict' mode; see above. 3449 name: A name for this operation (optional). 3450 3451 Returns: 3452 The tensors returned by the first pair whose predicate evaluated to True, or 3453 those returned by `default` if none does. 3454 3455 Raises: 3456 TypeError: If `pred_fn_pairs` is not a list/tuple. 3457 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3458 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3459 callable. 3460 """ 3461 return _case_helper( 3462 cond, 3463 pred_fn_pairs, 3464 default, 3465 exclusive, 3466 name, 3467 allow_python_preds=False, 3468 strict=strict) 3469 3470 3471@tf_export(v1=["case"]) 3472@dispatch.add_dispatch_support 3473def case(pred_fn_pairs, 3474 default=None, 3475 exclusive=False, 3476 strict=False, 3477 name="case"): 3478 """Create a case operation. 3479 3480 See also `tf.switch_case`. 3481 3482 The `pred_fn_pairs` parameter is a dict or list of pairs of size N. 3483 Each pair contains a boolean scalar tensor and a python callable that 3484 creates the tensors to be returned if the boolean evaluates to True. 3485 `default` is a callable generating a list of tensors. All the callables 3486 in `pred_fn_pairs` as well as `default` (if provided) should return the same 3487 number and types of tensors. 3488 3489 If `exclusive==True`, all predicates are evaluated, and an exception is 3490 thrown if more than one of the predicates evaluates to `True`. 3491 If `exclusive==False`, execution stops at the first predicate which 3492 evaluates to True, and the tensors generated by the corresponding function 3493 are returned immediately. If none of the predicates evaluate to True, this 3494 operation returns the tensors generated by `default`. 3495 3496 `tf.case` supports nested structures as implemented in 3497 `tf.contrib.framework.nest`. All of the callables must return the same 3498 (possibly nested) value structure of lists, tuples, and/or named tuples. 3499 Singleton lists and tuples form the only exceptions to this: when returned by 3500 a callable, they are implicitly unpacked to single values. This 3501 behavior is disabled by passing `strict=True`. 3502 3503 If an unordered dictionary is used for `pred_fn_pairs`, the order of the 3504 conditional tests is not guaranteed. However, the order is guaranteed to be 3505 deterministic, so that variables created in conditional branches are created 3506 in fixed order across runs. 3507 3508 @compatibility(eager) 3509 Unordered dictionaries are not supported in eager mode when `exclusive=False`. 3510 Use a list of tuples instead. 3511 @end_compatibility 3512 3513 3514 **Example 1:** 3515 3516 Pseudocode: 3517 3518 ``` 3519 if (x < y) return 17; 3520 else return 23; 3521 ``` 3522 3523 Expressions: 3524 3525 ```python 3526 f1 = lambda: tf.constant(17) 3527 f2 = lambda: tf.constant(23) 3528 r = tf.case([(tf.less(x, y), f1)], default=f2) 3529 ``` 3530 3531 **Example 2:** 3532 3533 Pseudocode: 3534 3535 ``` 3536 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 3537 if (x < y) return 17; 3538 else if (x > z) return 23; 3539 else return -1; 3540 ``` 3541 3542 Expressions: 3543 3544 ```python 3545 def f1(): return tf.constant(17) 3546 def f2(): return tf.constant(23) 3547 def f3(): return tf.constant(-1) 3548 r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2}, 3549 default=f3, exclusive=True) 3550 ``` 3551 3552 Args: 3553 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 3554 callable which returns a list of tensors. 3555 default: Optional callable that returns a list of tensors. 3556 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3557 strict: A boolean that enables/disables 'strict' mode; see above. 3558 name: A name for this operation (optional). 3559 3560 Returns: 3561 The tensors returned by the first pair whose predicate evaluated to True, or 3562 those returned by `default` if none does. 3563 3564 Raises: 3565 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3566 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3567 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3568 callable. 3569 """ 3570 return _case_helper( 3571 cond, 3572 pred_fn_pairs, 3573 default, 3574 exclusive, 3575 name, 3576 allow_python_preds=False, 3577 strict=strict) 3578 3579 3580@tf_export("switch_case") 3581def switch_case(branch_index, 3582 branch_fns, 3583 default=None, 3584 name="switch_case"): 3585 """Create a switch/case operation, i.e. an integer-indexed conditional. 3586 3587 See also `tf.case`. 3588 3589 This op can be substantially more efficient than `tf.case` when exactly one 3590 branch will be selected. `tf.switch_case` is more like a C++ switch/case 3591 statement than `tf.case`, which is more like an if/elif/elif/else chain. 3592 3593 The `branch_fns` parameter is either a dict from `int` to callables, or list 3594 of (`int`, callable) pairs, or simply a list of callables (in which case the 3595 index is implicitly the key). The `branch_index` `Tensor` is used to select an 3596 element in `branch_fns` with matching `int` key, falling back to `default` 3597 if none match, or `max(keys)` if no `default` is provided. The keys must form 3598 a contiguous set from `0` to `len(branch_fns) - 1`. 3599 3600 `tf.switch_case` supports nested structures as implemented in `tf.nest`. All 3601 callables must return the same (possibly nested) value structure of lists, 3602 tuples, and/or named tuples. 3603 3604 **Example:** 3605 3606 Pseudocode: 3607 3608 ```c++ 3609 switch (branch_index) { // c-style switch 3610 case 0: return 17; 3611 case 1: return 31; 3612 default: return -1; 3613 } 3614 ``` 3615 or 3616 ```python 3617 branches = {0: lambda: 17, 1: lambda: 31} 3618 branches.get(branch_index, lambda: -1)() 3619 ``` 3620 3621 Expressions: 3622 3623 ```python 3624 def f1(): return tf.constant(17) 3625 def f2(): return tf.constant(31) 3626 def f3(): return tf.constant(-1) 3627 r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3) 3628 # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3}) 3629 ``` 3630 3631 Args: 3632 branch_index: An int Tensor specifying which of `branch_fns` should be 3633 executed. 3634 branch_fns: A `dict` mapping `int`s to callables, or a `list` of 3635 (`int`, callable) pairs, or simply a list of callables (in which case the 3636 index serves as the key). Each callable must return a matching structure 3637 of tensors. 3638 default: Optional callable that returns a structure of tensors. 3639 name: A name for this operation (optional). 3640 3641 Returns: 3642 The tensors returned by the callable identified by `branch_index`, or those 3643 returned by `default` if no key matches and `default` was provided, or those 3644 returned by the max-keyed `branch_fn` if no `default` is provided. 3645 3646 Raises: 3647 TypeError: If `branch_fns` is not a list/dictionary. 3648 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3649 callables. 3650 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3651 callable. 3652 """ 3653 return _indexed_case_helper(branch_fns, default, branch_index, name) 3654 3655 3656@tf_export("__internal__.execute_fn_for_device", v1=[]) 3657def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"): 3658 """Executes one of the provided callables based on the device placement. 3659 3660 This API is used when the implementations for high level function depend on 3661 the underlying device placement. It takes a dictionary of device type to 3662 callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of 3663 the device where to run this op matches the key in 'device_branch_fns', 3664 the corresponding callable is executed, falling back to 'default_fn' if none 3665 matches. 3666 3667 **Example:** 3668 ```python 3669 def f1(): return tf.constant(1) 3670 def f2(): return tf.constant(2) 3671 r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1) 3672 ``` 3673 'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on 3674 any other device types. 3675 3676 3677 Args: 3678 device_branch_fns: a dictionary of device types to the callables. Each 3679 callable must return a matching structure of tensors. 3680 default_fn: fallback callable when the underlying device does not match any 3681 key in the 'device_branch_fns'. 3682 name: A name for this operation (optional). 3683 3684 Returns: 3685 The tensors returned by the callable identified by device type during 3686 execution, or those returned by 'default_fn' if no key matches. 3687 """ 3688 # Always execute the default fn for XLA to avoid complicated graph by case op. 3689 # see more discussions in b/167276293. 3690 is_in_xla = util.GraphOrParentsInXlaContext(ops.get_default_graph()) 3691 if is_in_xla: 3692 return default_fn() 3693 device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()} 3694 branch_fns = list(device_branch_fns_upper.values()) 3695 devices = list(device_branch_fns_upper.keys()) 3696 device_index = gen_functional_ops.device_index(device_names=devices) 3697 return _indexed_case_helper( 3698 branch_fns, 3699 default_fn, 3700 device_index, 3701 name, 3702 lower_using_switch_merge=False) 3703 3704 3705class XLAControlFlowContext(ControlFlowContext): 3706 """Base class for XLA and TPU control flow contexts.""" 3707 3708 def __init__(self): 3709 super(XLAControlFlowContext, self).__init__() 3710 self._name = "XLAControlFlowContext" 3711 3712 def to_control_flow_context_def(self, context_def, export_scope=None): 3713 # pylint: disable=useless-super-delegation 3714 # NOTE(slebedev): the method is required by `ControlFlowContext`. 3715 super(XLAControlFlowContext, 3716 self).to_control_flow_context_def(context_def, export_scope) 3717 3718 def IsXLAContext(self): 3719 return True 3720 3721 def AddOp(self, _): 3722 pass 3723 3724 def AddValue(self, x): 3725 return x 3726 3727 def RequiresUniqueFunctionRetracing(self): 3728 """Returns whether the tf.function should be retraced if the context changes. 3729 """ 3730 return False 3731 3732 3733@tf_export("__internal__.get_enclosing_xla_context", v1=[]) 3734def get_enclosing_xla_context(): 3735 """Recursively find and return the XLAControlFlowContext.""" 3736 graph = ops.get_default_graph() 3737 while graph is not None: 3738 # pylint: disable=protected-access 3739 context_ = graph._get_control_flow_context() 3740 # pylint: enable=protected-access 3741 while context_ is not None: 3742 if isinstance(context_, XLAControlFlowContext): 3743 return context_ 3744 context_ = context_.outer_context 3745 # This may be a FuncGraph due to defuns or v2 control flow. We need to 3746 # find the original graph with the XLAControlFlowContext. 3747 graph = getattr(graph, "outer_graph", None) 3748 return None 3749 3750 3751def from_control_flow_context_def(context_def, import_scope=None): 3752 """Deserializes `context_def` into the appropriate ControlFlowContext. 3753 3754 Args: 3755 context_def: ControlFlowContextDef proto 3756 import_scope: Optional `string`. Name scope to add. 3757 3758 Returns: 3759 A ControlFlowContext subclass 3760 """ 3761 if context_def.HasField("cond_ctxt"): 3762 return CondContext.from_proto( 3763 context_def.cond_ctxt, import_scope=import_scope) 3764 if context_def.HasField("while_ctxt"): 3765 return WhileContext.from_proto( 3766 context_def.while_ctxt, import_scope=import_scope) 3767 raise NotImplementedError("Unknown ControlFlowContextDef field: %s" % 3768 context_def.WhichOneof("ctxt")) 3769 3770 3771ops.register_proto_function( 3772 ops.GraphKeys.COND_CONTEXT, 3773 proto_type=control_flow_pb2.CondContextDef, 3774 to_proto=CondContext.to_proto, 3775 from_proto=CondContext.from_proto) 3776 3777ops.register_proto_function( 3778 ops.GraphKeys.WHILE_CONTEXT, 3779 proto_type=control_flow_pb2.WhileContextDef, 3780 to_proto=WhileContext.to_proto, 3781 from_proto=WhileContext.from_proto) 3782