1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Control Flow Operations. 16 17See the [autograph](https://www.tensorflow.org/guide/autograph) guide. 18""" 19# pylint: disable=g-bad-name 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import abc 25import collections 26import functools 27 28import six 29 30from tensorflow.core.framework import attr_value_pb2 31from tensorflow.core.protobuf import control_flow_pb2 32from tensorflow.python.eager import context 33from tensorflow.python.framework import composite_tensor 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.framework import tensor_spec 40from tensorflow.python.framework import tensor_util 41from tensorflow.python.framework import type_spec 42from tensorflow.python.ops import array_ops 43from tensorflow.python.ops import control_flow_util as util 44from tensorflow.python.ops import gen_array_ops 45from tensorflow.python.ops import gen_control_flow_ops 46from tensorflow.python.ops import gen_logging_ops 47from tensorflow.python.ops import gen_math_ops 48from tensorflow.python.ops import math_ops 49from tensorflow.python.ops import tensor_array_ops 50# go/tf-wildcard-import 51# pylint: disable=wildcard-import,undefined-variable 52from tensorflow.python.ops.gen_control_flow_ops import * 53# pylint: enable=wildcard-import 54from tensorflow.python.platform import tf_logging as logging 55from tensorflow.python.util import compat 56from tensorflow.python.util import deprecation 57from tensorflow.python.util import nest 58from tensorflow.python.util import tf_should_use 59from tensorflow.python.util.lazy_loader import LazyLoader 60from tensorflow.python.util.tf_export import tf_export 61 62# This is to avoid a circular dependency: 63# cond_v2 -> gradients_util -> control_flow_ops 64cond_v2 = LazyLoader("cond_v2", globals(), 65 "tensorflow.python.ops.cond_v2") 66 67# This is to avoid circular dependencies: 68# while_v2 -> control_flow_ops 69# while_v2 -> gradients_util -> control_flow_ops 70while_v2 = LazyLoader("while_v2", globals(), 71 "tensorflow.python.ops.while_v2") 72 73# We override the 'tuple' for a control flow op, so we keep python's 74# existing 'tuple' for later use in this module. 75_basetuple = tuple 76 77 78def _summarize_eager(tensor, summarize=None): 79 """Returns a summarized string representation of eager `tensor`. 80 81 Args: 82 tensor: EagerTensor to summarize 83 summarize: Include these many first elements of `array` 84 """ 85 # Emulate the behavior of Tensor::SummarizeValue() 86 if summarize is None: 87 summarize = 3 88 elif summarize < 0: 89 summarize = array_ops.size(tensor) 90 91 # reshape((-1,)) is the fastest way to get a flat array view 92 if tensor._rank(): # pylint: disable=protected-access 93 flat = tensor.numpy().reshape((-1,)) 94 lst = [str(x) for x in flat[:summarize]] 95 if len(lst) < flat.size: 96 lst.append("...") 97 else: 98 # tensor.numpy() returns a scalar for zero dimensional arrays 99 if gen_math_ops.not_equal(summarize, 0): 100 lst = [str(tensor.numpy())] 101 else: 102 lst = [] 103 104 return ", ".join(lst) 105 106 107# pylint: disable=protected-access 108 109 110# Assert and Print are special symbols in python, so we must 111# use an upper-case version of them. 112@tf_export("debugging.Assert", "Assert") 113@tf_should_use.should_use_result 114def Assert(condition, data, summarize=None, name=None): 115 """Asserts that the given condition is true. 116 117 If `condition` evaluates to false, print the list of tensors in `data`. 118 `summarize` determines how many entries of the tensors to print. 119 120 NOTE: In graph mode, to ensure that Assert executes, one usually attaches 121 a dependency: 122 123 ```python 124 # Ensure maximum element of x is smaller or equal to 1 125 assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x]) 126 with tf.control_dependencies([assert_op]): 127 ... code using x ... 128 ``` 129 130 Args: 131 condition: The condition to evaluate. 132 data: The tensors to print out when condition is false. 133 summarize: Print this many entries of each tensor. 134 name: A name for this operation (optional). 135 136 Returns: 137 assert_op: An `Operation` that, when executed, raises a 138 `tf.errors.InvalidArgumentError` if `condition` is not true. 139 @compatibility(eager) 140 returns None 141 @end_compatibility 142 143 Raises: 144 @compatibility(eager) 145 `tf.errors.InvalidArgumentError` if `condition` is not true 146 @end_compatibility 147 """ 148 if context.executing_eagerly(): 149 if not condition: 150 xs = ops.convert_n_to_tensor(data) 151 data_str = [_summarize_eager(x, summarize) for x in xs] 152 raise errors.InvalidArgumentError( 153 node_def=None, 154 op=None, 155 message="Expected '%s' to be true. Summarized data: %s" % 156 (condition, "\n".join(data_str))) 157 return 158 159 with ops.name_scope(name, "Assert", [condition, data]) as name: 160 xs = ops.convert_n_to_tensor(data) 161 if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs): 162 # As a simple heuristic, we assume that string and int32 are 163 # on host to avoid the need to use cond. If it is not case, 164 # we will pay the price copying the tensor to host memory. 165 return gen_logging_ops._assert(condition, data, summarize, name="Assert") 166 else: 167 condition = ops.convert_to_tensor(condition, name="Condition") 168 169 def true_assert(): 170 return gen_logging_ops._assert( 171 condition, data, summarize, name="Assert") 172 173 guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard") 174 if context.executing_eagerly(): 175 return 176 return guarded_assert.op 177 178 179def _Identity(data, name=None): 180 """Return a tensor with the same shape and contents as the input tensor. 181 182 Args: 183 data: A Tensor. 184 name: A name for this operation (optional). 185 186 Returns: 187 A Tensor with the same type and value as the input Tensor. 188 """ 189 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 190 if isinstance(data, ops.Tensor): 191 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 192 return gen_array_ops.ref_identity(data, name=name) 193 else: 194 return array_ops.identity(data, name=name) 195 elif isinstance(data, composite_tensor.CompositeTensor): 196 return nest.map_structure(_Identity, data, expand_composites=True) 197 else: 198 raise TypeError("Type %s not supported" % type(data)) 199 200 201def _NextIteration(data, name=None): 202 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 203 if isinstance(data, ops.Tensor): 204 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 205 return ref_next_iteration(data, name=name) 206 else: 207 return next_iteration(data, name=name) 208 elif isinstance(data, composite_tensor.CompositeTensor): 209 return nest.map_structure(_NextIteration, data, expand_composites=True) 210 else: 211 raise TypeError("Type %s not supported" % type(data)) 212 213 214def _Enter(data, 215 frame_name, 216 is_constant=False, 217 parallel_iterations=10, 218 use_ref=True, 219 use_input_shape=True, 220 name=None): 221 """Creates or finds a child frame, and makes `data` available to it. 222 223 The unique `frame_name` is used by the `Executor` to identify frames. If 224 `is_constant` is true, `data` is a constant in the child frame; otherwise 225 it may be changed in the child frame. At most `parallel_iterations` 226 iterations are run in parallel in the child frame. 227 228 Args: 229 data: The tensor to be made available to the child frame. 230 frame_name: The name of the child frame. 231 is_constant: If true, the output is constant within the child frame. 232 parallel_iterations: The number of iterations allowed to run in parallel. 233 use_ref: If true, use ref_enter if data is of ref type. 234 use_input_shape: If true, set the result's shape based on data's shape. 235 name: A name for this operation (optional). 236 237 Returns: 238 The same tensor as `data`. 239 """ 240 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 241 if isinstance(data, ops.Tensor): 242 if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access 243 result = gen_control_flow_ops.ref_enter( 244 data, frame_name, is_constant, parallel_iterations, name=name) 245 else: 246 result = gen_control_flow_ops.enter( 247 data, frame_name, is_constant, parallel_iterations, name=name) 248 if use_input_shape: 249 result.set_shape(data.get_shape()) 250 return result 251 elif isinstance(data, composite_tensor.CompositeTensor): 252 253 def enter_component(t): 254 return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref, 255 use_input_shape) 256 257 return nest.map_structure(enter_component, data, expand_composites=True) 258 else: 259 raise TypeError("Type %s not supported" % type(data)) 260 261 262def exit(data, name=None): # pylint: disable=redefined-builtin 263 """Exits the current frame to its parent frame. 264 265 Exit makes its input `data` available to the parent frame. 266 267 Args: 268 data: The tensor to be made available to the parent frame. 269 name: A name for this operation (optional). 270 271 Returns: 272 The same tensor as `data`. 273 """ 274 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 275 if isinstance(data, ops.Tensor): 276 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 277 return gen_control_flow_ops.ref_exit(data, name) 278 else: 279 return gen_control_flow_ops._exit(data, name) 280 elif isinstance(data, composite_tensor.CompositeTensor): 281 return nest.map_structure(exit, data, expand_composites=True) 282 else: 283 raise TypeError("Type %s not supported" % type(data)) 284 285 286def switch(data, pred, dtype=None, name=None): 287 """Forwards `data` to an output determined by `pred`. 288 289 If `pred` is false, the `data` input is forwarded to the first output. 290 Otherwise, the data goes to the second output. 291 292 This op handles `Tensor`s and `IndexedSlices`. 293 294 Args: 295 data: The tensor to be forwarded to the appropriate output. 296 pred: A scalar that specifies which output port will receive data. 297 dtype: Optional element type for the returned tensor. If missing, the type 298 is inferred from the type of `value`. 299 name: A name for this operation (optional). 300 301 Returns: 302 `(output_false, output_true)`: If `pred` is true, data will be forwarded 303 to `output_true`, otherwise it goes to `output_false`. 304 """ 305 with ops.name_scope(name, "Switch", [data, pred]) as name: 306 data = ops.internal_convert_to_tensor_or_composite( 307 data, dtype=dtype, name="data", as_ref=True) 308 pred = ops.convert_to_tensor(pred, name="pred") 309 if isinstance(data, ops.Tensor): 310 return gen_control_flow_ops.switch(data, pred, name=name) 311 else: 312 if not isinstance(data, composite_tensor.CompositeTensor): 313 raise TypeError("Type %s not supported" % type(data)) 314 tensors = nest.flatten(data, expand_composites=True) 315 mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors] 316 mapped_f, mapped_t = zip(*mapped) 317 return (nest.pack_sequence_as(data, mapped_f, expand_composites=True), 318 nest.pack_sequence_as(data, mapped_t, expand_composites=True)) 319 320 321def _SwitchRefOrTensor(data, pred, name="Switch"): 322 """Forwards `data` to an output determined by `pred`. 323 324 If `pred` is false, the `data` input is forwarded to the first output. 325 Otherwise, the data goes to the second output. 326 327 This op handles `Tensor`s and `IndexedSlices`. 328 329 Args: 330 data: The tensor to be forwarded to the appropriate output. 331 pred: A scalar that specifies which output port will receive data. 332 name: A name for this operation (optional). 333 334 Returns: 335 `(output_false, output_true)`: If `pred` is true, data will be forwarded to 336 `output_true`, otherwise it goes to `output_false`. 337 338 Raises: 339 TypeError: if data is not a Tensor or IndexedSlices 340 """ 341 data = ops.convert_to_tensor_or_composite(data, name="data") 342 # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below 343 # addresses the following scenario. 344 # 345 # Assume you execute Optimizer.apply_gradients() in a branch of a cond(). 346 # 347 # 1. The update op is created inside a `with ops.colocate(var):` block 348 # 349 # 2. Some tensor `data` is captured and a switch is created in a 350 # `with ops.colocate_with(data):` block. 351 # 352 # with ops.colocate_with(var): 353 # with ops.colocate_with(data): 354 # op = ... 355 # 356 # var and data may be pinned to different devices, so we want to ops 357 # created within ops.colocate_with(data) to ignore the existing stack. 358 with ops.colocate_with(data, ignore_existing=True): 359 if isinstance(data, ops.Tensor): 360 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 361 return ref_switch(data, pred, name=name) 362 return switch(data, pred, name=name) 363 364 365def merge(inputs, name=None): 366 """Returns the value of an available element of `inputs`. 367 368 This op tests each of the tensors in `inputs` in turn to determine if any of 369 them is available. If it finds an available tensor, it returns it and its 370 index in `inputs`. 371 372 It is an error if more than one tensor in `inputs` is available. If no tensor 373 in `inputs` is available, the returned tensor and index are not set. 374 375 This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of 376 `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices 377 before merging. 378 379 Args: 380 inputs: The input tensors, at most one of which is available. 381 name: A name for this operation (optional). 382 383 Returns: 384 A tuple containing the chosen input tensor and its index in `inputs`. 385 386 Raises: 387 ValueError: If any of the inputs is None, or inputs are IndexedSlices and 388 some but not all have a dense_shape property. 389 """ 390 if any(inp is None for inp in inputs): 391 raise ValueError("At least one of the merge inputs is None: %s" % inputs) 392 with ops.name_scope(name, "Merge", inputs) as name: 393 inputs = [ 394 ops.internal_convert_to_tensor_or_composite(inp, as_ref=True) 395 for inp in inputs 396 ] 397 if all(isinstance(v, ops.Tensor) for v in inputs): 398 if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access 399 return gen_control_flow_ops.ref_merge(inputs, name) 400 else: 401 return gen_control_flow_ops.merge(inputs, name) 402 else: 403 # If there is a mix of tensors and indexed slices, then convert the 404 # tensors to indexed slices. 405 if all(isinstance(v, (ops.IndexedSlices, ops.Tensor)) for v in inputs): 406 inputs = math_ops._as_indexed_slices_list(inputs, optimize=False) 407 408 for v in inputs: 409 if not isinstance(v, composite_tensor.CompositeTensor): 410 raise TypeError("Type %s not supported" % type(v)) 411 412 for v in inputs[1:]: 413 nest.assert_same_structure(inputs[0], v, expand_composites=True) 414 415 flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs] 416 merged_results = [ 417 gen_control_flow_ops.merge(component) 418 for component in zip(*flat_inputs) 419 ] 420 flat_merged = [tensor for (tensor, _) in merged_results] 421 chosen_index = merged_results[0][1] 422 merged_inputs = nest.pack_sequence_as( 423 inputs[0], flat_merged, expand_composites=True) 424 return (merged_inputs, chosen_index) 425 426 427# pylint: enable=protected-access 428 429 430def _convert_tensorarray_to_flow(tensor_or_tensor_array): 431 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray): 432 return tensor_or_tensor_array.flow 433 else: 434 return tensor_or_tensor_array 435 436 437def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows): 438 if len(tensors_or_tensorarrays) != len(tensors_or_flows): 439 raise ValueError( 440 "Lengths of original Tensor list and new list do not match: %d vs. %d" % 441 (len(tensors_or_tensorarrays), len(tensors_or_flows))) 442 return [ 443 tensor_array_ops.build_ta_with_new_flow(ta, t_or_flow) if isinstance( 444 ta, tensor_array_ops.TensorArray) else t_or_flow 445 for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows) 446 ] 447 448 449def _ShapeLessThanOrEqual(shape1, shape2): 450 if shape2.dims is None: 451 return True 452 if shape1.ndims != shape2.ndims: 453 return False 454 for dim1, dim2 in zip(shape1.dims, shape2.dims): 455 if dim2.value is not None and dim1.value != dim2.value: 456 return False 457 return True 458 459 460def _get_shape_invariant(var, shape=None): 461 """Returns shape invariant(s) for the given variable. 462 463 Args: 464 var: The tensor whose shape is described. 465 shape: The shape invariant for the tensor. If not specified, then a default 466 shape invariant for `var` is returned. 467 468 Returns: 469 `TensorShape` or `list` of `TensorShape`: The shape invariant for `var` (if 470 it is a `Tensor`), or the shape invariants for the components that comprise 471 `var` (if it is a `CompositeTensor`). 472 """ 473 if isinstance(var, composite_tensor.CompositeTensor): 474 # Get a TypeSpec for `var`. 475 if shape is None: 476 spec = var._type_spec # pylint: disable=protected-access 477 else: 478 spec = _shape_invariant_to_type_spec(var, shape) 479 480 tensor_specs = nest.flatten(spec, expand_composites=True) 481 return [tspec.shape for tspec in tensor_specs] 482 483 elif shape is None: 484 return var.shape 485 elif isinstance(shape, tensor_spec.TensorSpec): 486 if var.dtype != shape.dtype: 487 raise TypeError("TensorSpec %r is not compatible with %r" % (shape, var)) 488 return shape.shape 489 elif isinstance(shape, type_spec.TypeSpec): 490 raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var)) 491 else: 492 return shape 493 494 495def _shape_invariant_to_type_spec(var, shape): 496 """Converts a shape invariant to a TypeSpec. 497 498 Args: 499 var: The tensor whose shape is described by the shape invariant. 500 shape: A `TypeSpec` or `TensorShape`. If `shape` is already a `TypeSpec`, 501 then it is simply returned as-is. 502 503 Returns: 504 A `TypeSpec` for `var`, consistent with the given shape. 505 """ 506 if shape is None: 507 return type_spec.type_spec_from_value(var) 508 elif isinstance(shape, type_spec.TypeSpec): 509 if not shape.is_compatible_with(var): 510 raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var)) 511 return shape 512 elif not isinstance(shape, tensor_shape.TensorShape): 513 raise TypeError( 514 "Expected shape to be a TypeSpec, TensorShape or None, got %r for" 515 " value %r" % (shape, var)) 516 517 if isinstance(var, ops.Tensor): 518 return tensor_spec.TensorSpec(shape, var.dtype) 519 520 elif isinstance(var, composite_tensor.CompositeTensor): 521 try: 522 return var._shape_invariant_to_type_spec(shape) # pylint: disable=protected-access 523 except NotImplementedError: 524 raise TypeError( 525 "To describe or constrain a %s, use a %s instead of a TensorShape." % 526 (type(var).__name__, type(var._type_spec).__name__)) # pylint: disable=protected-access 527 528 else: 529 raise TypeError("Expected var to be a Tensor or CompositeTensor, got %s" 530 % var) 531 532 533def _SetShapeInvariants(input_vars, enter_vars, shapes): 534 """Set the shapes of the tensors in `enter_vars` to `shapes`. 535 536 Args: 537 input_vars: A list of tensors that are inputs to `enter_vars`. 538 enter_vars: A list of tensors whose shapes will be set. 539 shapes: A (possibly nested) list of shapes. 540 541 Raises: 542 ValueError: If any tensor in `enter_vars` has a less specific shape 543 than its corresponding shape in `shapes`. 544 """ 545 if shapes is None: 546 return 547 flat_shapes = nest.flatten(shapes) 548 if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes): 549 raise ValueError("`shapes` must be a (possibly nested) list of shapes.") 550 # Check that the shapes of the inputs are less than the shape invariants, 551 # and set the shapes of `enter_vars` to the shape invariants. 552 for inp, var, shape in zip(input_vars, enter_vars, flat_shapes): 553 if isinstance(var, ops.Tensor): 554 if not _ShapeLessThanOrEqual(inp.get_shape(), shape): 555 raise ValueError( 556 "The shape invariant specified for %s is not compatible with " 557 "the initial shape of the loop variable. It enters the loop " 558 "with shape %s, but the specified shape invariant is %s." % 559 (inp.name, inp.get_shape(), shape)) 560 var.set_shape(shape) 561 else: 562 raise TypeError("Type %s not supported" % type(var)) 563 564 565def _EnforceShapeInvariant(merge_var, next_var): 566 """Check if the shapes of the loops variables are invariants. 567 568 Args: 569 merge_var: The list of tensors representing the initial values of the loop 570 variables. 571 next_var: The list of tensors representing the values of the loop variables 572 after one loop iteration. 573 574 Raises: 575 ValueError: If any tensor in `merge_var` has a more specific shape than 576 its correspnding tensor in `next_var`. 577 """ 578 if isinstance(merge_var, ops.Tensor): 579 m_shape = merge_var.get_shape() 580 n_shape = next_var.get_shape() 581 if not _ShapeLessThanOrEqual(n_shape, m_shape): 582 enter = merge_var.op.inputs[0].op 583 assert util.IsLoopEnter(enter) 584 input_t = enter.inputs[0] 585 raise ValueError( 586 "Input tensor '%s' enters the loop with shape %s, but has shape %s " 587 "after one iteration. To allow the shape to vary across iterations, " 588 "use the `shape_invariants` argument of tf.while_loop to specify a " 589 "less-specific shape." % (input_t.name, input_t.shape, n_shape)) 590 else: 591 raise TypeError("Type %s not supported" % type(merge_var)) 592 593 594def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): 595 """Add NextIteration and back edge from v to m.""" 596 if isinstance(m, ops.Tensor): 597 v = ops.convert_to_tensor(v) 598 v = _NextIteration(v) 599 if enforce_shape_invariant: 600 # Make sure the shapes of loop outputs are correct. We do this before 601 # calling _update_input, which will raise a less-helpful error message if 602 # the types don't match. 603 # TODO(skyewm): call this for other cases below (needs testing) 604 _EnforceShapeInvariant(m, v) 605 m.op._update_input(1, v) # pylint: disable=protected-access 606 elif isinstance(m, composite_tensor.CompositeTensor): 607 # pylint: disable=protected-access 608 def update_component(m_component, v_component): 609 m_component.op._update_input(1, v_component) 610 611 if isinstance(m, ops.IndexedSlices): 612 v = math_ops._as_indexed_slices(v, optimize=False) 613 # pylint: enable=protected-access 614 v = _NextIteration(v) 615 return nest.map_structure(update_component, m, v, expand_composites=True) 616 else: 617 raise TypeError("Type %s not supported" % type(m)) 618 return v 619 620 621@six.add_metaclass(abc.ABCMeta) 622class ControlFlowContext(object): 623 """The base class for control flow context. 624 625 The usage pattern is a sequence of (Enter, Exit) followed by a final 626 ExitResult. 627 628 We maintain the following state for control flow contexts during graph 629 construction: 630 1. graph has _control_flow_context: the current context used to 631 construct new nodes. Changed by ctxt.Enter() and ctxt.Exit() 632 2. op has _control_flow_context: the context to which the op belongs. 633 Set at the time the op is created. Immutable. 634 3. A ControlFlowContext has _outer_context: the context in which this 635 context is created. Set at the time a context is created. Immutable. 636 4. A ControlFlowContext has _context_stack. 637 Pushed and popped by ctxt.Enter() and ctxt.Exit() 638 """ 639 640 def __init__(self, values_def=None, import_scope=None): 641 self._nested_contexts = [] 642 self._outer_context = ops.get_default_graph()._get_control_flow_context() 643 if self._outer_context: 644 self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access 645 self._context_stack = [] 646 if values_def: 647 self._init_values_from_proto(values_def, import_scope=import_scope) 648 else: 649 # The names of tensors that have been already seen in this context. 650 self._values = set() 651 # The keys are the names of tensors referenced by but external to this 652 # context. Each value is the Tensor that should be used by this context to 653 # access the key value (e.g. a switch output guarding a cond input value). 654 self._external_values = {} 655 656 def _init_values_from_proto(self, values_def, import_scope=None): 657 """Initializes values and external_values from `ValuesDef` protocol buffer. 658 659 Args: 660 values_def: `ValuesDef` protocol buffer. 661 import_scope: Optional `string`. Name scope to add. 662 """ 663 assert isinstance(values_def, control_flow_pb2.ValuesDef) 664 self._values = set( 665 ops.prepend_name_scope(value, import_scope) 666 for value in values_def.values) 667 g = ops.get_default_graph() 668 self._external_values = {} 669 for k, v in values_def.external_values.items(): 670 k = ops.prepend_name_scope(k, import_scope) 671 self._external_values[k] = g.as_graph_element( 672 ops.prepend_name_scope(v, import_scope)) 673 op_names = set([ 674 op.split(":")[0] 675 for op in self._values - set(self._external_values.keys()) 676 ]) 677 for op in op_names: 678 # pylint: disable=protected-access 679 g.as_graph_element(op)._set_control_flow_context(self) 680 # pylint: enable=protected-access 681 682 @property 683 def name(self): 684 return self._name 685 686 @property 687 def outer_context(self): 688 """Return the context containing this context.""" 689 return self._outer_context 690 691 @property 692 def grad_state(self): 693 raise NotImplementedError("Abstract method") 694 695 @property 696 def back_prop(self): 697 raise NotImplementedError("Abstract method") 698 699 @abc.abstractmethod 700 def to_control_flow_context_def(self, context_def, export_scope=None): 701 """Serializes this into `context_def`. 702 703 Args: 704 context_def: a `ControlFlowContextDef` protocol buffer. 705 export_scope: Optional `string`. Name scope to remove. 706 """ 707 raise NotImplementedError("Abstract method") 708 709 def _to_values_def(self, export_scope=None): 710 """Converts the values to a `ValuesDef` protocol buffer. 711 712 Args: 713 export_scope: Optional `string`. Name scope to remove. 714 715 Returns: 716 A `ValuesDef` protocol buffer. 717 """ 718 values_def = control_flow_pb2.ValuesDef() 719 values_def.values.extend( 720 [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)]) 721 for k, v in self._external_values.items(): 722 k = ops.strip_name_scope(k, export_scope) 723 values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope) 724 return values_def 725 726 def AddName(self, name): 727 self._values.add(name) 728 729 # pylint: disable=protected-access 730 def Enter(self): 731 """Enter this control flow context.""" 732 graph = ops.get_default_graph() 733 self._context_stack.append(graph._get_control_flow_context()) 734 graph._set_control_flow_context(self) 735 736 def Exit(self): 737 """Exit this control flow context.""" 738 graph = ops.get_default_graph() 739 last_context = self._context_stack.pop() 740 graph._set_control_flow_context(last_context) 741 742 def EnterGradientColocation(self, op, gradient_uid): 743 """Start building a gradient colocated with an op.""" 744 if self._outer_context: 745 self._outer_context.EnterGradientColocation(op, gradient_uid) 746 747 def ExitGradientColocation(self, op, gradient_uid): 748 """Start building a gradient colocated with an op.""" 749 if self._outer_context: 750 self._outer_context.ExitGradientColocation(op, gradient_uid) 751 752 def ExitResult(self, result): 753 """Make a list of tensors available in the outer context.""" 754 if self._outer_context: 755 def fn(x): 756 self._outer_context.AddName(x.name) 757 return x 758 nest.map_structure(fn, result, expand_composites=True) 759 760 def GetWhileContext(self): 761 """Return the while context containing this context.""" 762 if self._outer_context: 763 return self._outer_context.GetWhileContext() 764 return None 765 766 def _RemoveExternalControlEdges(self, op): 767 """Remove any external control dependency on this op.""" 768 while_ctxt = self.GetWhileContext() 769 # A control input of `op` is internal if it is in the same while 770 # loop context as the enclosing while loop context of self. 771 if while_ctxt is None: 772 internal_control_inputs = op.control_inputs 773 else: 774 internal_control_inputs = [] 775 for x in op.control_inputs: 776 ctxt = util.GetOutputContext(x) 777 if ctxt is not None and ctxt.GetWhileContext() == while_ctxt: 778 internal_control_inputs.append(x) 779 external_control_inputs = [] 780 if len(internal_control_inputs) != len(op.control_inputs): 781 external_control_inputs = list( 782 set(op.control_inputs) - set(internal_control_inputs)) 783 op._remove_all_control_inputs() 784 op._add_control_inputs(internal_control_inputs) 785 return internal_control_inputs, external_control_inputs 786 787 # pylint: enable=protected-access 788 789 def AddInnerOp(self, op): 790 """Notifies a scope about an operator added to an inner scope.""" 791 if self._outer_context: 792 self._outer_context.AddInnerOp(op) 793 794 def GetControlPivot(self): 795 """Returns the pivot node for this context, or None.""" 796 return None 797 798 def IsWhileContext(self): 799 return False 800 801 def IsCondContext(self): 802 return False 803 804 def IsXLAContext(self): 805 return False 806 807 def __str__(self): 808 return self.name 809 810 811class CondContext(ControlFlowContext): 812 """The context for the conditional construct.""" 813 814 def __init__(self, 815 pred=None, 816 pivot=None, 817 branch=None, 818 name="cond_text", 819 context_def=None, 820 import_scope=None): 821 """Creates a `CondContext`. 822 823 Args: 824 pred: The `boolean` tensor for the conditional predicate. 825 pivot: The predicate tensor in this branch. 826 branch: 0 or 1 representing this branch. 827 name: Name of the `CondContext` python object. 828 context_def: Optional `ContextDef` protocol buffer to initialize the 829 `CondContext` object from. 830 import_scope: Optional `string`. Name scope to add. Only used when 831 initialing from protocol buffer. 832 """ 833 self._name = ops.get_default_graph().unique_name(name) 834 835 if context_def: 836 self._init_from_proto(context_def, import_scope=import_scope) 837 else: 838 # Initializes the default fields. 839 ControlFlowContext.__init__(self) 840 self._pred = pred # The boolean tensor for the cond predicate 841 self._pivot = pivot # The predicate tensor in this branch 842 self._branch = branch # 0 or 1 representing this branch 843 844 # Values considered to have been already seen in this context. pred is not 845 # included in this context. 846 self._values.add(pred.name) 847 self._external_values[pred.name] = pred 848 self._values.add(pivot.name) 849 pivot.op._set_control_flow_context(self) # pylint: disable=protected-access 850 851 def _init_from_proto(self, context_def, import_scope=None): 852 """Creates a new `CondContext` from protocol buffer. 853 854 Args: 855 context_def: `CondContextDef` protocol buffer. 856 import_scope: Optional `string`. Name scope to add. 857 """ 858 assert isinstance(context_def, control_flow_pb2.CondContextDef) 859 # Create from context_def. 860 g = ops.get_default_graph() 861 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 862 self._pred = g.as_graph_element( 863 ops.prepend_name_scope(context_def.pred_name, import_scope)) 864 self._pivot = g.as_graph_element( 865 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 866 self._branch = context_def.branch 867 super(CondContext, self).__init__( 868 values_def=context_def.values_def, import_scope=import_scope) 869 870 @property 871 def pred(self): 872 return self._pred 873 874 @property 875 def pivot(self): 876 return self._pivot 877 878 @property 879 def branch(self): 880 return self._branch 881 882 @property 883 def grad_state(self): 884 if self.GetWhileContext(): 885 return self.GetWhileContext().grad_state 886 return None 887 888 @property 889 def back_prop(self): 890 if self.GetWhileContext(): 891 self.GetWhileContext().back_prop 892 return False 893 894 def GetControlPivot(self): 895 return self._pivot 896 897 def to_proto(self, export_scope=None): 898 """Converts a `CondContext` to a `CondContextDef` protocol buffer. 899 900 Args: 901 export_scope: Optional `string`. Name scope to remove. 902 903 Returns: 904 A `CondContextDef` protocol buffer. 905 """ 906 if (export_scope is None or self.name.startswith(export_scope)): 907 context_def = control_flow_pb2.CondContextDef() 908 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 909 context_def.pred_name = ops.strip_name_scope(self._pred.name, 910 export_scope) 911 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 912 export_scope) 913 context_def.branch = self._branch 914 context_def.values_def.MergeFrom( 915 super(CondContext, self)._to_values_def(export_scope)) 916 for nested in self._nested_contexts: 917 nested_def = context_def.nested_contexts.add() 918 nested.to_control_flow_context_def(nested_def) 919 920 return context_def 921 else: 922 return None 923 924 @staticmethod 925 def from_proto(context_def, import_scope=None): 926 """Returns a `CondContext` object created from `context_def`.""" 927 ret = CondContext(context_def=context_def, import_scope=import_scope) 928 929 ret.Enter() 930 for nested_def in context_def.nested_contexts: 931 from_control_flow_context_def(nested_def, import_scope=import_scope) 932 ret.Exit() 933 return ret 934 935 def to_control_flow_context_def(self, context_def, export_scope=None): 936 context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 937 938 def AddValue(self, val): 939 """Add `val` to the current context and its outer context recursively.""" 940 if val.name in self._values: 941 # Use the real value if it comes from outer context. This is needed in 942 # particular for nested conds. 943 result = self._external_values.get(val.name) 944 result = val if result is None else result 945 else: 946 result = val 947 self._values.add(val.name) 948 if self._outer_context: 949 result = self._outer_context.AddValue(val) 950 self._values.add(result.name) 951 self._external_values[result.name] = result 952 with ops.control_dependencies(None): 953 result = _SwitchRefOrTensor(result, self._pred)[self._branch] 954 if self._outer_context: 955 self._outer_context.AddInnerOp(result.op) 956 957 result.op.graph.prevent_fetching(result.op) 958 # pylint: disable=protected-access 959 result.op._set_control_flow_context(self) 960 # pylint: enable=protected-access 961 962 # Mark Switch output as seen by this context and any outer contexts, 963 # just like what we do for normal op outputs in _AddOpInternal() below. 964 ctxt = self 965 while ctxt is not None: 966 # pylint: disable=protected-access 967 ctxt._values.add(result.name) 968 ctxt = ctxt._outer_context 969 # pylint: enable=protected-access 970 971 self._external_values[val.name] = result 972 return result 973 974 def AddOp(self, op): 975 self._AddOpInternal(op) 976 977 def _AddOpInternal(self, op): 978 """Add `op` to the current context.""" 979 if not op.inputs: 980 # If we're in a while loop, remove any control inputs from outside the 981 # loop. 982 self._RemoveExternalControlEdges(op) 983 984 if not any( 985 util.OpInContext(input_op, self) for input_op in op.control_inputs): 986 # pylint: disable=protected-access 987 op._add_control_input(self._pivot.op) 988 # pylint: enable=protected-access 989 else: 990 # Make each input to 'op' available in this CondContext. If an input is 991 # already part of this context there's nothing to do, but if it's 992 # external, AddValue() will handle adding the appropriate Switch node and 993 # other bookkeeping. 994 for index in range(len(op.inputs)): 995 x = op.inputs[index] 996 if op.type == "Merge" and x.op.type == "NextIteration": 997 # Edge case: if we're importing a while loop inside this CondContext, 998 # AddValue() will not correctly handle the NextIteration inputs to 999 # Merge node. The problem is that the NextIteration should also be 1000 # part of this context, but if we're importing it won't have been 1001 # processed and added to the context yet, so AddValue() will try to 1002 # add a Switch which results in an invalid graph. Instead, we use the 1003 # NextIteration input as-is here, and it will eventually be added to 1004 # the context via AddOp(). 1005 real_x = x 1006 else: 1007 real_x = self.AddValue(x) 1008 if real_x != x: 1009 # pylint: disable=protected-access 1010 op._update_input(index, real_x) 1011 # pylint: enable=protected-access 1012 # Remove any external control dependency on this op. 1013 self._RemoveExternalControlEdges(op) 1014 # pylint: disable=protected-access 1015 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 1016 op._add_control_input(self._pivot.op) 1017 # pylint: enable=protected-access 1018 1019 # Mark op's outputs as seen by this context and any outer contexts. 1020 output_names = [x.name for x in op.outputs] 1021 ctxt = self 1022 while ctxt is not None: 1023 # pylint: disable=protected-access 1024 ctxt._values.update(output_names) 1025 ctxt = ctxt._outer_context 1026 # pylint: enable=protected-access 1027 1028 if self._outer_context or not util.IsLoopExit(op): 1029 op.graph.prevent_fetching(op) 1030 1031 if self._outer_context: 1032 self._outer_context.AddInnerOp(op) 1033 1034 def _ProcessOutputTensor(self, val): 1035 """Process an output tensor of a conditional branch.""" 1036 real_val = val 1037 if val.name not in self._values: 1038 # Handle the special case of lambda: x 1039 self._values.add(val.name) 1040 if self._outer_context: 1041 real_val = self._outer_context.AddValue(val) 1042 self._values.add(real_val.name) 1043 self._external_values[real_val.name] = real_val 1044 real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch] 1045 self._external_values[val.name] = real_val 1046 else: 1047 external_val = self._external_values.get(val.name) 1048 if external_val is not None: 1049 real_val = external_val 1050 return real_val 1051 1052 def _BuildCondTensor(self, v): 1053 if isinstance(v, ops.Operation): 1054 # Use pivot as the proxy for this op. 1055 return with_dependencies([v], self._pivot) 1056 else: 1057 v = nest.map_structure( 1058 _convert_tensorarray_to_flow, v, expand_composites=True) 1059 return self._ProcessOutputTensor(ops.convert_to_tensor(v)) 1060 1061 def BuildCondBranch(self, fn): 1062 """Add the subgraph defined by fn() to the graph.""" 1063 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1064 original_result = fn() 1065 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1066 if len(post_summaries) > len(pre_summaries): 1067 new_summaries = post_summaries[len(pre_summaries):] 1068 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1069 summary_ref[:] = pre_summaries 1070 with ops.control_dependencies(new_summaries): 1071 if original_result is None: 1072 return no_op(), None 1073 elif not isinstance(original_result, ops.Operation): 1074 original_result = nest.map_structure( 1075 array_ops.identity, original_result, expand_composites=True) 1076 if original_result is None: 1077 return None, None 1078 1079 result = nest.map_structure( 1080 self._BuildCondTensor, original_result, expand_composites=True) 1081 if not isinstance(result, (list, _basetuple)): 1082 result = [result] 1083 return original_result, result 1084 1085 def IsCondContext(self): 1086 return True 1087 1088 1089def _UnpackIfSingleton(res): 1090 if isinstance(res, (list, _basetuple)) and len(res) == 1: 1091 return res[0] 1092 else: 1093 return res 1094 1095 1096# pylint: disable=redefined-outer-name 1097# pylint: disable=g-doc-args 1098@tf_export(v1=["cond"]) 1099@deprecation.deprecated_args( 1100 None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", 1101 "fn1", "fn2") 1102def cond(pred, 1103 true_fn=None, 1104 false_fn=None, 1105 strict=False, 1106 name=None, 1107 fn1=None, 1108 fn2=None): 1109 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 1110 1111 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 1112 `false_fn` must have the same non-zero number and type of outputs. 1113 1114 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 1115 `false_fn` will be executed regardless of which branch is selected at runtime. 1116 1117 Although this behavior is consistent with the dataflow model of TensorFlow, 1118 it has frequently surprised users who expected a lazier semantics. 1119 Consider the following simple program: 1120 1121 ```python 1122 z = tf.multiply(a, b) 1123 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 1124 ``` 1125 1126 If `x < y`, the `tf.add` operation will be executed and `tf.square` 1127 operation will not be executed. Since `z` is needed for at least one 1128 branch of the `cond`, the `tf.multiply` operation is always executed, 1129 unconditionally. 1130 1131 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 1132 call to `cond`, and not at all during `Session.run()`). `cond` 1133 stitches together the graph fragments created during the `true_fn` and 1134 `false_fn` calls with some additional graph nodes to ensure that the right 1135 branch gets executed depending on the value of `pred`. 1136 1137 `tf.cond` supports nested structures as implemented in 1138 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 1139 same (possibly nested) value structure of lists, tuples, and/or named tuples. 1140 Singleton lists and tuples form the only exceptions to this: when returned by 1141 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 1142 This behavior is disabled by passing `strict=True`. 1143 1144 Args: 1145 pred: A scalar determining whether to return the result of `true_fn` or 1146 `false_fn`. 1147 true_fn: The callable to be performed if pred is true. 1148 false_fn: The callable to be performed if pred is false. 1149 strict: A boolean that enables/disables 'strict' mode; see above. 1150 name: Optional name prefix for the returned tensors. 1151 1152 Returns: 1153 Tensors returned by the call to either `true_fn` or `false_fn`. If the 1154 callables return a singleton list, the element is extracted from the list. 1155 1156 Raises: 1157 TypeError: if `true_fn` or `false_fn` is not callable. 1158 ValueError: if `true_fn` and `false_fn` do not return the same number of 1159 tensors, or return tensors of different types. 1160 1161 Example: 1162 1163 ```python 1164 x = tf.constant(2) 1165 y = tf.constant(5) 1166 def f1(): return tf.multiply(x, 17) 1167 def f2(): return tf.add(y, 23) 1168 r = tf.cond(tf.less(x, y), f1, f2) 1169 # r is set to f1(). 1170 # Operations in f2 (e.g., tf.add) are not executed. 1171 ``` 1172 1173 """ 1174 # Always enable control flow v2 if building a function, regardless of toggle. 1175 if (util.EnableControlFlowV2(ops.get_default_graph()) and 1176 not context.executing_eagerly()): 1177 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 1178 1179 # We needed to make true_fn/false_fn keyword arguments for 1180 # backwards-compatibility. This check exists so that we can convert back to 1181 # having them be positional arguments. 1182 # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after 1183 # `fn1` and `fn2` are deleted. 1184 if fn1 is not None: 1185 if true_fn is not None: 1186 raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") 1187 true_fn = fn1 1188 elif true_fn is None: 1189 raise TypeError("cond(): true_fn argument required") 1190 if fn2 is not None: 1191 if false_fn is not None: 1192 raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") 1193 false_fn = fn2 1194 elif false_fn is None: 1195 raise TypeError("cond(): false_fn argument required") 1196 1197 if not callable(true_fn): 1198 raise TypeError("true_fn must be callable.") 1199 if not callable(false_fn): 1200 raise TypeError("false_fn must be callable.") 1201 1202 with ops.name_scope(name, "cond", [pred]): 1203 if context.executing_eagerly(): 1204 if pred: 1205 result = true_fn() 1206 else: 1207 result = false_fn() 1208 if not strict: 1209 result = _UnpackIfSingleton(result) 1210 return result 1211 1212 # Add the Switch to the graph. 1213 if isinstance(pred, bool): 1214 raise TypeError("pred must not be a Python bool") 1215 p_2, p_1 = switch(pred, pred) 1216 pivot_1 = array_ops.identity(p_1, name="switch_t") 1217 pivot_2 = array_ops.identity(p_2, name="switch_f") 1218 pred = array_ops.identity(pred, name="pred_id") 1219 # Disable the fetching of tensors that are only on one branch of cond. 1220 for tensor in [p_1, p_2, pivot_1, pivot_2, pred]: 1221 tensor.op.graph.prevent_fetching(tensor.op) 1222 1223 # Build the graph for the true branch in a new context. 1224 context_t = CondContext(pred, pivot_1, branch=1) 1225 try: 1226 context_t.Enter() 1227 orig_res_t, res_t = context_t.BuildCondBranch(true_fn) 1228 if orig_res_t is None: 1229 raise ValueError("true_fn must have a return value.") 1230 context_t.ExitResult(res_t) 1231 finally: 1232 context_t.Exit() 1233 1234 # Build the graph for the false branch in a new context. 1235 context_f = CondContext(pred, pivot_2, branch=0) 1236 try: 1237 context_f.Enter() 1238 orig_res_f, res_f = context_f.BuildCondBranch(false_fn) 1239 if orig_res_f is None: 1240 raise ValueError("false_fn must have a return value.") 1241 context_f.ExitResult(res_f) 1242 finally: 1243 context_f.Exit() 1244 1245 if not strict: 1246 orig_res_t = _UnpackIfSingleton(orig_res_t) 1247 orig_res_f = _UnpackIfSingleton(orig_res_f) 1248 1249 # Check that the return values of the two branches have the same structure. 1250 try: 1251 nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True) 1252 except (TypeError, ValueError): 1253 nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f) 1254 nest.map_structure(_cast_indexed_slice_indices, res_t, res_f) 1255 try: 1256 nest.assert_same_structure(orig_res_t, orig_res_f, 1257 expand_composites=True) 1258 except TypeError as e: 1259 raise TypeError( 1260 "Incompatible return types of true_fn and false_fn: {}".format(e)) 1261 except ValueError as e: 1262 raise ValueError( 1263 "Incompatible return values of true_fn and false_fn: {}".format(e)) 1264 1265 # Add the final merge to the graph. 1266 if not res_t: 1267 raise ValueError("true_fn and false_fn must return at least one result.") 1268 1269 res_t_flat = nest.flatten(res_t, expand_composites=True) 1270 res_f_flat = nest.flatten(res_f, expand_composites=True) 1271 1272 for (x, y) in zip(res_t_flat, res_f_flat): 1273 assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) 1274 if x.dtype.base_dtype != y.dtype.base_dtype: 1275 raise ValueError( 1276 "Outputs of true_fn and false_fn must have the same type: " 1277 "%s, %s" % (x.dtype.name, y.dtype.name)) 1278 1279 merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] 1280 merges = _convert_flows_to_tensorarrays( 1281 nest.flatten(orig_res_t, expand_composites=True), merges) 1282 1283 # Only add non-nested conds to the collection. Any nested control flow will 1284 # be encapsulated in the root context. 1285 assert context_t.outer_context == context_f.outer_context 1286 if context_t.outer_context is None: 1287 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) 1288 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) 1289 1290 merges = nest.pack_sequence_as( 1291 structure=orig_res_t, flat_sequence=merges, expand_composites=True) 1292 1293 # Singleton lists and tuples are automatically unpacked if strict == False. 1294 if not strict: 1295 merges = _UnpackIfSingleton(merges) 1296 return merges 1297 1298 1299def _cast_indexed_slice_indices(a, b): 1300 """Cast IndexedSlice.indices from int32 to int64 where necessary. 1301 1302 If `a` and `b` are both IndexedSlices, and their indices have different 1303 dtypes, then cast both their dtypes to `int64` (modifies `a` and `b` 1304 in-place). Otherwise, does nothing. 1305 1306 Args: 1307 a: A value, which may be an IndexedSlices. 1308 b: A value, which may be an IndexedSlices. 1309 """ 1310 if (isinstance(a, ops.IndexedSlices) and isinstance(b, ops.IndexedSlices) 1311 and a.indices.dtype != b.indices.dtype): 1312 # pylint: disable=protected-access 1313 a._indices = math_ops.cast(a.indices, dtypes.int64) 1314 b._indices = math_ops.cast(b.indices, dtypes.int64) 1315 1316 1317# pylint: enable=g-doc-args 1318# pylint: enable=redefined-outer-name 1319 1320 1321@tf_export("cond", v1=[]) 1322def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None): 1323 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 1324 1325 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 1326 `false_fn` must have the same non-zero number and type of outputs. 1327 1328 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 1329 `false_fn` will be executed regardless of which branch is selected at runtime. 1330 1331 Although this behavior is consistent with the dataflow model of TensorFlow, 1332 it has frequently surprised users who expected a lazier semantics. 1333 Consider the following simple program: 1334 1335 ```python 1336 z = tf.multiply(a, b) 1337 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 1338 ``` 1339 1340 If `x < y`, the `tf.add` operation will be executed and `tf.square` 1341 operation will not be executed. Since `z` is needed for at least one 1342 branch of the `cond`, the `tf.multiply` operation is always executed, 1343 unconditionally. 1344 1345 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 1346 call to `cond`, and not at all during `Session.run()`). `cond` 1347 stitches together the graph fragments created during the `true_fn` and 1348 `false_fn` calls with some additional graph nodes to ensure that the right 1349 branch gets executed depending on the value of `pred`. 1350 1351 `tf.cond` supports nested structures as implemented in 1352 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 1353 same (possibly nested) value structure of lists, tuples, and/or named tuples. 1354 Singleton lists and tuples form the only exceptions to this: when returned by 1355 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 1356 1357 Note: It is illegal to "directly" use tensors created inside a cond branch 1358 outside it, e.g. by storing a reference to a branch tensor in the python 1359 state. If you need to use a tensor created in a branch function you should 1360 return it as an output of the branch function and use the output from 1361 `tf.cond` instead. 1362 1363 Args: 1364 pred: A scalar determining whether to return the result of `true_fn` or 1365 `false_fn`. 1366 true_fn: The callable to be performed if pred is true. 1367 false_fn: The callable to be performed if pred is false. 1368 name: Optional name prefix for the returned tensors. 1369 1370 Returns: 1371 Tensors returned by the call to either `true_fn` or `false_fn`. If the 1372 callables return a singleton list, the element is extracted from the list. 1373 1374 Raises: 1375 TypeError: if `true_fn` or `false_fn` is not callable. 1376 ValueError: if `true_fn` and `false_fn` do not return the same number of 1377 tensors, or return tensors of different types. 1378 1379 Example: 1380 1381 ```python 1382 x = tf.constant(2) 1383 y = tf.constant(5) 1384 def f1(): return tf.multiply(x, 17) 1385 def f2(): return tf.add(y, 23) 1386 r = tf.cond(tf.less(x, y), f1, f2) 1387 # r is set to f1(). 1388 # Operations in f2 (e.g., tf.add) are not executed. 1389 ``` 1390 1391 """ 1392 return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name) 1393 1394 1395def _resource_safe_shape(t): 1396 """Returns the shape of t or the variable it points to.""" 1397 if t.dtype == dtypes.resource: 1398 while t.op.inputs: 1399 t = t.op.inputs[0] 1400 return tensor_shape.TensorShape(t.op.get_attr("shape")) 1401 return array_ops.shape_internal(t, optimize=False) 1402 1403 1404# TODO(yuanbyu): Consider having a unified notion of context for 1405# not only conditionals and loops but also control dependency and 1406# subgraphs. 1407class WhileContext(ControlFlowContext): 1408 """The context for the loop construct.""" 1409 1410 def __init__(self, 1411 maximum_iterations=None, 1412 parallel_iterations=10, 1413 back_prop=True, 1414 swap_memory=False, 1415 name="while_context", 1416 grad_state=None, 1417 context_def=None, 1418 import_scope=None): 1419 """"Creates a `WhileContext`. 1420 1421 Args: 1422 maximum_iterations: Optional upper bound on number of loop iterations. 1423 parallel_iterations: The number of iterations allowed to run in parallel. 1424 back_prop: Whether backprop is enabled for this while loop. 1425 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 1426 name: Optional name prefix for the returned tensors. 1427 grad_state: The gradient loop state. 1428 context_def: Optional `WhileContextDef` protocol buffer to initialize the 1429 `Whilecontext` python object from. 1430 import_scope: Optional `string`. Name scope to add. Only used when 1431 initialing from protocol buffer. 1432 """ 1433 if context_def: 1434 self._init_from_proto(context_def, import_scope=import_scope) 1435 else: 1436 ControlFlowContext.__init__(self) 1437 self._init_from_args(maximum_iterations, parallel_iterations, back_prop, 1438 swap_memory, name) 1439 # The gradient loop state. 1440 self._grad_state = grad_state 1441 1442 def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop, 1443 swap_memory, name): 1444 """Creates a new `WhileContext` from arguments. 1445 1446 Args: 1447 maximum_iterations: Optional upper bound on number of loop iterations. 1448 parallel_iterations: The number of iterations allowed to run in parallel. 1449 back_prop: Whether backprop is enabled for this while loop. 1450 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 1451 name: Optional name prefix for the returned tensors. 1452 1453 Raises: 1454 ValueError: If `parallel_iterations` has invalid value. 1455 """ 1456 if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0): 1457 raise ValueError("`parallel_iterations` must be a positive integer: " 1458 "%s" % parallel_iterations) 1459 self._name = ops.get_default_graph().unique_name(name) 1460 self._maximum_iterations = maximum_iterations 1461 self._parallel_iterations = parallel_iterations 1462 self._back_prop = back_prop 1463 self._swap_memory = swap_memory 1464 # We use this node to control constants created by the pred lambda. 1465 self._pivot_for_pred = None 1466 # We use this node to control constants created by the body lambda. 1467 self._pivot_for_body = None 1468 # The boolean tensor for loop termination condition. Used in code 1469 # generation for gradient computation 1470 self._pivot = None 1471 # The list of exit tensors for loop variables. 1472 self._loop_exits = [] 1473 # The list of enter tensors for loop variables. 1474 self._loop_enters = [] 1475 self._graph = ops.get_default_graph() 1476 1477 def _init_from_proto(self, context_def, import_scope=None): 1478 """Creates a new `WhileContext` from protocol buffer. 1479 1480 Args: 1481 context_def: `WhileContextDef` protocol buffer. 1482 import_scope: Optional `string`. Name scope to add. 1483 """ 1484 assert isinstance(context_def, control_flow_pb2.WhileContextDef) 1485 # Create from context_def. 1486 g = ops.get_default_graph() 1487 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 1488 if context_def.maximum_iterations_name: 1489 self._maximum_iterations = g.as_graph_element( 1490 ops.prepend_name_scope(context_def.maximum_iterations_name, 1491 import_scope)) 1492 else: 1493 self._maximum_iterations = None 1494 self._parallel_iterations = context_def.parallel_iterations 1495 self._back_prop = context_def.back_prop 1496 self._swap_memory = context_def.swap_memory 1497 self._pivot_for_pred = g.as_graph_element( 1498 ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope)) 1499 # We use this node to control constants created by the body lambda. 1500 self._pivot_for_body = g.as_graph_element( 1501 ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope)) 1502 # The boolean tensor for loop termination condition. Used in code 1503 # generation for gradient computation. 1504 self._pivot = g.as_graph_element( 1505 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 1506 # The list of exit tensors for loop variables. 1507 self._loop_exits = [ 1508 g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) 1509 for exit_name in context_def.loop_exit_names 1510 ] 1511 # The list of enter tensors for loop variables. 1512 self._loop_enters = [ 1513 g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) 1514 for enter_name in context_def.loop_enter_names 1515 ] 1516 super(WhileContext, self).__init__( 1517 values_def=context_def.values_def, import_scope=import_scope) 1518 1519 # import_scope causes self.name to be different from the original serialized 1520 # context's name. Rewrite "frame_name" attrs with the new name. 1521 if import_scope: 1522 for tensor_name in self._values: 1523 op = g.as_graph_element(tensor_name).op 1524 if util.IsLoopEnter(op): 1525 # pylint: disable=protected-access 1526 op._set_attr("frame_name", 1527 attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) 1528 # pylint: enable=protected-access 1529 self._graph = ops.get_default_graph() 1530 1531 @property 1532 def maximum_iterations(self): 1533 """The maximum number of iterations that will be executed.""" 1534 return self._maximum_iterations 1535 1536 @property 1537 def parallel_iterations(self): 1538 """The number of iterations allowed to run in parallel.""" 1539 return self._parallel_iterations 1540 1541 @property 1542 def back_prop(self): 1543 """True iff backprop is enabled for this while loop.""" 1544 return self._back_prop 1545 1546 @property 1547 def swap_memory(self): 1548 """True iff GPU-CPU memory swap is enabled for this while loop.""" 1549 return self._swap_memory 1550 1551 @property 1552 def pivot(self): 1553 """The boolean tensor representing the loop termination condition.""" 1554 return self._pivot 1555 1556 @property 1557 def loop_enters(self): 1558 """The list of enter tensors for loop variables.""" 1559 return self._loop_enters 1560 1561 @property 1562 def loop_exits(self): 1563 """The list of exit tensors for loop variables.""" 1564 return self._loop_exits 1565 1566 @property 1567 def grad_state(self): 1568 """The gradient loop state.""" 1569 return self._grad_state 1570 1571 def to_proto(self, export_scope=None): 1572 """Converts a `WhileContext` to a `WhileContextDef` protocol buffer. 1573 1574 Args: 1575 export_scope: Optional `string`. Name scope to remove. 1576 1577 Returns: 1578 A `WhileContextDef` protocol buffer. 1579 """ 1580 if (export_scope is None or self.name.startswith(export_scope)): 1581 context_def = control_flow_pb2.WhileContextDef() 1582 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 1583 context_def.parallel_iterations = self._parallel_iterations 1584 if self._maximum_iterations is not None: 1585 context_def.maximum_iterations_name = ops.strip_name_scope( 1586 self._maximum_iterations.name, export_scope) 1587 context_def.back_prop = self._back_prop 1588 context_def.swap_memory = self._swap_memory 1589 context_def.pivot_for_pred_name = ops.strip_name_scope( 1590 self._pivot_for_pred.name, export_scope) 1591 context_def.pivot_for_body_name = ops.strip_name_scope( 1592 self._pivot_for_body.name, export_scope) 1593 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 1594 export_scope) 1595 context_def.loop_exit_names.extend([ 1596 ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits 1597 ]) 1598 context_def.loop_enter_names.extend([ 1599 ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters 1600 ]) 1601 context_def.values_def.MergeFrom( 1602 super(WhileContext, self)._to_values_def(export_scope=export_scope)) 1603 for nested in self._nested_contexts: 1604 nested_def = context_def.nested_contexts.add() 1605 nested.to_control_flow_context_def(nested_def) 1606 1607 return context_def 1608 else: 1609 return None 1610 1611 def to_control_flow_context_def(self, context_def, export_scope=None): 1612 context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 1613 1614 @staticmethod 1615 def from_proto(context_def, import_scope=None): 1616 """Returns a `WhileContext` object created from `context_def`. 1617 1618 Args: 1619 context_def: A `WhileContextDef` protocol buffer. 1620 import_scope: Optional `string`. Name scope to add. 1621 1622 Returns: 1623 A `WhileContext` Python object. 1624 """ 1625 ret = WhileContext(context_def=context_def, import_scope=import_scope) 1626 ret.Enter() 1627 for nested_def in context_def.nested_contexts: 1628 from_control_flow_context_def(nested_def, import_scope=import_scope) 1629 ret.Exit() 1630 return ret 1631 1632 def GetWhileContext(self): 1633 return self 1634 1635 def GetControlPivot(self): 1636 if self._pivot_for_body is not None: 1637 return self._pivot_for_body 1638 return self._pivot_for_pred 1639 1640 def AddValue(self, val): 1641 """Add `val` to the current context and its outer context recursively.""" 1642 result = val 1643 new_value = val.name not in self._values 1644 # Don't treat ops in this context as new values. Usually all known values 1645 # are in self._values, except when we're importing a while loop inside this 1646 # WhileContext. Since there's a cycle in this case, `val` may be part of the 1647 # imported while loop but not yet processed by this context and added to 1648 # self._values in _AddOpInternal. We only want to process external input 1649 # tensors to the while loop here. 1650 new_value &= val.op._control_flow_context is not self # pylint: disable=protected-access 1651 if new_value: 1652 self._values.add(val.name) 1653 1654 # If we are in a grad context and val is from its forward context, 1655 # use GetRealValue(), which adds the logic to save the history of 1656 # val in forward. 1657 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 1658 if grad_ctxt: 1659 grad_ctxt = grad_ctxt.GetWhileContext() 1660 if grad_ctxt.grad_state: 1661 forward_ctxt = util.GetWhileContext(val.op) 1662 if util.IsLoopExit(val.op): 1663 forward_ctxt = forward_ctxt.outer_context 1664 if forward_ctxt: 1665 forward_ctxt = forward_ctxt.GetWhileContext() 1666 if forward_ctxt == grad_ctxt.grad_state.forward_context: 1667 real_val = grad_ctxt.grad_state.GetRealValue(val) 1668 self._external_values[val.name] = real_val 1669 return real_val 1670 1671 if self._outer_context is not None: 1672 result = self._outer_context.AddValue(val) 1673 # Create an Enter to make `result` known to this loop context. 1674 with ops.control_dependencies(None): 1675 enter = _Enter( 1676 result, 1677 self._name, 1678 is_constant=True, 1679 parallel_iterations=self._parallel_iterations) 1680 enter.graph.prevent_feeding(enter) 1681 if self._outer_context: 1682 self._outer_context.AddInnerOp(enter.op) 1683 # Fix the control inputs and control flow context of these enter ops. 1684 self._FixControlInputsAndContext([enter]) 1685 1686 # Add `enter` in this context. 1687 self._values.add(enter.name) 1688 self._external_values[val.name] = enter 1689 result = enter 1690 else: 1691 actual_val = self._external_values.get(val.name) 1692 if actual_val is not None: 1693 result = actual_val 1694 return result 1695 1696 def AddOp(self, op): 1697 """Add `op` to the current context.""" 1698 # For a reduction op, if op is in a grad context and its input is from 1699 # its forward context, moving op to the forward context means we would 1700 # store the tensor after the reduction as opposed to the tensor before 1701 # reduction, and therefore could significantly reduce memory consumption. 1702 # For now, we do this only for a few ops. 1703 # 1704 # If in XLA context, do not move constant ops to forward pass as pushing to 1705 # and popping from a stack removes the constant property of an op and breaks 1706 # XLA compilation, which requires certain inputs to be constant for certain 1707 # ops. 1708 if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}: 1709 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 1710 if grad_ctxt: 1711 grad_ctxt = grad_ctxt.GetWhileContext() 1712 if grad_ctxt.grad_state: 1713 op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op) 1714 if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context: 1715 op_input_ctxt = op.inputs[0].op._get_control_flow_context() 1716 op._set_control_flow_context(op_input_ctxt) 1717 op_input_ctxt._AddOpInternal(op) 1718 return 1719 self._AddOpInternal(op) 1720 1721 def _AddOpInternal(self, op): 1722 """Add `op` to the current context. 1723 1724 We move any external control dependencies of the op to the loop pivot, to 1725 ensure they get executed. 1726 """ 1727 if not op.inputs: 1728 # Remove any external control dependency on this op 1729 control_inputs, external_inputs = self._RemoveExternalControlEdges(op) 1730 # Add a control edge from the control pivot to this op. 1731 if not control_inputs: 1732 # pylint: disable=protected-access 1733 op._add_control_input(self.GetControlPivot().op) 1734 # pylint: enable=protected-access 1735 for x in op.outputs: 1736 self._values.add(x.name) 1737 else: 1738 for index in range(len(op.inputs)): 1739 x = op.inputs[index] 1740 real_x = self.AddValue(x) 1741 if real_x != x: 1742 op._update_input(index, real_x) # pylint: disable=protected-access 1743 # Remove any external control dependency on this op. 1744 _, external_inputs = self._RemoveExternalControlEdges(op) 1745 # Add a control dependency to prevent loop invariants from 1746 # enabling ops that should not be executed. 1747 self._MaybeAddControlDependency(op) 1748 for x in op.outputs: 1749 self._values.add(x.name) 1750 if external_inputs: 1751 # Use an identity to pull control inputs as data inputs. Note that we 1752 # ignore ops which don't have outputs. TODO(apassos): fix that 1753 with ops.control_dependencies(None): 1754 self.Enter() 1755 external_inputs = [ 1756 array_ops.identity(x.outputs[0]).op 1757 for x in external_inputs 1758 if x.outputs 1759 ] 1760 self.Exit() 1761 op._add_control_inputs(external_inputs) # pylint: disable=protected-access 1762 if self._outer_context or not util.IsLoopExit(op): 1763 op.graph.prevent_fetching(op) 1764 for x in op.outputs: 1765 op.graph.prevent_feeding(x) 1766 1767 if self._outer_context: 1768 self._outer_context.AddInnerOp(op) 1769 1770 def _MaybeAddControlDependency(self, op): 1771 """Add a control input to the op if it only depends on loop invariants.""" 1772 1773 def _IsOpFree(op): 1774 """Determines if `op` needs a control dependency.""" 1775 if op.control_inputs: 1776 return False 1777 # pylint: disable=protected-access 1778 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 1779 return True 1780 # pylint: enable=protected-access 1781 for x in op.inputs: 1782 if not util.IsLoopConstantEnter(x.op): 1783 return False 1784 return True 1785 1786 if _IsOpFree(op): 1787 # pylint: disable=protected-access 1788 op._add_control_input(self.GetControlPivot().op) 1789 # pylint: enable=protected-access 1790 1791 def AddForwardLoopCounter(self, outer_grad_state): 1792 """Adds a loop that counts the number of iterations. 1793 1794 This is added to the forward loop at the time when we start to 1795 create the loop for backprop gradient computation. Called in 1796 the outer context of this forward context. 1797 1798 The pseudocode is: 1799 `n = 0; while (_pivot) { n++; }` 1800 1801 Note that a control dependency is added to `n` to ensure the correct 1802 execution order of stack push ops. 1803 1804 Args: 1805 outer_grad_state: The outer grad state. None if not nested. 1806 1807 Returns: 1808 The number of iterations taken by the forward loop and the loop index. 1809 """ 1810 n = constant_op.constant(0, name="f_count") 1811 if outer_grad_state is not None: 1812 # Force the stack pushes of i-th execution of an inner loop to be ordered 1813 # before the pushes of (i+1)-th execution of the same inner loop. 1814 outer_add_op = outer_grad_state.forward_index.op.inputs[0].op 1815 n.op._add_control_input(outer_add_op) # pylint: disable=protected-access 1816 1817 self.Enter() 1818 self.AddName(n.name) 1819 enter_n = _Enter( 1820 n, 1821 self._name, 1822 is_constant=False, 1823 parallel_iterations=self._parallel_iterations, 1824 name="f_count") 1825 self.loop_enters.append(enter_n) 1826 1827 merge_n = merge([enter_n, enter_n])[0] 1828 switch_n = switch(merge_n, self._pivot) 1829 1830 index = math_ops.add(switch_n[1], 1) 1831 next_n = _NextIteration(index) 1832 merge_n.op._update_input(1, next_n) 1833 1834 total_iterations = exit(switch_n[0], name="f_count") 1835 self.loop_exits.append(total_iterations) 1836 self.ExitResult([total_iterations]) 1837 self.Exit() 1838 return total_iterations, next_n 1839 1840 def AddBackpropLoopCounter(self, count, outer_grad_state): 1841 """Add the backprop loop that controls the iterations. 1842 1843 This is added to the backprop loop. It is used to control the loop 1844 termination of the backprop loop. Called in the outer context of 1845 this grad context. 1846 1847 The pseudocode is: 1848 `n = count; while (n >= 1) { n--; }` 1849 1850 Note that a control dependency is added to `final_zero` to ensure the 1851 correct execution order of stack pop ops. 1852 1853 Args: 1854 count: The number of iterations for backprop. 1855 outer_grad_state: The outer grad state. None if not nested. 1856 1857 Returns: 1858 The loop index. 1859 """ 1860 in_separate_functions = count.graph is not ops.get_default_graph() 1861 if in_separate_functions: 1862 # Brings the count into this graph 1863 count = array_ops.identity(count) 1864 else: 1865 # TODO(apassos) XLA expects this constant to be created outside the loop, 1866 # so doing that for now. 1867 one = constant_op.constant(1, name="b_count") 1868 1869 self.Enter() 1870 self.AddName(count.name) 1871 enter_count = _Enter( 1872 count, 1873 self._name, 1874 is_constant=False, 1875 parallel_iterations=self._parallel_iterations, 1876 name="b_count") 1877 self.loop_enters.append(enter_count) 1878 1879 merge_count = merge([enter_count, enter_count])[0] 1880 self._pivot_for_pred = merge_count 1881 1882 if in_separate_functions: 1883 one = constant_op.constant(1, name="b_count") 1884 pred = math_ops.greater_equal(merge_count, one) 1885 self._pivot = loop_cond(pred, name="b_count") 1886 switch_count = switch(merge_count, self._pivot) 1887 1888 index = math_ops.subtract(switch_count[1], one) 1889 self._pivot_for_body = index 1890 next_count = _NextIteration(index) 1891 merge_count.op._update_input(1, next_count) 1892 1893 final_zero = exit(switch_count[0], name="b_count") 1894 self.loop_exits.append(final_zero) 1895 if outer_grad_state is not None: 1896 # Force the stack pops of i-th execution of an inner loop to be ordered 1897 # before the pops of (i+1)-th execution of the same inner loop. 1898 # pylint: disable=protected-access 1899 outer_grad_state.grad_sync._add_control_input(final_zero.op) 1900 # pylint: enable=protected-access 1901 1902 self.ExitResult([final_zero]) 1903 self.Exit() 1904 return next_count 1905 1906 def AddBackpropAccumulator(self, op, grad): 1907 """Add an accumulation loop for every loop invariant. 1908 1909 This is added to the backprop loop. It is used to accumulate partial 1910 gradients within each loop iteration. Called when in the gradient while 1911 context. 1912 1913 The pseudocode is: 1914 ``` 1915 acc = 0.0; 1916 while (_pivot) { 1917 acc += grad; 1918 } 1919 ``` 1920 1921 Args: 1922 op: The Enter op for a loop invariant. 1923 grad: The partial gradient of an iteration for a loop invariant. 1924 1925 Returns: 1926 The gradient for a loop invariant. 1927 """ 1928 self.Exit() 1929 # Create a zeros tensor with the right shape for acc. If we don't 1930 # know the full shape statically, we will have to get the shape 1931 # dynamically from the forward inference. Getting the shape right 1932 # for the zeros is only needed for the base case when the loop exits 1933 # without running any iterations. 1934 shape = grad.get_shape() 1935 if shape.is_fully_defined(): 1936 if self.outer_context: 1937 self.outer_context.Enter() 1938 acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc") 1939 if self.outer_context: 1940 self.outer_context.Exit() 1941 else: 1942 value = op.inputs[0] 1943 if (isinstance(self.outer_context, WhileContext) and 1944 self.outer_context.grad_state is not None): 1945 # We are in a nested while loop. 1946 forward_ctxt = self.grad_state.forward_context 1947 forward_ctxt.outer_context.Enter() 1948 zeros_shape = array_ops.shape_internal(value, optimize=False) 1949 forward_ctxt.outer_context.Exit() 1950 outer_grad_state = self.grad_state.outer_grad_state 1951 history_zeros_shape = outer_grad_state.AddForwardAccumulator( 1952 zeros_shape) 1953 self.outer_context.Enter() 1954 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 1955 history_zeros_shape, zeros_shape) 1956 acc = array_ops.zeros(real_shape, grad.dtype) 1957 self.outer_context.Exit() 1958 else: 1959 if self.outer_context: 1960 self.outer_context.Enter() 1961 zeros_shape = array_ops.shape_internal(value, optimize=False) 1962 acc = array_ops.zeros(zeros_shape, grad.dtype) 1963 if self.outer_context: 1964 self.outer_context.Exit() 1965 1966 self.Enter() 1967 self.AddName(acc.name) 1968 enter_acc = _Enter( 1969 acc, 1970 self._name, 1971 is_constant=False, 1972 parallel_iterations=self._parallel_iterations, 1973 name="b_acc") 1974 self.loop_enters.append(enter_acc) 1975 1976 merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0] 1977 switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot) 1978 1979 add_acc = math_ops.add(switch_acc_true, grad) 1980 next_acc = _NextIteration(add_acc) 1981 merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access 1982 1983 result_acc = exit(switch_acc_false, name="b_acc") 1984 self.loop_exits.append(result_acc) 1985 self.ExitResult([result_acc]) 1986 return result_acc 1987 1988 def AddBackpropIndexedSlicesAccumulator(self, op, grad): 1989 """This is used for accumulating gradients that are IndexedSlices. 1990 1991 This is essentially the equivalent of AddBackpropAccumulator but optimized 1992 for things like updating embeddings from within a while loop. 1993 1994 Args: 1995 op: The Enter op for a loop invariant. 1996 grad: The partial gradients represented as an IndexedSlices. 1997 1998 Returns: 1999 The accumulated IndexedSlices gradient of the loop invariant. 2000 """ 2001 values = grad.values 2002 indices = grad.indices 2003 dense_shape = grad.dense_shape 2004 2005 self.Exit() 2006 if self.outer_context: 2007 self.outer_context.Enter() 2008 if values.get_shape().is_fully_defined(): 2009 values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] + 2010 values.get_shape().dims[1:]) 2011 if self.outer_context: 2012 self.outer_context.Enter() 2013 values_acc = constant_op.constant( 2014 0, values.dtype, shape=values_shape, name="b_acc") 2015 if self.outer_context: 2016 self.outer_context.Exit() 2017 else: 2018 values_shape = _resource_safe_shape(op.inputs[0])[1:] 2019 values_shape = array_ops.concat([[1], values_shape], 0) 2020 values_acc = array_ops.zeros(values_shape, dtype=values.dtype) 2021 indices_acc = constant_op.constant([0], indices.dtype) 2022 shape_acc = None 2023 if dense_shape is not None: 2024 if dense_shape.get_shape().is_fully_defined(): 2025 if self.outer_context: 2026 self.outer_context.Enter() 2027 shape_acc = constant_op.constant( 2028 0, dense_shape.dtype, shape=dense_shape.get_shape()) 2029 if self.outer_context: 2030 self.outer_context.Exit() 2031 else: 2032 shape_acc = array_ops.zeros_like( 2033 array_ops.shape_internal( 2034 op.inputs[0], optimize=False, out_type=dense_shape.dtype), 2035 optimize=False) 2036 2037 if self.outer_context: 2038 self.outer_context.Exit() 2039 2040 self.Enter() 2041 self.AddName(values_acc.name) 2042 self.AddName(indices_acc.name) 2043 init_acc = [indices_acc, values_acc] 2044 if shape_acc is not None: 2045 self.AddName(shape_acc.name) 2046 init_acc.append(shape_acc) 2047 2048 # Set use_input_shape=False since the accumulator tensors will grow in 2049 # size. If use_input_shape=True, the _update_input call below will result in 2050 # incompatible shapes. 2051 enter_acc = [ 2052 _Enter( 2053 x, 2054 self._name, 2055 is_constant=False, 2056 parallel_iterations=self._parallel_iterations, 2057 use_input_shape=False, 2058 name="b_acc") for x in init_acc 2059 ] 2060 # Manually set appropriate partial shapes. 2061 enter_acc[0].set_shape([None]) 2062 if values_acc.shape.dims is not None: 2063 enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:]) 2064 self.loop_enters.extend(enter_acc) 2065 2066 merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc] 2067 switch_acc = [switch(x, self._pivot) for x in merge_acc] 2068 2069 # The actual accumulation. 2070 acc_indexed_slices = [ 2071 array_ops.concat([xa[1], xv], 0) 2072 for xa, xv in zip(switch_acc[:2], [indices, values]) 2073 ] 2074 if shape_acc is not None: 2075 # For the shape we just keep the maximum 2076 acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1])) 2077 2078 next_acc = [_NextIteration(x) for x in acc_indexed_slices] 2079 for xm, xn in zip(merge_acc, next_acc): 2080 xm.op._update_input(1, xn) # pylint: disable=protected-access 2081 2082 exit_acc = [exit(x[0], name="b_acc") for x in switch_acc] 2083 self.loop_exits.extend(exit_acc) 2084 2085 self.ExitResult(exit_acc) 2086 return ops.IndexedSlices( 2087 indices=exit_acc[0], 2088 values=exit_acc[1], 2089 dense_shape=exit_acc[2] if shape_acc is not None else None) 2090 2091 def _InitializeValues(self, values): 2092 """Makes the values known to this context.""" 2093 self._values = set() 2094 for x in values: 2095 if isinstance(x, ops.Tensor): 2096 self._values.add(x.name) 2097 else: 2098 raise TypeError("Type %s not supported" % type(x)) 2099 2100 def _BuildLoop(self, pred, body, original_loop_vars, loop_vars, 2101 shape_invariants): 2102 """Core: Add the loop termination condition and body to the graph.""" 2103 flat_loop_vars = nest.flatten(original_loop_vars, expand_composites=True) 2104 2105 # Let the context know the loop variables so the loop variables 2106 # would be added in the outer contexts properly. 2107 self._InitializeValues(loop_vars) 2108 real_vars = loop_vars 2109 if self._outer_context: 2110 real_vars = [self._outer_context.AddValue(x) for x in loop_vars] 2111 with ops.control_dependencies(None): 2112 enter_vars = [ 2113 _Enter( 2114 x, 2115 self._name, 2116 is_constant=False, 2117 parallel_iterations=self._parallel_iterations, 2118 use_input_shape=(shape_invariants is None)) for x in real_vars 2119 ] 2120 for x in enter_vars: 2121 x.graph.prevent_feeding(x) 2122 if self._outer_context: 2123 self._outer_context.AddInnerOp(x.op) 2124 2125 # Finds the closest enclosing non-None control pivot. 2126 outer_context = self._outer_context 2127 control_pivot = None 2128 while outer_context is not None and control_pivot is None: 2129 control_pivot = outer_context.GetControlPivot() 2130 # pylint: disable=protected-access 2131 outer_context = outer_context._outer_context 2132 # pylint: enable=protected-access 2133 2134 if control_pivot is not None: 2135 for var in enter_vars: 2136 if util.IsLoopConstantEnter(var.op.inputs[0].op): 2137 # pylint: disable=protected-access 2138 var.op._add_control_input(control_pivot.op) 2139 # pylint: enable=protected-access 2140 _SetShapeInvariants(real_vars, enter_vars, shape_invariants) 2141 2142 # Fix the control inputs and control flow context of these enter ops. 2143 self._FixControlInputsAndContext(enter_vars) 2144 self._InitializeValues(enter_vars) 2145 self._loop_enters = enter_vars 2146 2147 merge_vars = [merge([x, x])[0] for x in enter_vars] 2148 self._pivot_for_pred = merge_vars[0] 2149 2150 # Build the graph for pred. 2151 merge_vars_with_tensor_arrays = ( 2152 _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars)) 2153 packed_vars = nest.pack_sequence_as( 2154 structure=original_loop_vars, 2155 flat_sequence=merge_vars_with_tensor_arrays, 2156 expand_composites=True) 2157 c = ops.convert_to_tensor(pred(*packed_vars)) 2158 self._pivot = loop_cond(c, name="LoopCond") 2159 switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars] 2160 2161 # Build the graph for body. 2162 vars_for_body = [_Identity(x[1]) for x in switch_vars] 2163 self._pivot_for_body = vars_for_body[0] 2164 # Convert TensorArray flow variables inside the context back into 2165 # their associated TensorArrays for calling the body. 2166 vars_for_body_with_tensor_arrays = ( 2167 _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body)) 2168 packed_vars_for_body = nest.pack_sequence_as( 2169 structure=original_loop_vars, 2170 flat_sequence=vars_for_body_with_tensor_arrays, 2171 expand_composites=True) 2172 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2173 body_result = body(*packed_vars_for_body) 2174 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2175 if not nest.is_sequence_or_composite(body_result): 2176 body_result = [body_result] 2177 if len(post_summaries) > len(pre_summaries): 2178 new_summaries = post_summaries[len(pre_summaries):] 2179 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2180 summary_ref[:] = pre_summaries 2181 with ops.control_dependencies(new_summaries): 2182 2183 def map_fn(x): 2184 # TODO(apassos) figure out how to trigger with tensor arrays as well 2185 if isinstance(x, tensor_array_ops.TensorArray): 2186 return x 2187 return array_ops.identity(x) 2188 2189 body_result = nest.map_structure( 2190 map_fn, body_result, expand_composites=True) 2191 2192 # Compare the structure types of input and output of body. 2193 # For backwards compatibility, the first layer is forced to a list 2194 # during this comparison, because inputs are typically lists and 2195 # outputs of the body are typically tuples. 2196 nest.assert_same_structure( 2197 list(packed_vars_for_body), list(body_result), expand_composites=True) 2198 2199 # Store body_result to keep track of TensorArrays returned by body 2200 original_body_result = body_result 2201 # Convert TensorArrays returned by body into their flow variables 2202 result = nest.map_structure( 2203 _convert_tensorarray_to_flow, 2204 nest.flatten(body_result, expand_composites=True), 2205 expand_composites=True) 2206 result = ops.convert_n_to_tensor_or_composite(result) 2207 2208 # Add NextIteration and the back edges to complete the loop. 2209 if len(merge_vars) != len(result): 2210 raise ValueError("Number of inputs and outputs of body must match " 2211 "loop_vars: %d, %d" % (len(merge_vars), len(result))) 2212 next_vars = [] 2213 for m, v in zip(merge_vars, result): 2214 next_vars.append(_AddNextAndBackEdge(m, v)) 2215 2216 # Add the exit ops. 2217 exit_vars = [exit(x[0]) for x in switch_vars] 2218 self._loop_exits = exit_vars 2219 2220 # Exit the loop. 2221 self.ExitResult(exit_vars) 2222 2223 return original_body_result, exit_vars 2224 2225 def BuildLoop(self, pred, body, loop_vars, shape_invariants, 2226 return_same_structure): 2227 """Add the loop termination condition and body to the graph.""" 2228 2229 # Keep original_loop_vars to identify which are TensorArrays 2230 original_loop_vars = loop_vars 2231 # Convert TensorArrays to their flow variables 2232 loop_vars = nest.map_structure( 2233 _convert_tensorarray_to_flow, 2234 nest.flatten(loop_vars, expand_composites=False), 2235 expand_composites=True) 2236 loop_vars = ops.convert_n_to_tensor_or_composite(loop_vars) 2237 if shape_invariants is None: 2238 shape_invariants = nest.map_structure( 2239 _get_shape_invariant, loop_vars, expand_composites=False) 2240 loop_vars = nest.flatten(loop_vars, expand_composites=True) 2241 try: 2242 self.Enter() 2243 # _BuildLoop calls _update_input in several places. _mutation_lock() 2244 # ensures a Session.run call cannot occur between creating and mutating 2245 # new ops. 2246 with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access 2247 original_body_result, exit_vars = self._BuildLoop( 2248 pred, body, original_loop_vars, loop_vars, shape_invariants) 2249 finally: 2250 self.Exit() 2251 2252 flat_result = nest.flatten(original_body_result, expand_composites=True) 2253 # Convert TensorArray flow variables outside the context back into 2254 # their associated TensorArrays for returning to caller. 2255 exit_vars_with_tensor_arrays = ( 2256 _convert_flows_to_tensorarrays(flat_result, exit_vars)) 2257 packed_exit_vars = nest.pack_sequence_as( 2258 structure=original_body_result, 2259 flat_sequence=exit_vars_with_tensor_arrays, 2260 expand_composites=True) 2261 2262 if return_same_structure: 2263 return packed_exit_vars 2264 else: 2265 return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars 2266 2267 def _FixControlInputsAndContext(self, enters): 2268 graph = ops.get_default_graph() 2269 # pylint: disable=protected-access 2270 for e in enters: 2271 if isinstance(e, ops.Tensor): 2272 xs = [e] 2273 else: 2274 raise TypeError("Type %s not supported" % type(e)) 2275 for x in xs: 2276 inp_op = x.op.inputs[0].op 2277 control_inputs = graph._control_dependencies_for_inputs([inp_op]) 2278 outer_control_inputs = [] 2279 for op in control_inputs: 2280 # We need to keep control inputs that are in any ancestor 2281 # ControlFlowContext, and within outer WhileContext. 2282 keep_as_control_input = True 2283 op_ctxt = util.GetOutputContext(op) 2284 outer_ctxt = self.outer_context 2285 outer_while_context = (None if outer_ctxt is None else 2286 outer_ctxt.GetWhileContext()) 2287 while outer_ctxt != op_ctxt: 2288 if outer_ctxt is None or outer_ctxt == outer_while_context: 2289 keep_as_control_input = False 2290 break 2291 outer_ctxt = outer_ctxt.outer_context 2292 if keep_as_control_input: 2293 outer_control_inputs.append(op) 2294 x.op._set_control_flow_context(self) 2295 x.op._add_control_inputs(outer_control_inputs) 2296 graph._record_op_seen_by_control_dependencies(x.op) 2297 # pylint: enable=protected-access 2298 2299 def IsWhileContext(self): 2300 return True 2301 2302 2303# @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature". 2304# pylint: disable=redefined-outer-name 2305@tf_export("while_loop", v1=[]) 2306@deprecation.deprecated_arg_values( 2307 None, 2308 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 2309Instead of: 2310results = tf.while_loop(c, b, vars, back_prop=False) 2311Use: 2312results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""", 2313 warn_once=True, 2314 back_prop=False) 2315def while_loop_v2(cond, 2316 body, 2317 loop_vars, 2318 shape_invariants=None, 2319 parallel_iterations=10, 2320 back_prop=True, 2321 swap_memory=False, 2322 maximum_iterations=None, 2323 name=None): 2324 """Repeat `body` while the condition `cond` is true. 2325 2326 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 2327 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 2328 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 2329 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 2330 `cond` and `body`. `cond` and `body` both take as many arguments as there are 2331 `loop_vars`. 2332 2333 In addition to regular Tensors or IndexedSlices, the body may accept and 2334 return TensorArray objects. The flows of the TensorArray objects will 2335 be appropriately forwarded between loops and during gradient calculations. 2336 2337 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 2338 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 2339 stitches together the graph fragments created during the `cond` and `body` 2340 calls with some additional graph nodes to create the graph flow that 2341 repeats `body` until `cond` returns false. 2342 2343 For correctness, `tf.while_loop()` strictly enforces shape invariants for 2344 the loop variables. A shape invariant is a (possibly partial) shape that 2345 is unchanged across the iterations of the loop. An error will be raised 2346 if the shape of a loop variable after an iteration is determined to be more 2347 general than or incompatible with its shape invariant. For example, a shape 2348 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 2349 compatible with [11, 17]. By default (if the argument `shape_invariants` is 2350 not specified), it is assumed that the initial shape of each tensor in 2351 `loop_vars` is the same in every iteration. The `shape_invariants` argument 2352 allows the caller to specify a less specific shape invariant for each loop 2353 variable, which is needed if the shape varies between iterations. The 2354 `tf.Tensor.set_shape` 2355 function may also be used in the `body` function to indicate that 2356 the output loop variable has a particular shape. The shape invariant for 2357 SparseTensor and IndexedSlices are treated specially as follows: 2358 2359 a) If a loop variable is a SparseTensor, the shape invariant must be 2360 TensorShape([r]) where r is the rank of the dense tensor represented 2361 by the sparse tensor. It means the shapes of the three tensors of the 2362 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 2363 is the shape of the SparseTensor.dense_shape property. It must be the shape of 2364 a vector. 2365 2366 b) If a loop variable is an IndexedSlices, the shape invariant must be 2367 a shape invariant of the values tensor of the IndexedSlices. It means 2368 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 2369 [shape.ndims]). 2370 2371 `while_loop` implements non-strict semantics, enabling multiple iterations 2372 to run in parallel. The maximum number of parallel iterations can be 2373 controlled by `parallel_iterations`, which gives users some control over 2374 memory consumption and execution order. For correct programs, `while_loop` 2375 should return the same result for any parallel_iterations > 0. 2376 2377 For training, TensorFlow stores the tensors that are produced in the 2378 forward inference and are needed in back propagation. These tensors are a 2379 main source of memory consumption and often cause OOM errors when training 2380 on GPUs. When the flag swap_memory is true, we swap out these tensors from 2381 GPU to CPU. This for example allows us to train RNN models with very long 2382 sequences and large batches. 2383 2384 Args: 2385 cond: A callable that represents the termination condition of the loop. 2386 body: A callable that represents the loop body. 2387 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 2388 `Tensor`, and `TensorArray` objects. 2389 shape_invariants: The shape invariants for the loop variables. 2390 parallel_iterations: The number of iterations allowed to run in parallel. It 2391 must be a positive integer. 2392 back_prop: (optional) Deprecated. False disables support for back 2393 propagation. Prefer using `tf.stop_gradient` instead. 2394 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2395 maximum_iterations: Optional maximum number of iterations of the while loop 2396 to run. If provided, the `cond` output is AND-ed with an additional 2397 condition ensuring the number of iterations executed is no greater than 2398 `maximum_iterations`. 2399 name: Optional name prefix for the returned tensors. 2400 2401 Returns: 2402 The output tensors for the loop variables after the loop. The return value 2403 has the same structure as `loop_vars`. 2404 2405 Raises: 2406 TypeError: if `cond` or `body` is not callable. 2407 ValueError: if `loop_vars` is empty. 2408 2409 Example: 2410 2411 ```python 2412 i = tf.constant(0) 2413 c = lambda i: tf.less(i, 10) 2414 b = lambda i: (tf.add(i, 1), ) 2415 r = tf.while_loop(c, b, [i]) 2416 ``` 2417 2418 Example with nesting and a namedtuple: 2419 2420 ```python 2421 import collections 2422 Pair = collections.namedtuple('Pair', 'j, k') 2423 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 2424 c = lambda i, p: i < 10 2425 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 2426 ijk_final = tf.while_loop(c, b, ijk_0) 2427 ``` 2428 2429 Example using shape_invariants: 2430 2431 ```python 2432 i0 = tf.constant(0) 2433 m0 = tf.ones([2, 2]) 2434 c = lambda i, m: i < 10 2435 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 2436 tf.while_loop( 2437 c, b, loop_vars=[i0, m0], 2438 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 2439 ``` 2440 2441 Example which demonstrates non-strict semantics: In the following 2442 example, the final value of the counter `i` does not depend on `x`. So 2443 the `while_loop` can increment the counter parallel to updates of `x`. 2444 However, because the loop counter at one loop iteration depends 2445 on the value at the previous iteration, the loop counter itself cannot 2446 be incremented in parallel. Hence if we just want the final value of the 2447 counter (which we print on the line `print(sess.run(i))`), then 2448 `x` will never be incremented, but the counter will be updated on a 2449 single thread. Conversely, if we want the value of the output (which we 2450 print on the line `print(sess.run(out).shape)`), then the counter may be 2451 incremented on its own thread, while `x` can be incremented in 2452 parallel on a separate thread. In the extreme case, it is conceivable 2453 that the thread incrementing the counter runs until completion before 2454 `x` is incremented even a single time. The only thing that can never 2455 happen is that the thread updating `x` can never get ahead of the 2456 counter thread because the thread incrementing `x` depends on the value 2457 of the counter. 2458 2459 ```python 2460 import tensorflow as tf 2461 2462 n = 10000 2463 x = tf.constant(list(range(n))) 2464 c = lambda i, x: i < n 2465 b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1, 2466 [i], "x:")) 2467 i, out = tf.while_loop(c, b, (0, x)) 2468 with tf.compat.v1.Session() as sess: 2469 print(sess.run(i)) # prints [0] ... [9999] 2470 2471 # The following line may increment the counter and x in parallel. 2472 # The counter thread may get ahead of the other thread, but not the 2473 # other way around. So you may see things like 2474 # [9996] x:[9987] 2475 # meaning that the counter thread is on iteration 9996, 2476 # while the other thread is on iteration 9987 2477 print(sess.run(out).shape) 2478 ``` 2479 2480 """ 2481 return while_loop( 2482 cond=cond, 2483 body=body, 2484 loop_vars=loop_vars, 2485 shape_invariants=shape_invariants, 2486 parallel_iterations=parallel_iterations, 2487 back_prop=back_prop, 2488 swap_memory=swap_memory, 2489 name=name, 2490 maximum_iterations=maximum_iterations, 2491 return_same_structure=True) 2492 2493 2494# pylint: disable=redefined-outer-name 2495@tf_export(v1=["while_loop"]) 2496def while_loop(cond, 2497 body, 2498 loop_vars, 2499 shape_invariants=None, 2500 parallel_iterations=10, 2501 back_prop=True, 2502 swap_memory=False, 2503 name=None, 2504 maximum_iterations=None, 2505 return_same_structure=False): 2506 """Repeat `body` while the condition `cond` is true. 2507 2508 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 2509 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 2510 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 2511 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 2512 `cond` and `body`. `cond` and `body` both take as many arguments as there are 2513 `loop_vars`. 2514 2515 In addition to regular Tensors or IndexedSlices, the body may accept and 2516 return TensorArray objects. The flows of the TensorArray objects will 2517 be appropriately forwarded between loops and during gradient calculations. 2518 2519 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 2520 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 2521 stitches together the graph fragments created during the `cond` and `body` 2522 calls with some additional graph nodes to create the graph flow that 2523 repeats `body` until `cond` returns false. 2524 2525 For correctness, `tf.while_loop()` strictly enforces shape invariants for 2526 the loop variables. A shape invariant is a (possibly partial) shape that 2527 is unchanged across the iterations of the loop. An error will be raised 2528 if the shape of a loop variable after an iteration is determined to be more 2529 general than or incompatible with its shape invariant. For example, a shape 2530 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 2531 compatible with [11, 17]. By default (if the argument `shape_invariants` is 2532 not specified), it is assumed that the initial shape of each tensor in 2533 `loop_vars` is the same in every iteration. The `shape_invariants` argument 2534 allows the caller to specify a less specific shape invariant for each loop 2535 variable, which is needed if the shape varies between iterations. The 2536 `tf.Tensor.set_shape` 2537 function may also be used in the `body` function to indicate that 2538 the output loop variable has a particular shape. The shape invariant for 2539 SparseTensor and IndexedSlices are treated specially as follows: 2540 2541 a) If a loop variable is a SparseTensor, the shape invariant must be 2542 TensorShape([r]) where r is the rank of the dense tensor represented 2543 by the sparse tensor. It means the shapes of the three tensors of the 2544 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 2545 is the shape of the SparseTensor.dense_shape property. It must be the shape of 2546 a vector. 2547 2548 b) If a loop variable is an IndexedSlices, the shape invariant must be 2549 a shape invariant of the values tensor of the IndexedSlices. It means 2550 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 2551 [shape.ndims]). 2552 2553 `while_loop` implements non-strict semantics, enabling multiple iterations 2554 to run in parallel. The maximum number of parallel iterations can be 2555 controlled by `parallel_iterations`, which gives users some control over 2556 memory consumption and execution order. For correct programs, `while_loop` 2557 should return the same result for any parallel_iterations > 0. 2558 2559 For training, TensorFlow stores the tensors that are produced in the 2560 forward inference and are needed in back propagation. These tensors are a 2561 main source of memory consumption and often cause OOM errors when training 2562 on GPUs. When the flag swap_memory is true, we swap out these tensors from 2563 GPU to CPU. This for example allows us to train RNN models with very long 2564 sequences and large batches. 2565 2566 Args: 2567 cond: A callable that represents the termination condition of the loop. 2568 body: A callable that represents the loop body. 2569 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 2570 `Tensor`, and `TensorArray` objects. 2571 shape_invariants: The shape invariants for the loop variables. 2572 parallel_iterations: The number of iterations allowed to run in parallel. It 2573 must be a positive integer. 2574 back_prop: Whether backprop is enabled for this while loop. 2575 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2576 name: Optional name prefix for the returned tensors. 2577 maximum_iterations: Optional maximum number of iterations of the while loop 2578 to run. If provided, the `cond` output is AND-ed with an additional 2579 condition ensuring the number of iterations executed is no greater than 2580 `maximum_iterations`. 2581 return_same_structure: If True, output has same structure as `loop_vars`. If 2582 eager execution is enabled, this is ignored (and always treated as True). 2583 2584 Returns: 2585 The output tensors for the loop variables after the loop. 2586 If `return_same_structure` is True, the return value has the same 2587 structure as `loop_vars`. 2588 If `return_same_structure` is False, the return value is a Tensor, 2589 TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list 2590 otherwise. 2591 2592 Raises: 2593 TypeError: if `cond` or `body` is not callable. 2594 ValueError: if `loop_vars` is empty. 2595 2596 Example: 2597 2598 ```python 2599 i = tf.constant(0) 2600 c = lambda i: tf.less(i, 10) 2601 b = lambda i: tf.add(i, 1) 2602 r = tf.while_loop(c, b, [i]) 2603 ``` 2604 2605 Example with nesting and a namedtuple: 2606 2607 ```python 2608 import collections 2609 Pair = collections.namedtuple('Pair', 'j, k') 2610 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 2611 c = lambda i, p: i < 10 2612 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 2613 ijk_final = tf.while_loop(c, b, ijk_0) 2614 ``` 2615 2616 Example using shape_invariants: 2617 2618 ```python 2619 i0 = tf.constant(0) 2620 m0 = tf.ones([2, 2]) 2621 c = lambda i, m: i < 10 2622 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 2623 tf.while_loop( 2624 c, b, loop_vars=[i0, m0], 2625 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 2626 ``` 2627 2628 Example which demonstrates non-strict semantics: In the following 2629 example, the final value of the counter `i` does not depend on `x`. So 2630 the `while_loop` can increment the counter parallel to updates of `x`. 2631 However, because the loop counter at one loop iteration depends 2632 on the value at the previous iteration, the loop counter itself cannot 2633 be incremented in parallel. Hence if we just want the final value of the 2634 counter (which we print on the line `print(sess.run(i))`), then 2635 `x` will never be incremented, but the counter will be updated on a 2636 single thread. Conversely, if we want the value of the output (which we 2637 print on the line `print(sess.run(out).shape)`), then the counter may be 2638 incremented on its own thread, while `x` can be incremented in 2639 parallel on a separate thread. In the extreme case, it is conceivable 2640 that the thread incrementing the counter runs until completion before 2641 `x` is incremented even a single time. The only thing that can never 2642 happen is that the thread updating `x` can never get ahead of the 2643 counter thread because the thread incrementing `x` depends on the value 2644 of the counter. 2645 2646 ```python 2647 import tensorflow as tf 2648 2649 n = 10000 2650 x = tf.constant(list(range(n))) 2651 c = lambda i, x: i < n 2652 b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1, 2653 [i], "x:")) 2654 i, out = tf.while_loop(c, b, (0, x)) 2655 with tf.compat.v1.Session() as sess: 2656 print(sess.run(i)) # prints [0] ... [9999] 2657 2658 # The following line may increment the counter and x in parallel. 2659 # The counter thread may get ahead of the other thread, but not the 2660 # other way around. So you may see things like 2661 # [9996] x:[9987] 2662 # meaning that the counter thread is on iteration 9996, 2663 # while the other thread is on iteration 9987 2664 print(sess.run(out).shape) 2665 ``` 2666 2667 """ 2668 if not callable(cond): 2669 raise TypeError("cond must be callable.") 2670 if not callable(body): 2671 raise TypeError("body must be callable.") 2672 if parallel_iterations < 1: 2673 raise TypeError("parallel_iterations must be a positive integer.") 2674 2675 # Always enable control flow v2 if building a function, regardless of toggle. 2676 executing_eagerly = context.executing_eagerly() 2677 if (util.EnableControlFlowV2(ops.get_default_graph()) and 2678 not executing_eagerly): 2679 return while_v2.while_loop( 2680 cond, 2681 body, 2682 loop_vars, 2683 shape_invariants=shape_invariants, 2684 parallel_iterations=parallel_iterations, 2685 maximum_iterations=maximum_iterations, 2686 name=name, 2687 return_same_structure=return_same_structure, 2688 back_prop=back_prop) 2689 2690 with ops.name_scope(name, "while", loop_vars): 2691 if not loop_vars: 2692 raise ValueError("No loop variables provided") 2693 try_to_pack = (len(loop_vars) == 1 and not return_same_structure) 2694 if maximum_iterations is not None: 2695 maximum_iterations = ops.convert_to_tensor( 2696 maximum_iterations, name="maximum_iterations") 2697 if maximum_iterations.shape.ndims != 0: 2698 raise ValueError("maximum_iterations must be a scalar, saw shape: %s" % 2699 maximum_iterations.shape) 2700 2701 if executing_eagerly: 2702 counter = 0 2703 maximum_iterations = int(maximum_iterations.numpy()) 2704 else: 2705 counter = constant_op.constant( 2706 0, dtype=maximum_iterations.dtype, name="iteration_counter") 2707 orig_cond = cond 2708 orig_body = body 2709 if try_to_pack: 2710 loop_vars = (counter, loop_vars[0]) 2711 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 2712 math_ops.logical_and(i < maximum_iterations, orig_cond(lv))) 2713 body = lambda i, lv: (i + 1, orig_body(lv)) 2714 else: 2715 loop_vars = (counter, loop_vars) 2716 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 2717 math_ops.logical_and(i < maximum_iterations, orig_cond(*lv))) 2718 body = lambda i, lv: (i + 1, orig_body(*lv)) 2719 try_to_pack = False 2720 2721 if executing_eagerly: 2722 packed = False # whether the body result was packed into a 1-item tuple 2723 2724 loop_var_structure = nest.map_structure(type_spec.type_spec_from_value, 2725 list(loop_vars)) 2726 while cond(*loop_vars): 2727 loop_vars = body(*loop_vars) 2728 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)): 2729 packed = True 2730 loop_vars = (loop_vars,) 2731 nest.assert_same_structure(loop_var_structure, list(loop_vars)) 2732 2733 def convert(x): 2734 if isinstance(x, tensor_array_ops.TensorArray): 2735 return x 2736 return ops.convert_to_tensor(x) 2737 2738 loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True) 2739 if maximum_iterations is not None: 2740 return loop_vars[1] 2741 else: 2742 return loop_vars[0] if packed else loop_vars 2743 2744 if shape_invariants is not None: 2745 if maximum_iterations is not None: 2746 shape_invariants = (tensor_shape.TensorShape([]), shape_invariants) 2747 2748 nest.assert_same_structure( 2749 loop_vars, shape_invariants, expand_composites=False) 2750 shape_invariants = nest.map_structure( 2751 _get_shape_invariant, 2752 loop_vars, 2753 shape_invariants, 2754 expand_composites=False) 2755 2756 loop_context = WhileContext( 2757 maximum_iterations=maximum_iterations, 2758 parallel_iterations=parallel_iterations, 2759 back_prop=back_prop, 2760 swap_memory=swap_memory) 2761 # Only add non-nested loops to the collection. Any nested control flow will 2762 # be encapsulated in the root context. 2763 if loop_context.outer_context is None: 2764 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) 2765 result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, 2766 return_same_structure) 2767 if maximum_iterations is not None: 2768 return result[1] 2769 else: 2770 return result 2771 2772 2773# pylint: enable=redefined-outer-name 2774 2775 2776def _AsTensorList(x, p): 2777 """Return x as a list of Tensors or IndexedSlices. 2778 2779 For entries of `x` that are Operations, this returns an Identity of `p` 2780 with a dependency on the operation. 2781 2782 Args: 2783 x: A Tensor/IndexedSlices/Operation or a list or tuple of them. 2784 p: A Tensor to return for entries in `x` that are Operations. 2785 2786 Returns: 2787 A list of Tensors or IndexedSlices. 2788 """ 2789 if not isinstance(x, (list, _basetuple)): 2790 x = [x] 2791 2792 l = [] 2793 for v in x: 2794 if isinstance(v, ops.Operation): 2795 v = with_dependencies([v], p) 2796 v = ops.convert_to_tensor_or_composite(v) 2797 if isinstance(v, ops.Tensor): 2798 l.append(array_ops.identity(v)) 2799 else: 2800 l.append( 2801 ops.IndexedSlices( 2802 array_ops.identity(v.values), array_ops.identity(v.indices))) 2803 return l 2804 2805 2806def _CheckResults(a, b): 2807 assert len(a) == len(b), ( 2808 "Values returned by a() and b() must have the same length.") 2809 for x, y in zip(a, b): 2810 assert x.dtype == y.dtype, ( 2811 "Values returned by a() [%s] and b() [%s] must have " 2812 "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name)) 2813 2814 2815def with_dependencies(dependencies, output_tensor, name=None): 2816 """Produces the content of `output_tensor` only after `dependencies`. 2817 2818 In some cases, a user may want the output of an operation to be 2819 consumed externally only after some other dependencies have run 2820 first. This function ensures returns `output_tensor`, but only after all 2821 operations in `dependencies` have run. Note that this means that there is 2822 no guarantee that `output_tensor` will be evaluated after any `dependencies` 2823 have run. 2824 2825 See also `tf.tuple` and `tf.group`. 2826 2827 Args: 2828 dependencies: Iterable of operations to run before this op finishes. 2829 output_tensor: A `Tensor` or `IndexedSlices` that will be returned. 2830 name: (Optional) A name for this operation. 2831 2832 Returns: 2833 Same as `output_tensor`. 2834 2835 Raises: 2836 TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. 2837 """ 2838 if context.executing_eagerly(): 2839 return output_tensor 2840 with ops.name_scope(name, "control_dependency", 2841 list(dependencies) + [output_tensor]) as name: 2842 with ops.colocate_with(output_tensor): 2843 with ops.control_dependencies(dependencies): 2844 output_tensor = ops.convert_to_tensor_or_composite(output_tensor) 2845 if isinstance(output_tensor, ops.Tensor): 2846 return _Identity(output_tensor, name=name) 2847 else: 2848 return ops.IndexedSlices( 2849 _Identity(output_tensor.values, name=name), output_tensor.indices, 2850 output_tensor.dense_shape) 2851 2852 2853def _GroupControlDeps(dev, deps, name=None): 2854 with ops.control_dependencies(deps): 2855 if dev is None: 2856 return no_op(name=name) 2857 else: 2858 with ops.device(dev): 2859 return no_op(name=name) 2860 2861 2862# TODO(touts): Accept "inputs" as a list. 2863@tf_export("group") 2864def group(*inputs, **kwargs): 2865 """Create an op that groups multiple operations. 2866 2867 When this op finishes, all ops in `inputs` have finished. This op has no 2868 output. 2869 2870 See also `tf.tuple` and 2871 `tf.control_dependencies`. 2872 2873 Args: 2874 *inputs: Zero or more tensors to group. 2875 name: A name for this operation (optional). 2876 2877 Returns: 2878 An Operation that executes all its inputs. 2879 2880 Raises: 2881 ValueError: If an unknown keyword argument is provided. 2882 """ 2883 if context.executing_eagerly(): 2884 return None 2885 name = kwargs.pop("name", None) 2886 if kwargs: 2887 raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys())) 2888 with ops.name_scope(name, "group_deps", inputs) as name: 2889 # Grouping no inputs means do nothing 2890 if not inputs: 2891 return no_op(name=name) 2892 2893 # Sorts *inputs according to their devices. 2894 ops_on_device = {} # device -> operations specified on the device. 2895 for inp in nest.flatten(inputs, expand_composites=True): 2896 if not hasattr(inp, "device"): 2897 raise TypeError("Expected tf.group() expected Tensor arguments not " 2898 "'%s' with type '%s'" % (inp, type(inp))) 2899 dev = inp.device 2900 if dev in ops_on_device: 2901 ops_on_device[dev].append(inp) 2902 else: 2903 ops_on_device[dev] = [inp] 2904 if len(ops_on_device) == 1: 2905 # 1-level tree. The root node is the returned NoOp node. 2906 (dev, deps), = ops_on_device.items() 2907 return _GroupControlDeps(dev, deps, name=name) 2908 2909 # 2-level tree. The root node is the returned NoOp node. 2910 # deps contains 1 NoOp node for each device. 2911 deps = [] 2912 2913 def device_key(dev): 2914 """A sort key that allows None to be compared to strings.""" 2915 return "" if dev is None else dev 2916 2917 for dev in sorted(ops_on_device, key=device_key): 2918 deps.append(_GroupControlDeps(dev, ops_on_device[dev])) 2919 2920 with ops.control_dependencies(deps): 2921 return no_op(name=name) 2922 2923 2924@tf_export("tuple", v1=[]) 2925def tuple_v2(tensors, control_inputs=None, name=None): 2926 """Group tensors together. 2927 2928 This creates a tuple of tensors with the same values as the `tensors` 2929 argument, except that the value of each tensor is only returned after the 2930 values of all tensors have been computed. 2931 2932 `control_inputs` contains additional ops that have to finish before this op 2933 finishes, but whose outputs are not returned. 2934 2935 This can be used as a "join" mechanism for parallel computations: all the 2936 argument tensors can be computed in parallel, but the values of any tensor 2937 returned by `tuple` are only available after all the parallel computations 2938 are done. 2939 2940 See also `tf.group` and 2941 `tf.control_dependencies`. 2942 2943 Args: 2944 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 2945 control_inputs: List of additional ops to finish before returning. 2946 name: (optional) A name to use as a `name_scope` for the operation. 2947 2948 Returns: 2949 Same as `tensors`. 2950 2951 Raises: 2952 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 2953 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 2954 objects. 2955 2956 """ 2957 return tuple(tensors=tensors, name=name, control_inputs=control_inputs) # pylint: disable=redefined-builtin 2958 2959 2960@tf_export(v1=["tuple"]) 2961def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin 2962 """Group tensors together. 2963 2964 This creates a tuple of tensors with the same values as the `tensors` 2965 argument, except that the value of each tensor is only returned after the 2966 values of all tensors have been computed. 2967 2968 `control_inputs` contains additional ops that have to finish before this op 2969 finishes, but whose outputs are not returned. 2970 2971 This can be used as a "join" mechanism for parallel computations: all the 2972 argument tensors can be computed in parallel, but the values of any tensor 2973 returned by `tuple` are only available after all the parallel computations 2974 are done. 2975 2976 See also `tf.group` and 2977 `tf.control_dependencies`. 2978 2979 Args: 2980 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 2981 name: (optional) A name to use as a `name_scope` for the operation. 2982 control_inputs: List of additional ops to finish before returning. 2983 2984 Returns: 2985 Same as `tensors`. 2986 2987 Raises: 2988 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 2989 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 2990 objects. 2991 2992 """ 2993 if context.executing_eagerly(): 2994 return tensors 2995 with ops.name_scope(name, "tuple", tensors) as name: 2996 tensors = [ 2997 t if (isinstance(t, ops.Operation) or tensor_util.is_tensor(t) or 2998 t is None) else ops.convert_to_tensor(t) for t in tensors 2999 ] 3000 gating_ops = [ 3001 t if isinstance(t, ops.Operation) else t.op 3002 for t in tensors 3003 if t is not None 3004 ] 3005 if control_inputs: 3006 for c in control_inputs: 3007 if isinstance(c, ops.Tensor): 3008 c = c.op 3009 elif not isinstance(c, ops.Operation): 3010 raise TypeError("Control input must be Operation or Tensor: %s" % c) 3011 gating_ops.append(c) 3012 # Note that in order to ensure ordering in the pbtxt, we must take care to 3013 # ensure the order here. 3014 gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops. 3015 if not gating_ops: 3016 raise ValueError("Must have at least one Tensor: %s" % tensors) 3017 gate = group(*gating_ops) 3018 tpl = [] 3019 for t in tensors: 3020 if tensor_util.is_tensor(t): 3021 tpl.append(with_dependencies([gate], t)) 3022 elif isinstance(t, ops.Operation): 3023 with ops.control_dependencies([gate]): 3024 tpl.append(group(t)) 3025 else: 3026 tpl.append(None) 3027 return tpl 3028 3029 3030def _assert_at_most_n_true(predicates, n, msg): 3031 """Returns an Assert op that checks that at most n predicates are True. 3032 3033 Args: 3034 predicates: list of bool scalar tensors. 3035 n: maximum number of true predicates allowed. 3036 msg: Error message. 3037 """ 3038 preds_c = array_ops.stack(predicates, name="preds_c") 3039 num_true_conditions = math_ops.reduce_sum( 3040 math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") 3041 condition = math_ops.less_equal(num_true_conditions, 3042 constant_op.constant(n, name="n_true_conds")) 3043 preds_names = ", ".join(getattr(p, "name", "?") for p in predicates) 3044 error_msg = [ 3045 "%s: more than %d conditions (%s) evaluated as True:" % 3046 (msg, n, preds_names), preds_c 3047 ] 3048 return Assert(condition, data=error_msg, summarize=len(predicates)) 3049 3050 3051def _case_create_default_action(predicates, actions): 3052 """Creates default action for a list of actions and their predicates. 3053 3054 It uses the input actions to select an arbitrary as default and makes sure 3055 that corresponding predicates have valid values. 3056 3057 Args: 3058 predicates: a list of bool scalar tensors 3059 actions: a list of callable objects which return tensors. 3060 3061 Returns: 3062 a callable 3063 """ 3064 k = len(predicates) - 1 # could pick any 3065 predicate, action = predicates[k], actions[k] 3066 other_predicates, other_actions = predicates[:k], actions[:k] 3067 3068 def default_action(): 3069 others_msg = ("Implementation error: " 3070 "selected default action #%d was called, but some of other " 3071 "predicates are True: " % k) 3072 default_msg = ("Input error: " 3073 "None of conditions evaluated as True:", 3074 array_ops.stack(predicates, name="preds_c")) 3075 with ops.control_dependencies([ 3076 _assert_at_most_n_true(other_predicates, n=0, msg=others_msg), 3077 Assert(predicate, data=default_msg) 3078 ]): 3079 return action() 3080 3081 return default_action, other_predicates, other_actions 3082 3083 3084def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name, 3085 allow_python_preds): 3086 """Verifies input arguments for the case function. 3087 3088 Args: 3089 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 3090 callable which returns a list of tensors. 3091 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3092 name: A name for the case operation. 3093 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 3094 addition to boolean Tensors 3095 3096 Raises: 3097 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3098 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3099 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3100 callable. 3101 3102 Returns: 3103 a tuple <list of scalar bool tensors, list of callables>. 3104 """ 3105 if not isinstance(pred_fn_pairs, (list, _basetuple, dict)): 3106 raise TypeError("fns must be a list, tuple, or dict") 3107 3108 if isinstance(pred_fn_pairs, collections.OrderedDict): 3109 pred_fn_pairs = pred_fn_pairs.items() 3110 elif isinstance(pred_fn_pairs, dict): 3111 if context.executing_eagerly(): 3112 # No name to sort on in eager mode. Use dictionary traversal order, 3113 # which is nondeterministic in versions of Python < 3.6 3114 if not exclusive: 3115 raise ValueError("Unordered dictionaries are not supported for the " 3116 "`pred_fn_pairs` argument when `exclusive=False` and " 3117 "eager mode is enabled.") 3118 pred_fn_pairs = list(pred_fn_pairs.items()) 3119 else: 3120 pred_fn_pairs = sorted( 3121 pred_fn_pairs.items(), key=lambda item: item[0].name) 3122 if not exclusive: 3123 logging.warn( 3124 "%s: An unordered dictionary of predicate/fn pairs was " 3125 "provided, but exclusive=False. The order of conditional " 3126 "tests is deterministic but not guaranteed.", name) 3127 for pred_fn_pair in pred_fn_pairs: 3128 if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2: 3129 raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple") 3130 pred, fn = pred_fn_pair 3131 3132 if isinstance(pred, ops.Tensor): 3133 if pred.dtype != dtypes.bool: 3134 raise TypeError("pred must be Tensor of type bool: %s" % pred.name) 3135 elif not allow_python_preds: 3136 raise TypeError("pred must be a Tensor, got: %s" % pred) 3137 elif not isinstance(pred, bool): 3138 raise TypeError("pred must be a Tensor or bool, got: %s" % pred) 3139 3140 if not callable(fn): 3141 raise TypeError("fn for pred %s must be callable." % pred.name) 3142 3143 predicates, actions = zip(*pred_fn_pairs) 3144 return predicates, actions 3145 3146 3147def _case_helper(cond_fn, 3148 pred_fn_pairs, 3149 default, 3150 exclusive, 3151 name, 3152 allow_python_preds=False, 3153 **cond_kwargs): 3154 """Implementation of case that allows for different cond functions. 3155 3156 Args: 3157 cond_fn: method that has signature and semantics of `cond` above. 3158 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 3159 callable which returns a list of tensors. 3160 default: Optional callable that returns a list of tensors. 3161 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3162 name: A name for this operation (optional). 3163 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 3164 addition to boolean Tensors 3165 **cond_kwargs: keyword arguments that will be passed to `cond_fn`. 3166 3167 Returns: 3168 The tensors returned by the first pair whose predicate evaluated to True, or 3169 those returned by `default` if none does. 3170 3171 Raises: 3172 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3173 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3174 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3175 callable. 3176 """ 3177 predicates, actions = _case_verify_and_canonicalize_args( 3178 pred_fn_pairs, exclusive, name, allow_python_preds) 3179 with ops.name_scope(name, "case", [predicates]): 3180 if default is None: 3181 default, predicates, actions = _case_create_default_action( 3182 predicates, actions) 3183 fn = default 3184 # To eval conditions in direct order we create nested conditions in reverse: 3185 # cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...)) 3186 for predicate, action in reversed(list(zip(predicates, actions))): 3187 fn = functools.partial( 3188 cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs) 3189 if exclusive: 3190 with ops.control_dependencies([ 3191 _assert_at_most_n_true( 3192 predicates, n=1, msg="Input error: exclusive=True") 3193 ]): 3194 return fn() 3195 else: 3196 return fn() 3197 3198 3199def _indexed_case_verify_and_canonicalize_args(branch_fns, default, 3200 branch_index): 3201 """Verifies input arguments for the case function. 3202 3203 Args: 3204 branch_fns: Dict or list of pairs of an `int` and a callable which 3205 returns a list of tensors. 3206 default: Optional callable that returns a list of tensors. 3207 branch_index: Optional int `Tensor`, which selects for the corresponding 3208 pred_fn_pair. 3209 3210 Raises: 3211 TypeError: If `branch_fns` is not a list/dictionary. 3212 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3213 callables. 3214 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3215 callable. 3216 3217 Returns: 3218 branch_fns: validated list of callables for each branch (default last). 3219 """ 3220 if not isinstance(branch_index, ops.Tensor): 3221 raise TypeError("branch_index must a Tensor, got {}".format( 3222 type(branch_index))) 3223 if not branch_index.dtype.is_integer: 3224 raise TypeError("branch_index must an integer Tensor, got {}".format( 3225 branch_index.dtype)) 3226 3227 if not branch_fns: 3228 raise ValueError("Must provide at least one item in branch_fns") 3229 if not isinstance(branch_fns, (list, _basetuple, dict)): 3230 raise TypeError("branch_fns must be a list, tuple, or dict") 3231 3232 if isinstance(branch_fns, dict): 3233 branch_fns = branch_fns.items() 3234 3235 if all(callable(fn) for fn in branch_fns): 3236 branch_fns = list(enumerate(branch_fns)) 3237 3238 for key_fn_pair in branch_fns: 3239 if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2: 3240 raise TypeError("Each entry in branch_fns must be a 2-tuple") 3241 key, branch_fn = key_fn_pair 3242 3243 if not isinstance(key, int): 3244 raise TypeError("key must be a Python `int`, got {}".format(type(key))) 3245 3246 if not callable(branch_fn): 3247 raise TypeError("fn for key {} must be callable.".format(key)) 3248 3249 keys = [p[0] for p in branch_fns] 3250 if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys): 3251 raise ValueError( 3252 "branch indices (keys) must form contiguous range of [0 to {}) but " 3253 "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys))))) 3254 actions = [p[1] for p in sorted(branch_fns)] 3255 if default is not None: 3256 actions.append(default) 3257 return actions 3258 3259 3260def _indexed_case_helper(branch_fns, default, branch_index, name): 3261 """Implementation of case that emits the n-way indexed Case op. 3262 3263 Args: 3264 branch_fns: Dict or list of pairs of a boolean scalar tensor, and a 3265 callable which returns a list of tensors. 3266 default: Optional callable that returns a list of tensors. 3267 branch_index: Optional int `Tensor`, which selects for the corresponding 3268 pred_fn_pair. 3269 name: A name for this operation (optional). 3270 3271 Returns: 3272 The tensors returned by the pair whose key matched branch_index, or 3273 those returned by `default` if none does. 3274 3275 Raises: 3276 TypeError: If `branch_fns` is not a list/dictionary. 3277 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3278 callables. 3279 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3280 callable. 3281 """ 3282 branch_fns = _indexed_case_verify_and_canonicalize_args( 3283 branch_fns, default, branch_index) 3284 with ops.name_scope(name, "case", [branch_index]): 3285 if context.executing_eagerly() and not hasattr(branch_index, "graph"): 3286 branch_index = array_ops.where( 3287 math_ops.less(branch_index, 0) 3288 | math_ops.greater_equal(branch_index, len(branch_fns)), 3289 len(branch_fns) - 1, branch_index) 3290 return branch_fns[int(branch_index)]() 3291 return cond_v2.indexed_case(branch_index, branch_fns) 3292 3293 3294@tf_export("case", v1=[]) 3295def case_v2(pred_fn_pairs, 3296 default=None, 3297 exclusive=False, 3298 strict=False, 3299 name="case"): 3300 """Create a case operation. 3301 3302 See also `tf.switch_case`. 3303 3304 The `pred_fn_pairs` parameter is a list of pairs of size N. 3305 Each pair contains a boolean scalar tensor and a python callable that 3306 creates the tensors to be returned if the boolean evaluates to True. 3307 `default` is a callable generating a list of tensors. All the callables 3308 in `pred_fn_pairs` as well as `default` (if provided) should return the same 3309 number and types of tensors. 3310 3311 If `exclusive==True`, all predicates are evaluated, and an exception is 3312 thrown if more than one of the predicates evaluates to `True`. 3313 If `exclusive==False`, execution stops at the first predicate which 3314 evaluates to True, and the tensors generated by the corresponding function 3315 are returned immediately. If none of the predicates evaluate to True, this 3316 operation returns the tensors generated by `default`. 3317 3318 `tf.case` supports nested structures as implemented in 3319 `tf.contrib.framework.nest`. All of the callables must return the same 3320 (possibly nested) value structure of lists, tuples, and/or named tuples. 3321 Singleton lists and tuples form the only exceptions to this: when returned by 3322 a callable, they are implicitly unpacked to single values. This 3323 behavior is disabled by passing `strict=True`. 3324 3325 @compatibility(v2) 3326 `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and 3327 tf.Variable are no longer hashable in v2, so cannot be used as a key for a 3328 dictionary. Please use a list or a tuple instead. 3329 @end_compatibility 3330 3331 3332 **Example 1:** 3333 3334 Pseudocode: 3335 3336 ``` 3337 if (x < y) return 17; 3338 else return 23; 3339 ``` 3340 3341 Expressions: 3342 3343 ```python 3344 f1 = lambda: tf.constant(17) 3345 f2 = lambda: tf.constant(23) 3346 r = tf.case([(tf.less(x, y), f1)], default=f2) 3347 ``` 3348 3349 **Example 2:** 3350 3351 Pseudocode: 3352 3353 ``` 3354 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 3355 if (x < y) return 17; 3356 else if (x > z) return 23; 3357 else return -1; 3358 ``` 3359 3360 Expressions: 3361 3362 ```python 3363 def f1(): return tf.constant(17) 3364 def f2(): return tf.constant(23) 3365 def f3(): return tf.constant(-1) 3366 r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)], 3367 default=f3, exclusive=True) 3368 ``` 3369 3370 Args: 3371 pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which 3372 returns a list of tensors. 3373 default: Optional callable that returns a list of tensors. 3374 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3375 strict: A boolean that enables/disables 'strict' mode; see above. 3376 name: A name for this operation (optional). 3377 3378 Returns: 3379 The tensors returned by the first pair whose predicate evaluated to True, or 3380 those returned by `default` if none does. 3381 3382 Raises: 3383 TypeError: If `pred_fn_pairs` is not a list/tuple. 3384 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3385 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3386 callable. 3387 """ 3388 return _case_helper( 3389 cond, 3390 pred_fn_pairs, 3391 default, 3392 exclusive, 3393 name, 3394 allow_python_preds=False, 3395 strict=strict) 3396 3397 3398@tf_export(v1=["case"]) 3399def case(pred_fn_pairs, 3400 default=None, 3401 exclusive=False, 3402 strict=False, 3403 name="case"): 3404 """Create a case operation. 3405 3406 See also `tf.switch_case`. 3407 3408 The `pred_fn_pairs` parameter is a dict or list of pairs of size N. 3409 Each pair contains a boolean scalar tensor and a python callable that 3410 creates the tensors to be returned if the boolean evaluates to True. 3411 `default` is a callable generating a list of tensors. All the callables 3412 in `pred_fn_pairs` as well as `default` (if provided) should return the same 3413 number and types of tensors. 3414 3415 If `exclusive==True`, all predicates are evaluated, and an exception is 3416 thrown if more than one of the predicates evaluates to `True`. 3417 If `exclusive==False`, execution stops at the first predicate which 3418 evaluates to True, and the tensors generated by the corresponding function 3419 are returned immediately. If none of the predicates evaluate to True, this 3420 operation returns the tensors generated by `default`. 3421 3422 `tf.case` supports nested structures as implemented in 3423 `tf.contrib.framework.nest`. All of the callables must return the same 3424 (possibly nested) value structure of lists, tuples, and/or named tuples. 3425 Singleton lists and tuples form the only exceptions to this: when returned by 3426 a callable, they are implicitly unpacked to single values. This 3427 behavior is disabled by passing `strict=True`. 3428 3429 If an unordered dictionary is used for `pred_fn_pairs`, the order of the 3430 conditional tests is not guaranteed. However, the order is guaranteed to be 3431 deterministic, so that variables created in conditional branches are created 3432 in fixed order across runs. 3433 3434 @compatibility(eager) 3435 Unordered dictionaries are not supported in eager mode when `exclusive=False`. 3436 Use a list of tuples instead. 3437 @end_compatibility 3438 3439 3440 **Example 1:** 3441 3442 Pseudocode: 3443 3444 ``` 3445 if (x < y) return 17; 3446 else return 23; 3447 ``` 3448 3449 Expressions: 3450 3451 ```python 3452 f1 = lambda: tf.constant(17) 3453 f2 = lambda: tf.constant(23) 3454 r = tf.case([(tf.less(x, y), f1)], default=f2) 3455 ``` 3456 3457 **Example 2:** 3458 3459 Pseudocode: 3460 3461 ``` 3462 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 3463 if (x < y) return 17; 3464 else if (x > z) return 23; 3465 else return -1; 3466 ``` 3467 3468 Expressions: 3469 3470 ```python 3471 def f1(): return tf.constant(17) 3472 def f2(): return tf.constant(23) 3473 def f3(): return tf.constant(-1) 3474 r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2}, 3475 default=f3, exclusive=True) 3476 ``` 3477 3478 Args: 3479 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 3480 callable which returns a list of tensors. 3481 default: Optional callable that returns a list of tensors. 3482 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3483 strict: A boolean that enables/disables 'strict' mode; see above. 3484 name: A name for this operation (optional). 3485 3486 Returns: 3487 The tensors returned by the first pair whose predicate evaluated to True, or 3488 those returned by `default` if none does. 3489 3490 Raises: 3491 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3492 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3493 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3494 callable. 3495 """ 3496 return _case_helper( 3497 cond, 3498 pred_fn_pairs, 3499 default, 3500 exclusive, 3501 name, 3502 allow_python_preds=False, 3503 strict=strict) 3504 3505 3506@tf_export("switch_case") 3507def switch_case(branch_index, 3508 branch_fns, 3509 default=None, 3510 name="switch_case"): 3511 """Create a switch/case operation, i.e. an integer-indexed conditional. 3512 3513 See also `tf.case`. 3514 3515 This op can be substantially more efficient than `tf.case` when exactly one 3516 branch will be selected. `tf.switch_case` is more like a C++ switch/case 3517 statement than `tf.case`, which is more like an if/elif/elif/else chain. 3518 3519 The `branch_fns` parameter is either a dict from `int` to callables, or list 3520 of (`int`, callable) pairs, or simply a list of callables (in which case the 3521 index is implicitly the key). The `branch_index` `Tensor` is used to select an 3522 element in `branch_fns` with matching `int` key, falling back to `default` 3523 if none match, or `max(keys)` if no `default` is provided. The keys must form 3524 a contiguous set from `0` to `len(branch_fns) - 1`. 3525 3526 `tf.switch_case` supports nested structures as implemented in `tf.nest`. All 3527 callables must return the same (possibly nested) value structure of lists, 3528 tuples, and/or named tuples. 3529 3530 **Example:** 3531 3532 Pseudocode: 3533 3534 ```c++ 3535 switch (branch_index) { // c-style switch 3536 case 0: return 17; 3537 case 1: return 31; 3538 default: return -1; 3539 } 3540 ``` 3541 or 3542 ```python 3543 branches = {0: lambda: 17, 1: lambda: 31} 3544 branches.get(branch_index, lambda: -1)() 3545 ``` 3546 3547 Expressions: 3548 3549 ```python 3550 def f1(): return tf.constant(17) 3551 def f2(): return tf.constant(31) 3552 def f3(): return tf.constant(-1) 3553 r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3) 3554 # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3}) 3555 ``` 3556 3557 Args: 3558 branch_index: An int Tensor specifying which of `branch_fns` should be 3559 executed. 3560 branch_fns: A `dict` mapping `int`s to callables, or a `list` of 3561 (`int`, callable) pairs, or simply a list of callables (in which case the 3562 index serves as the key). Each callable must return a matching structure 3563 of tensors. 3564 default: Optional callable that returns a structure of tensors. 3565 name: A name for this operation (optional). 3566 3567 Returns: 3568 The tensors returned by the callable identified by `branch_index`, or those 3569 returned by `default` if no key matches and `default` was provided, or those 3570 returned by the max-keyed `branch_fn` if no `default` is provided. 3571 3572 Raises: 3573 TypeError: If `branch_fns` is not a list/dictionary. 3574 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3575 callables. 3576 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3577 callable. 3578 """ 3579 return _indexed_case_helper(branch_fns, default, branch_index, name) 3580 3581 3582class XLAControlFlowContext(ControlFlowContext): 3583 """Base class for XLA and TPU control flow contexts.""" 3584 3585 def __init__(self): 3586 super(XLAControlFlowContext, self).__init__() 3587 self._name = "XLAControlFlowContext" 3588 3589 def to_control_flow_context_def(self, context_def, export_scope=None): 3590 # pylint: disable=useless-super-delegation 3591 # NOTE(slebedev): the method is required by `ControlFlowContext`. 3592 super(XLAControlFlowContext, 3593 self).to_control_flow_context_def(context_def, export_scope) 3594 3595 def IsXLAContext(self): 3596 return True 3597 3598 def AddOp(self, _): 3599 pass 3600 3601 def AddValue(self, x): 3602 return x 3603 3604 3605def from_control_flow_context_def(context_def, import_scope=None): 3606 """Deserializes `context_def` into the appropriate ControlFlowContext. 3607 3608 Args: 3609 context_def: ControlFlowContextDef proto 3610 import_scope: Optional `string`. Name scope to add. 3611 3612 Returns: 3613 A ControlFlowContext subclass 3614 """ 3615 if context_def.HasField("cond_ctxt"): 3616 return CondContext.from_proto( 3617 context_def.cond_ctxt, import_scope=import_scope) 3618 if context_def.HasField("while_ctxt"): 3619 return WhileContext.from_proto( 3620 context_def.while_ctxt, import_scope=import_scope) 3621 raise NotImplementedError("Unknown ControlFlowContextDef field: %s" % 3622 context_def.WhichOneof("ctxt")) 3623 3624 3625ops.register_proto_function( 3626 ops.GraphKeys.COND_CONTEXT, 3627 proto_type=control_flow_pb2.CondContextDef, 3628 to_proto=CondContext.to_proto, 3629 from_proto=CondContext.from_proto) 3630 3631ops.register_proto_function( 3632 ops.GraphKeys.WHILE_CONTEXT, 3633 proto_type=control_flow_pb2.WhileContextDef, 3634 to_proto=WhileContext.to_proto, 3635 from_proto=WhileContext.from_proto) 3636