1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Control Flow Operations. 16 17See the [autograph](https://www.tensorflow.org/guide/autograph) guide. 18""" 19# pylint: disable=g-bad-name 20import abc 21import collections 22import functools 23 24from tensorflow.core.framework import attr_value_pb2 25from tensorflow.core.protobuf import control_flow_pb2 26from tensorflow.python.eager import context 27from tensorflow.python.framework import composite_tensor 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import indexed_slices 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_spec 35from tensorflow.python.framework import tensor_util 36from tensorflow.python.framework import type_spec 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_util as util 39from tensorflow.python.ops import gen_array_ops 40from tensorflow.python.ops import gen_control_flow_ops 41from tensorflow.python.ops import gen_functional_ops 42from tensorflow.python.ops import gen_logging_ops 43from tensorflow.python.ops import gen_math_ops 44from tensorflow.python.ops import math_ops 45from tensorflow.python.ops import tensor_array_ops 46# go/tf-wildcard-import 47# pylint: disable=wildcard-import,undefined-variable 48from tensorflow.python.ops.gen_control_flow_ops import * 49# pylint: enable=wildcard-import 50from tensorflow.python.platform import tf_logging as logging 51from tensorflow.python.util import compat 52from tensorflow.python.util import deprecation 53from tensorflow.python.util import dispatch 54from tensorflow.python.util import nest 55from tensorflow.python.util import tf_should_use 56from tensorflow.python.util import variable_utils 57from tensorflow.python.util.lazy_loader import LazyLoader 58from tensorflow.python.util.tf_export import tf_export 59 60# This is to avoid a circular dependency: 61# cond_v2 -> gradients_util -> control_flow_ops 62cond_v2 = LazyLoader("cond_v2", globals(), 63 "tensorflow.python.ops.cond_v2") 64 65# This is to avoid circular dependencies: 66# while_v2 -> control_flow_ops 67# while_v2 -> gradients_util -> control_flow_ops 68while_v2 = LazyLoader("while_v2", globals(), 69 "tensorflow.python.ops.while_v2") 70 71# def_function also uses cond 72def_function = LazyLoader( 73 "def_function", globals(), 74 "tensorflow.python.eager.def_function") 75 76 77# We override the 'tuple' for a control flow op, so we keep python's 78# existing 'tuple' for later use in this module. 79_basetuple = tuple 80 81 82def _summarize_eager(tensor, summarize=None): 83 """Returns a summarized string representation of eager `tensor`. 84 85 Args: 86 tensor: EagerTensor to summarize 87 summarize: Include these many first elements of `array` 88 """ 89 # Emulate the behavior of Tensor::SummarizeValue() 90 if summarize is None: 91 summarize = 3 92 elif summarize < 0: 93 summarize = array_ops.size(tensor) 94 95 # reshape((-1,)) is the fastest way to get a flat array view 96 if tensor._rank(): # pylint: disable=protected-access 97 flat = tensor.numpy().reshape((-1,)) 98 lst = [str(x) for x in flat[:summarize]] 99 if len(lst) < flat.size: 100 lst.append("...") 101 else: 102 # tensor.numpy() returns a scalar for zero dimensional arrays 103 if gen_math_ops.not_equal(summarize, 0): 104 lst = [str(tensor.numpy())] 105 else: 106 lst = [] 107 108 return ", ".join(lst) 109 110 111# pylint: disable=protected-access 112 113 114# Assert and Print are special symbols in python, so we must 115# use an upper-case version of them. 116@tf_export("debugging.Assert", "Assert") 117@dispatch.add_dispatch_support 118@tf_should_use.should_use_result 119def Assert(condition, data, summarize=None, name=None): 120 """Asserts that the given condition is true. 121 122 If `condition` evaluates to false, print the list of tensors in `data`. 123 `summarize` determines how many entries of the tensors to print. 124 125 Args: 126 condition: The condition to evaluate. 127 data: The tensors to print out when condition is false. 128 summarize: Print this many entries of each tensor. 129 name: A name for this operation (optional). 130 131 Returns: 132 assert_op: An `Operation` that, when executed, raises a 133 `tf.errors.InvalidArgumentError` if `condition` is not true. 134 @compatibility(eager) 135 returns None 136 @end_compatibility 137 138 Raises: 139 @compatibility(TF1) 140 When in TF V1 mode (that is, outside `tf.function`) Assert needs a control 141 dependency on the output to ensure the assertion executes: 142 143 ```python 144 # Ensure maximum element of x is smaller or equal to 1 145 assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x]) 146 with tf.control_dependencies([assert_op]): 147 ... code using x ... 148 ``` 149 150 @end_compatibility 151 """ 152 if context.executing_eagerly(): 153 if not condition: 154 xs = ops.convert_n_to_tensor(data) 155 data_str = [_summarize_eager(x, summarize) for x in xs] 156 raise errors.InvalidArgumentError( 157 node_def=None, 158 op=None, 159 message="Expected '%s' to be true. Summarized data: %s" % 160 (condition, "\n".join(data_str))) 161 return 162 163 with ops.name_scope(name, "Assert", [condition, data]) as name: 164 xs = ops.convert_n_to_tensor(data) 165 if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs): 166 # As a simple heuristic, we assume that string and int32 are 167 # on host to avoid the need to use cond. If it is not case, 168 # we will pay the price copying the tensor to host memory. 169 return gen_logging_ops._assert(condition, data, summarize, name="Assert") 170 else: 171 condition = ops.convert_to_tensor(condition, name="Condition") 172 173 def true_assert(): 174 return gen_logging_ops._assert( 175 condition, data, summarize, name="Assert") 176 177 guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard") 178 if context.executing_eagerly(): 179 return 180 return guarded_assert.op 181 182 183def _Identity(tensor, name=None): 184 """Return a tensor with the same shape and contents as the input tensor. 185 186 Args: 187 tensor: A Tensor. 188 name: A name for this operation (optional). 189 190 Returns: 191 A Tensor with the same type and value as the input Tensor. 192 """ 193 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 194 if isinstance(tensor, ops.Tensor): 195 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 196 return gen_array_ops.ref_identity(tensor, name=name) 197 else: 198 return array_ops.identity(tensor, name=name) 199 elif isinstance(tensor, composite_tensor.CompositeTensor): 200 return nest.map_structure(_Identity, tensor, expand_composites=True) 201 else: 202 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 203 f"Received: {type(tensor)}.") 204 205 206def _NextIteration(tensor, name=None): 207 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 208 if isinstance(tensor, ops.Tensor): 209 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 210 return ref_next_iteration(tensor, name=name) 211 else: 212 return next_iteration(tensor, name=name) 213 elif isinstance(tensor, composite_tensor.CompositeTensor): 214 return nest.map_structure(_NextIteration, tensor, expand_composites=True) 215 else: 216 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 217 f"Received: {type(tensor)}.") 218 219 220def _Enter(tensor, 221 frame_name, 222 is_constant=False, 223 parallel_iterations=10, 224 use_ref=True, 225 use_input_shape=True, 226 name=None): 227 """Creates or finds a child frame, and makes `tensor` available to it. 228 229 The unique `frame_name` is used by the `Executor` to identify frames. If 230 `is_constant` is true, `tensor` is a constant in the child frame; otherwise 231 it may be changed in the child frame. At most `parallel_iterations` 232 iterations are run in parallel in the child frame. 233 234 Args: 235 tensor: The tensor to be made available to the child frame. 236 frame_name: The name of the child frame. 237 is_constant: If true, the output is constant within the child frame. 238 parallel_iterations: The number of iterations allowed to run in parallel. 239 use_ref: If true, use ref_enter if tensor is of ref type. 240 use_input_shape: If true, set the result's shape based on tensor's shape. 241 name: A name for this operation (optional). 242 243 Returns: 244 The same tensor as `tensor`. 245 246 Raises: 247 ValueError: If any tensor in `tensor` has a less specific shape 248 than its corresponding shape in `shape_invariant`. 249 """ 250 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 251 if isinstance(tensor, ops.Tensor): 252 if tensor.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access 253 result = gen_control_flow_ops.ref_enter( 254 tensor, frame_name, is_constant, parallel_iterations, name=name) 255 else: 256 result = gen_control_flow_ops.enter( 257 tensor, frame_name, is_constant, parallel_iterations, name=name) 258 if use_input_shape: 259 result.set_shape(tensor.get_shape()) 260 return result 261 elif isinstance(tensor, composite_tensor.CompositeTensor): 262 263 def enter_component(t): 264 return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref, 265 use_input_shape) 266 267 return nest.map_structure(enter_component, tensor, expand_composites=True) 268 else: 269 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 270 f"Received: {type(tensor)}.") 271 272 273def exit(tensor, name=None): # pylint: disable=redefined-builtin 274 """Exits the current frame to its parent frame. 275 276 Exit makes its input `tensor` available to the parent frame. 277 278 Args: 279 tensor: The tensor to be made available to the parent frame. 280 name: A name for this operation (optional). 281 282 Returns: 283 The same tensor as `tensor`. 284 """ 285 tensor = ops.internal_convert_to_tensor_or_composite(tensor, as_ref=True) 286 if isinstance(tensor, ops.Tensor): 287 if tensor.dtype._is_ref_dtype: # pylint: disable=protected-access 288 return gen_control_flow_ops.ref_exit(tensor, name) 289 else: 290 return gen_control_flow_ops._exit(tensor, name) 291 elif isinstance(tensor, composite_tensor.CompositeTensor): 292 return nest.map_structure(exit, tensor, expand_composites=True) 293 else: 294 raise TypeError("'tensor' must be a Tensor or CompositeTensor. " 295 f"Received: {type(tensor)}.") 296 297 298def switch(data, pred, dtype=None, name=None): 299 """Forwards `data` to an output determined by `pred`. 300 301 If `pred` is false, the `data` input is forwarded to the first output. 302 Otherwise, the data goes to the second output. 303 304 This op handles `Tensor`s and `IndexedSlices`. 305 306 Args: 307 data: The tensor to be forwarded to the appropriate output. 308 pred: A scalar that specifies which output port will receive data. 309 dtype: Optional element type for the returned tensor. If missing, the type 310 is inferred from the type of `value`. 311 name: A name for this operation (optional). 312 313 Returns: 314 `(output_false, output_true)`: If `pred` is true, data will be forwarded 315 to `output_true`, otherwise it goes to `output_false`. 316 """ 317 with ops.name_scope(name, "Switch", [data, pred]) as name: 318 data = ops.internal_convert_to_tensor_or_composite( 319 data, dtype=dtype, name="data", as_ref=True) 320 pred = ops.convert_to_tensor(pred, name="pred") 321 if isinstance(data, ops.Tensor): 322 return gen_control_flow_ops.switch(data, pred, name=name) 323 else: 324 if not isinstance(data, composite_tensor.CompositeTensor): 325 raise TypeError( 326 "'data' must be a Tensor or CompositeTensor. " 327 f"Received: {type(data)}.") 328 tensors = nest.flatten(data, expand_composites=True) 329 mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors] 330 mapped_f, mapped_t = zip(*mapped) 331 return (nest.pack_sequence_as(data, mapped_f, expand_composites=True), 332 nest.pack_sequence_as(data, mapped_t, expand_composites=True)) 333 334 335def _SwitchRefOrTensor(data, pred, name="Switch"): 336 """Forwards `data` to an output determined by `pred`. 337 338 If `pred` is false, the `data` input is forwarded to the first output. 339 Otherwise, the data goes to the second output. 340 341 This op handles `Tensor`s and `IndexedSlices`. 342 343 Args: 344 data: The tensor to be forwarded to the appropriate output. 345 pred: A scalar that specifies which output port will receive data. 346 name: A name for this operation (optional). 347 348 Returns: 349 `(output_false, output_true)`: If `pred` is true, data will be forwarded to 350 `output_true`, otherwise it goes to `output_false`. 351 352 Raises: 353 TypeError: if data is not a Tensor or IndexedSlices 354 """ 355 data = ops.convert_to_tensor_or_composite(data, name="data") 356 # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below 357 # addresses the following scenario. 358 # 359 # Assume you execute Optimizer.apply_gradients() in a branch of a cond(). 360 # 361 # 1. The update op is created inside a `with ops.colocate(var):` block 362 # 363 # 2. Some tensor `data` is captured and a switch is created in a 364 # `with ops.colocate_with(data):` block. 365 # 366 # with ops.colocate_with(var): 367 # with ops.colocate_with(data): 368 # op = ... 369 # 370 # var and data may be pinned to different devices, so we want to ops 371 # created within ops.colocate_with(data) to ignore the existing stack. 372 with ops.colocate_with(data, ignore_existing=True): 373 if isinstance(data, ops.Tensor): 374 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 375 return ref_switch(data, pred, name=name) 376 return switch(data, pred, name=name) 377 378 379def merge(inputs, name=None): 380 """Returns the value of an available element of `inputs`. 381 382 This op tests each of the tensors in `inputs` in turn to determine if any of 383 them is available. If it finds an available tensor, it returns it and its 384 index in `inputs`. 385 386 It is an error if more than one tensor in `inputs` is available. If no tensor 387 in `inputs` is available, the returned tensor and index are not set. 388 389 This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of 390 `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices 391 before merging. 392 393 Args: 394 inputs: The input tensors, at most one of which is available. 395 name: A name for this operation (optional). 396 397 Returns: 398 A tuple containing the chosen input tensor and its index in `inputs`. 399 400 Raises: 401 ValueError: If any of the inputs is None, or inputs are IndexedSlices and 402 some but not all have a dense_shape property. 403 """ 404 if any(inp is None for inp in inputs): 405 raise ValueError("At least one of the merge inputs is None: %s" % inputs) 406 with ops.name_scope(name, "Merge", inputs) as name: 407 inputs = [ 408 ops.internal_convert_to_tensor_or_composite(inp, as_ref=True) 409 for inp in inputs 410 ] 411 if all(isinstance(v, ops.Tensor) for v in inputs): 412 if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access 413 return gen_control_flow_ops.ref_merge(inputs, name) 414 else: 415 return gen_control_flow_ops.merge(inputs, name) 416 else: 417 # If there is a mix of tensors and indexed slices, then convert the 418 # tensors to indexed slices. 419 if all( 420 isinstance(v, (indexed_slices.IndexedSlices, ops.Tensor)) 421 for v in inputs): 422 inputs = math_ops._as_indexed_slices_list(inputs, optimize=False) 423 424 for v in inputs: 425 if not isinstance(v, composite_tensor.CompositeTensor): 426 raise TypeError("Type %s not supported" % type(v)) 427 428 for v in inputs[1:]: 429 nest.assert_same_structure(inputs[0], v, expand_composites=True) 430 431 flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs] 432 merged_results = [ 433 gen_control_flow_ops.merge(component) 434 for component in zip(*flat_inputs) 435 ] 436 flat_merged = [tensor for (tensor, _) in merged_results] 437 chosen_index = merged_results[0][1] 438 merged_inputs = nest.pack_sequence_as( 439 inputs[0], flat_merged, expand_composites=True) 440 return (merged_inputs, chosen_index) 441 442 443def _convert_tensorarray_to_flow(tensor_or_tensor_array): 444 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray): 445 return tensor_or_tensor_array.flow 446 else: 447 return tensor_or_tensor_array 448 449 450def _convert_flow_to_tensorarray(tensor_or_tensor_array, tensor_or_flow): 451 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray): 452 return tensor_array_ops.build_ta_with_new_flow(tensor_or_tensor_array, 453 tensor_or_flow) 454 else: 455 return tensor_or_flow 456 457 458def _convert_to_tensor_or_composite_or_tensorarray(var): 459 if isinstance(var, tensor_array_ops.TensorArray): 460 return var 461 return ops.convert_to_tensor_or_composite(var) 462 463 464# TODO(xjun): replace this with is_subtype_of after it is landed. 465def _ShapeLessThanOrEqual(shape1, shape2): 466 if shape2.dims is None: 467 return True 468 if shape1.ndims != shape2.ndims: 469 return False 470 for dim1, dim2 in zip(shape1.dims, shape2.dims): 471 if dim2.value is not None and dim1.value != dim2.value: 472 return False 473 return True 474 475 476def _shape_invariant_to_type_spec(var, shape=None): 477 """Converts a shape invariant to a TypeSpec. 478 479 If `var` is a TensorArray, it will first be converted to its flow. 480 481 Args: 482 var: The tensor, tensor array or composite tensor whose shape is described 483 by the shape invariant. 484 shape: A `TypeSpec` or `TensorShape`. If `shape` is already a `TypeSpec`, 485 then it is simply returned as-is. 486 487 Returns: 488 A `TypeSpec` for `var`, consistent with the given shape. 489 490 Raises: 491 TypeError: If `shape` is a TypeSpec and not compatible with `var`. 492 TypeError: If `shape` is not None, a TypeSpec, or a TensorShape. 493 TypeError: If `shape` is a TensorShape, `var` is a CompositeTensor, and 494 `var` doesn't implement the `_shape_invariant_to_type_spec` method. 495 """ 496 var = _convert_tensorarray_to_flow(var) 497 if shape is None: 498 return type_spec.type_spec_from_value(var) 499 elif isinstance(shape, type_spec.TypeSpec): 500 if not shape.is_compatible_with(var): 501 raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var)) 502 return shape 503 elif not isinstance(shape, tensor_shape.TensorShape): 504 raise TypeError( 505 "'shape' must be one of TypeSpec, TensorShape or None. " 506 f"Received: {type(shape)}") 507 508 if isinstance(var, ops.Tensor): 509 return tensor_spec.TensorSpec(shape, var.dtype) 510 else: 511 try: 512 return var._shape_invariant_to_type_spec(shape) # pylint: disable=protected-access 513 except NotImplementedError as e: 514 raise TypeError( 515 f"To describe or constrain a {type(var).__name__}, use a " 516 f"{type(var._type_spec).__name__} instead of a TensorShape.") from e # pylint: disable=protected-access 517 518 519def _EnforceShapeInvariant(merge_var, next_var): 520 """Check if the shapes of the loops variables are invariants. 521 522 Args: 523 merge_var: The tensor representing the initial values of the loop 524 variables. 525 next_var: The tensor representing the values of the loop variables 526 after one loop iteration. 527 528 Raises: 529 ValueError: If any tensor in `merge_var` has a more specific shape than 530 its corresponding tensor in `next_var`. 531 """ 532 if isinstance(merge_var, ops.Tensor): 533 m_shape = merge_var.get_shape() 534 n_shape = next_var.get_shape() 535 if not _ShapeLessThanOrEqual(n_shape, m_shape): 536 enter = merge_var.op.inputs[0].op 537 assert util.IsLoopEnter(enter) 538 input_t = enter.inputs[0] 539 raise ValueError( 540 "Input tensor '%s' enters the loop with shape %s, but has shape %s " 541 "after one iteration. To allow the shape to vary across iterations, " 542 "use the `shape_invariants` argument of tf.while_loop to specify a " 543 "less-specific shape." % (input_t.name, input_t.shape, n_shape)) 544 else: 545 raise TypeError("'merge_var' must be a Tensor. " 546 f"Received: {type(merge_var)}.") 547 548 549def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): 550 """Add NextIteration and back edge from v to m.""" 551 if isinstance(m, ops.Tensor): 552 v = ops.convert_to_tensor(v) 553 v = _NextIteration(v) 554 if enforce_shape_invariant: 555 # Make sure the shapes of loop outputs are correct. We do this before 556 # calling _update_input, which will raise a less-helpful error message if 557 # the types don't match. 558 # TODO(skyewm): call this for other cases below (needs testing) 559 _EnforceShapeInvariant(m, v) 560 m.op._update_input(1, v) # pylint: disable=protected-access 561 elif isinstance(m, composite_tensor.CompositeTensor): 562 # pylint: disable=protected-access 563 def update_component(m_component, v_component): 564 m_component.op._update_input(1, v_component) 565 566 if isinstance(m, indexed_slices.IndexedSlices): 567 v = math_ops._as_indexed_slices(v, optimize=False) 568 # pylint: enable=protected-access 569 v = _NextIteration(v) 570 return nest.map_structure(update_component, m, v, expand_composites=True) 571 else: 572 raise TypeError("'m' must be a Tensor or CompositeTensor. " 573 f"Received: {type(m)}.") 574 return v 575 576 577class ControlFlowContext(metaclass=abc.ABCMeta): 578 """The base class for control flow context. 579 580 The usage pattern is a sequence of (Enter, Exit) followed by a final 581 ExitResult. 582 583 We maintain the following state for control flow contexts during graph 584 construction: 585 1. graph has _control_flow_context: the current context used to 586 construct new nodes. Changed by ctxt.Enter() and ctxt.Exit() 587 2. op has _control_flow_context: the context to which the op belongs. 588 Set at the time the op is created. Immutable. 589 3. A ControlFlowContext has _outer_context: the context in which this 590 context is created. Set at the time a context is created. Immutable. 591 4. A ControlFlowContext has _context_stack. 592 Pushed and popped by ctxt.Enter() and ctxt.Exit() 593 """ 594 595 def __init__(self, values_def=None, import_scope=None): 596 self._nested_contexts = [] 597 self._outer_context = ops.get_default_graph()._get_control_flow_context() 598 if self._outer_context: 599 self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access 600 self._context_stack = [] 601 if values_def: 602 self._init_values_from_proto(values_def, import_scope=import_scope) 603 else: 604 # The names of tensors that have been already seen in this context. 605 self._values = set() 606 # The keys are the names of tensors referenced by but external to this 607 # context. Each value is the Tensor that should be used by this context to 608 # access the key value (e.g. a switch output guarding a cond input value). 609 self._external_values = {} 610 611 def _init_values_from_proto(self, values_def, import_scope=None): 612 """Initializes values and external_values from `ValuesDef` protocol buffer. 613 614 Args: 615 values_def: `ValuesDef` protocol buffer. 616 import_scope: Optional `string`. Name scope to add. 617 """ 618 assert isinstance(values_def, control_flow_pb2.ValuesDef) 619 self._values = set( 620 ops.prepend_name_scope(value, import_scope) 621 for value in values_def.values) 622 g = ops.get_default_graph() 623 self._external_values = {} 624 for k, v in values_def.external_values.items(): 625 k = ops.prepend_name_scope(k, import_scope) 626 self._external_values[k] = g.as_graph_element( 627 ops.prepend_name_scope(v, import_scope)) 628 op_names = set([ 629 op.split(":")[0] 630 for op in self._values - set(self._external_values.keys()) 631 ]) 632 for op in op_names: 633 # pylint: disable=protected-access 634 g.as_graph_element(op)._set_control_flow_context(self) 635 # pylint: enable=protected-access 636 637 @property 638 def name(self): 639 return self._name 640 641 @property 642 def outer_context(self): 643 """Return the context containing this context.""" 644 return self._outer_context 645 646 @property 647 def grad_state(self): 648 raise NotImplementedError("Abstract method") 649 650 @property 651 def back_prop(self): 652 raise NotImplementedError("Abstract method") 653 654 @abc.abstractmethod 655 def to_control_flow_context_def(self, context_def, export_scope=None): 656 """Serializes this into `context_def`. 657 658 Args: 659 context_def: a `ControlFlowContextDef` protocol buffer. 660 export_scope: Optional `string`. Name scope to remove. 661 """ 662 raise NotImplementedError("Abstract method") 663 664 def _to_values_def(self, export_scope=None): 665 """Converts the values to a `ValuesDef` protocol buffer. 666 667 Args: 668 export_scope: Optional `string`. Name scope to remove. 669 670 Returns: 671 A `ValuesDef` protocol buffer. 672 """ 673 values_def = control_flow_pb2.ValuesDef() 674 values_def.values.extend( 675 [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)]) 676 for k, v in self._external_values.items(): 677 k = ops.strip_name_scope(k, export_scope) 678 values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope) 679 return values_def 680 681 def AddName(self, name): 682 self._values.add(name) 683 684 # pylint: disable=protected-access 685 def Enter(self): 686 """Enter this control flow context.""" 687 graph = ops.get_default_graph() 688 self._context_stack.append(graph._get_control_flow_context()) 689 graph._set_control_flow_context(self) 690 691 def Exit(self): 692 """Exit this control flow context.""" 693 graph = ops.get_default_graph() 694 last_context = self._context_stack.pop() 695 graph._set_control_flow_context(last_context) 696 697 def EnterGradientColocation(self, op, gradient_uid): 698 """Start building a gradient colocated with an op.""" 699 if self._outer_context: 700 self._outer_context.EnterGradientColocation(op, gradient_uid) 701 702 def ExitGradientColocation(self, op, gradient_uid): 703 """Start building a gradient colocated with an op.""" 704 if self._outer_context: 705 self._outer_context.ExitGradientColocation(op, gradient_uid) 706 707 def ExitResult(self, result): 708 """Make a list of tensors available in the outer context.""" 709 if self._outer_context: 710 def fn(x): 711 self._outer_context.AddName(x.name) 712 return x 713 nest.map_structure(fn, result, expand_composites=True) 714 715 def GetWhileContext(self): 716 """Return the while context containing this context.""" 717 if self._outer_context: 718 return self._outer_context.GetWhileContext() 719 return None 720 721 def _RemoveExternalControlEdges(self, op): 722 """Remove any external control dependency on this op.""" 723 while_ctxt = self.GetWhileContext() 724 # A control input of `op` is internal if it is in the same while 725 # loop context as the enclosing while loop context of self. 726 if while_ctxt is None: 727 internal_control_inputs, external_control_inputs = op.control_inputs, [] 728 else: 729 internal_control_inputs, external_control_inputs = [], [] 730 for x in op.control_inputs: 731 ctxt = util.GetOutputContext(x) 732 if ctxt is not None and ctxt.GetWhileContext() == while_ctxt: 733 internal_control_inputs.append(x) 734 else: 735 external_control_inputs.append(x) 736 if len(internal_control_inputs) != len(op.control_inputs): 737 # TODO(mdan): perhaps there should be a replace_control_inputs() 738 op._remove_all_control_inputs() 739 op._add_control_inputs(internal_control_inputs) 740 return internal_control_inputs, external_control_inputs 741 742 # pylint: enable=protected-access 743 744 def AddInnerOp(self, op): 745 """Notifies a scope about an operator added to an inner scope.""" 746 if self._outer_context: 747 self._outer_context.AddInnerOp(op) 748 749 def GetControlPivot(self): 750 """Returns the pivot node for this context, or None.""" 751 return None 752 753 def IsWhileContext(self): 754 return False 755 756 def IsCondContext(self): 757 return False 758 759 def IsXLAContext(self): 760 return False 761 762 def __str__(self): 763 return self.name 764 765 766class CondContext(ControlFlowContext): 767 """The context for the conditional construct.""" 768 769 def __init__(self, 770 pred=None, 771 pivot=None, 772 branch=None, 773 name="cond_text", 774 context_def=None, 775 import_scope=None): 776 """Creates a `CondContext`. 777 778 Args: 779 pred: The `boolean` tensor for the conditional predicate. 780 pivot: The predicate tensor in this branch. 781 branch: 0 or 1 representing this branch. 782 name: Name of the `CondContext` python object. 783 context_def: Optional `ContextDef` protocol buffer to initialize the 784 `CondContext` object from. 785 import_scope: Optional `string`. Name scope to add. Only used when 786 initialing from protocol buffer. 787 """ 788 self._name = ops.get_default_graph().unique_name(name) 789 790 if context_def: 791 self._init_from_proto(context_def, import_scope=import_scope) 792 else: 793 # Initializes the default fields. 794 ControlFlowContext.__init__(self) 795 self._pred = pred # The boolean tensor for the cond predicate 796 self._pivot = pivot # The predicate tensor in this branch 797 self._branch = branch # 0 or 1 representing this branch 798 799 # Values considered to have been already seen in this context. pred is not 800 # included in this context. 801 self._values.add(pred.name) 802 self._external_values[pred.name] = pred 803 self._values.add(pivot.name) 804 pivot.op._set_control_flow_context(self) # pylint: disable=protected-access 805 806 def _init_from_proto(self, context_def, import_scope=None): 807 """Creates a new `CondContext` from protocol buffer. 808 809 Args: 810 context_def: `CondContextDef` protocol buffer. 811 import_scope: Optional `string`. Name scope to add. 812 """ 813 assert isinstance(context_def, control_flow_pb2.CondContextDef) 814 # Create from context_def. 815 g = ops.get_default_graph() 816 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 817 self._pred = g.as_graph_element( 818 ops.prepend_name_scope(context_def.pred_name, import_scope)) 819 self._pivot = g.as_graph_element( 820 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 821 self._branch = context_def.branch 822 super(CondContext, self).__init__( 823 values_def=context_def.values_def, import_scope=import_scope) 824 825 @property 826 def pred(self): 827 return self._pred 828 829 @property 830 def pivot(self): 831 return self._pivot 832 833 @property 834 def branch(self): 835 return self._branch 836 837 @property 838 def grad_state(self): 839 if self.GetWhileContext(): 840 return self.GetWhileContext().grad_state 841 return None 842 843 @property 844 def back_prop(self): 845 if self.GetWhileContext(): 846 self.GetWhileContext().back_prop 847 return False 848 849 def GetControlPivot(self): 850 return self._pivot 851 852 def to_proto(self, export_scope=None): 853 """Converts a `CondContext` to a `CondContextDef` protocol buffer. 854 855 Args: 856 export_scope: Optional `string`. Name scope to remove. 857 858 Returns: 859 A `CondContextDef` protocol buffer. 860 """ 861 if (export_scope is None or self.name.startswith(export_scope)): 862 context_def = control_flow_pb2.CondContextDef() 863 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 864 context_def.pred_name = ops.strip_name_scope(self._pred.name, 865 export_scope) 866 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 867 export_scope) 868 context_def.branch = self._branch 869 context_def.values_def.MergeFrom( 870 super(CondContext, self)._to_values_def(export_scope)) 871 for nested in self._nested_contexts: 872 nested_def = context_def.nested_contexts.add() 873 nested.to_control_flow_context_def(nested_def) 874 875 return context_def 876 else: 877 return None 878 879 @staticmethod 880 def from_proto(context_def, import_scope=None): 881 """Returns a `CondContext` object created from `context_def`.""" 882 ret = CondContext(context_def=context_def, import_scope=import_scope) 883 884 ret.Enter() 885 for nested_def in context_def.nested_contexts: 886 from_control_flow_context_def(nested_def, import_scope=import_scope) 887 ret.Exit() 888 return ret 889 890 def to_control_flow_context_def(self, context_def, export_scope=None): 891 context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 892 893 def AddValue(self, val): 894 """Add `val` to the current context and its outer context recursively.""" 895 if val.name in self._values: 896 # Use the real value if it comes from outer context. This is needed in 897 # particular for nested conds. 898 result = self._external_values.get(val.name) 899 result = val if result is None else result 900 else: 901 result = val 902 self._values.add(val.name) 903 if self._outer_context: 904 result = self._outer_context.AddValue(val) 905 self._values.add(result.name) 906 self._external_values[result.name] = result 907 with ops.control_dependencies(None): 908 result = _SwitchRefOrTensor(result, self._pred)[self._branch] 909 if self._outer_context: 910 self._outer_context.AddInnerOp(result.op) 911 912 result.op.graph.prevent_fetching(result.op) 913 # pylint: disable=protected-access 914 result.op._set_control_flow_context(self) 915 # pylint: enable=protected-access 916 917 # Mark Switch output as seen by this context and any outer contexts, 918 # just like what we do for normal op outputs in _AddOpInternal() below. 919 ctxt = self 920 while ctxt is not None: 921 # pylint: disable=protected-access 922 ctxt._values.add(result.name) 923 ctxt = ctxt._outer_context 924 # pylint: enable=protected-access 925 926 self._external_values[val.name] = result 927 return result 928 929 def AddOp(self, op): 930 self._AddOpInternal(op) 931 932 def _AddOpInternal(self, op): 933 """Add `op` to the current context.""" 934 if not op.inputs: 935 # If we're in a while loop, remove any control inputs from outside the 936 # loop. 937 self._RemoveExternalControlEdges(op) 938 939 if not any( 940 util.OpInContext(input_op, self) for input_op in op.control_inputs): 941 # pylint: disable=protected-access 942 op._add_control_input(self._pivot.op) 943 # pylint: enable=protected-access 944 else: 945 # Make each input to 'op' available in this CondContext. If an input is 946 # already part of this context there's nothing to do, but if it's 947 # external, AddValue() will handle adding the appropriate Switch node and 948 # other bookkeeping. 949 for index in range(len(op.inputs)): 950 x = op.inputs[index] 951 if op.type == "Merge" and x.op.type == "NextIteration": 952 # Edge case: if we're importing a while loop inside this CondContext, 953 # AddValue() will not correctly handle the NextIteration inputs to 954 # Merge node. The problem is that the NextIteration should also be 955 # part of this context, but if we're importing it won't have been 956 # processed and added to the context yet, so AddValue() will try to 957 # add a Switch which results in an invalid graph. Instead, we use the 958 # NextIteration input as-is here, and it will eventually be added to 959 # the context via AddOp(). 960 real_x = x 961 else: 962 real_x = self.AddValue(x) 963 if real_x != x: 964 # pylint: disable=protected-access 965 op._update_input(index, real_x) 966 # pylint: enable=protected-access 967 # Remove any external control dependency on this op. 968 self._RemoveExternalControlEdges(op) 969 # pylint: disable=protected-access 970 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 971 op._add_control_input(self._pivot.op) 972 # pylint: enable=protected-access 973 974 # Mark op's outputs as seen by this context and any outer contexts. 975 output_names = [x.name for x in op.outputs] 976 ctxt = self 977 while ctxt is not None: 978 # pylint: disable=protected-access 979 ctxt._values.update(output_names) 980 ctxt = ctxt._outer_context 981 # pylint: enable=protected-access 982 983 if self._outer_context or not util.IsLoopExit(op): 984 op.graph.prevent_fetching(op) 985 986 if self._outer_context: 987 self._outer_context.AddInnerOp(op) 988 989 def _ProcessOutputTensor(self, val): 990 """Process an output tensor of a conditional branch.""" 991 real_val = val 992 if val.name not in self._values: 993 # Handle the special case of lambda: x 994 self._values.add(val.name) 995 if self._outer_context: 996 real_val = self._outer_context.AddValue(val) 997 self._values.add(real_val.name) 998 self._external_values[real_val.name] = real_val 999 real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch] 1000 self._external_values[val.name] = real_val 1001 else: 1002 external_val = self._external_values.get(val.name) 1003 if external_val is not None: 1004 real_val = external_val 1005 return real_val 1006 1007 def _BuildCondTensor(self, v): 1008 if isinstance(v, ops.Operation): 1009 # Use pivot as the proxy for this op. 1010 return with_dependencies([v], self._pivot) 1011 else: 1012 v = nest.map_structure( 1013 _convert_tensorarray_to_flow, v, expand_composites=True) 1014 return self._ProcessOutputTensor(ops.convert_to_tensor(v)) 1015 1016 def BuildCondBranch(self, fn): 1017 """Add the subgraph defined by fn() to the graph.""" 1018 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1019 original_result = fn() 1020 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1021 if len(post_summaries) > len(pre_summaries): 1022 new_summaries = post_summaries[len(pre_summaries):] 1023 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1024 summary_ref[:] = pre_summaries 1025 with ops.control_dependencies(new_summaries): 1026 if original_result is None: 1027 return no_op(), None 1028 elif not isinstance(original_result, ops.Operation): 1029 original_result = variable_utils.convert_variables_to_tensors( 1030 original_result) 1031 original_result = nest.map_structure( 1032 array_ops.identity, original_result, expand_composites=True) 1033 if original_result is None: 1034 return None, None 1035 1036 original_result = variable_utils.convert_variables_to_tensors( 1037 original_result) 1038 result = nest.map_structure( 1039 self._BuildCondTensor, original_result, expand_composites=True) 1040 if not isinstance(result, (list, _basetuple)): 1041 result = [result] 1042 return original_result, result 1043 1044 def IsCondContext(self): 1045 return True 1046 1047 1048def _UnpackIfSingleton(res): 1049 if isinstance(res, (list, _basetuple)) and len(res) == 1: 1050 return res[0] 1051 else: 1052 return res 1053 1054 1055def _eager_cond_implementation(pred, true_fn, false_fn, strict, name): 1056 """Special cases for `cond` when executing eagerly.""" 1057 pred = ops.convert_to_tensor(pred) 1058 pred_constant_value = tensor_util.constant_value(pred) 1059 if pred_constant_value is None: 1060 # Eager tensors from a parallel device may not have a constant 1061 # value. Running the cond op itself would work, but we don't have logic to 1062 # build cond ops without wrapping in a function first. 1063 if (not isinstance(true_fn, def_function.Function) 1064 or not isinstance(false_fn, def_function.Function)): 1065 raise TypeError("When running tf.cond on a parallel device, 'true_fn' " 1066 "and 'false_fn' must be decorated with `tf.function`.") 1067 @def_function.function 1068 def _parallel_device_cond_wrapper(): 1069 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 1070 functions_run_eagerly = def_function.functions_run_eagerly() 1071 if functions_run_eagerly: 1072 # We need to use tf.function to deal with variable creation inside the 1073 # cond, and skipping it because of run_functions_eagerly would just 1074 # crash immediately. 1075 logging.warning( 1076 "It looks like tf.function behavior was disabled, perhaps using " 1077 "tf.config.run_functions_eagerly. Parallelized tf.cond requires " 1078 "tf.function to work. This primitive will override the disable.") 1079 def_function.run_functions_eagerly(False) 1080 try: 1081 return _parallel_device_cond_wrapper() 1082 finally: 1083 if functions_run_eagerly is not None: 1084 def_function.run_functions_eagerly(functions_run_eagerly) 1085 else: 1086 # For conditions which are eager tensors with a constant value (most of 1087 # them), we only call the relevant branch function and execute it eagerly. 1088 with ops.name_scope(name, "cond", [pred]): 1089 if pred_constant_value: 1090 result = true_fn() 1091 else: 1092 result = false_fn() 1093 if not strict: 1094 result = _UnpackIfSingleton(result) 1095 return result 1096 1097 1098# pylint: disable=redefined-outer-name 1099# pylint: disable=g-doc-args 1100@tf_export(v1=["cond"]) 1101@dispatch.add_dispatch_support 1102@deprecation.deprecated_args( 1103 None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", 1104 "fn1", "fn2") 1105def cond(pred, 1106 true_fn=None, 1107 false_fn=None, 1108 strict=False, 1109 name=None, 1110 fn1=None, 1111 fn2=None): 1112 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 1113 1114 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 1115 `false_fn` must have the same non-zero number and type of outputs. 1116 1117 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 1118 `false_fn` will be executed regardless of which branch is selected at runtime. 1119 1120 Although this behavior is consistent with the dataflow model of TensorFlow, 1121 it has frequently surprised users who expected a lazier semantics. 1122 Consider the following simple program: 1123 1124 ```python 1125 z = tf.multiply(a, b) 1126 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 1127 ``` 1128 1129 If `x < y`, the `tf.add` operation will be executed and `tf.square` 1130 operation will not be executed. Since `z` is needed for at least one 1131 branch of the `cond`, the `tf.multiply` operation is always executed, 1132 unconditionally. 1133 1134 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 1135 call to `cond`, and not at all during `Session.run()`). `cond` 1136 stitches together the graph fragments created during the `true_fn` and 1137 `false_fn` calls with some additional graph nodes to ensure that the right 1138 branch gets executed depending on the value of `pred`. 1139 1140 `tf.cond` supports nested structures as implemented in 1141 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 1142 same (possibly nested) value structure of lists, tuples, and/or named tuples. 1143 Singleton lists and tuples form the only exceptions to this: when returned by 1144 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 1145 This behavior is disabled by passing `strict=True`. 1146 1147 Args: 1148 pred: A scalar determining whether to return the result of `true_fn` or 1149 `false_fn`. 1150 true_fn: The callable to be performed if pred is true. 1151 false_fn: The callable to be performed if pred is false. 1152 strict: A boolean that enables/disables 'strict' mode; see above. 1153 name: Optional name prefix for the returned tensors. 1154 1155 Returns: 1156 Tensors returned by the call to either `true_fn` or `false_fn`. If the 1157 callables return a singleton list, the element is extracted from the list. 1158 1159 Raises: 1160 TypeError: if `true_fn` or `false_fn` is not callable. 1161 ValueError: if `true_fn` and `false_fn` do not return the same number of 1162 tensors, or return tensors of different types. 1163 1164 Example: 1165 1166 ```python 1167 x = tf.constant(2) 1168 y = tf.constant(5) 1169 def f1(): return tf.multiply(x, 17) 1170 def f2(): return tf.add(y, 23) 1171 r = tf.cond(tf.less(x, y), f1, f2) 1172 # r is set to f1(). 1173 # Operations in f2 (e.g., tf.add) are not executed. 1174 ``` 1175 1176 """ 1177 # We needed to make true_fn/false_fn keyword arguments for 1178 # backwards-compatibility. This check exists so that we can convert back to 1179 # having them be positional arguments. 1180 # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after 1181 # `fn1` and `fn2` are deleted. 1182 if fn1 is not None: 1183 if true_fn is not None: 1184 raise TypeError( 1185 "cond(): 'true_fn' and 'fn1' may not be set simultaneously.") 1186 true_fn = fn1 1187 elif true_fn is None: 1188 raise TypeError("cond(): 'true_fn' argument required") 1189 if fn2 is not None: 1190 if false_fn is not None: 1191 raise TypeError( 1192 "cond(): 'false_fn' and 'fn2' may not be set simultaneously.") 1193 false_fn = fn2 1194 elif false_fn is None: 1195 raise TypeError("cond(): 'false_fn' argument required") 1196 1197 if not callable(true_fn): 1198 raise TypeError("'true_fn' must be callable.") 1199 if not callable(false_fn): 1200 raise TypeError("'false_fn' must be callable.") 1201 1202 if context.executing_eagerly(): 1203 return _eager_cond_implementation(pred, true_fn, false_fn, strict, name) 1204 1205 # Always enable control flow v2 if building a function, regardless of toggle. 1206 if util.EnableControlFlowV2(ops.get_default_graph()): 1207 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 1208 1209 with ops.name_scope(name, "cond", [pred]): 1210 # Add the Switch to the graph. 1211 if isinstance(pred, bool): 1212 raise TypeError("'pred' must not be a Python bool.") 1213 p_2, p_1 = switch(pred, pred) 1214 pivot_1 = array_ops.identity(p_1, name="switch_t") 1215 pivot_2 = array_ops.identity(p_2, name="switch_f") 1216 pred = array_ops.identity(pred, name="pred_id") 1217 # Disable the fetching of tensors that are only on one branch of cond. 1218 for tensor in [p_1, p_2, pivot_1, pivot_2, pred]: 1219 tensor.op.graph.prevent_fetching(tensor.op) 1220 1221 # Build the graph for the true branch in a new context. 1222 context_t = CondContext(pred, pivot_1, branch=1) 1223 try: 1224 context_t.Enter() 1225 orig_res_t, res_t = context_t.BuildCondBranch(true_fn) 1226 if orig_res_t is None: 1227 raise ValueError("'true_fn' must have a return value.") 1228 context_t.ExitResult(res_t) 1229 finally: 1230 context_t.Exit() 1231 1232 # Build the graph for the false branch in a new context. 1233 context_f = CondContext(pred, pivot_2, branch=0) 1234 try: 1235 context_f.Enter() 1236 orig_res_f, res_f = context_f.BuildCondBranch(false_fn) 1237 if orig_res_f is None: 1238 raise ValueError("'false_fn' must have a return value.") 1239 context_f.ExitResult(res_f) 1240 finally: 1241 context_f.Exit() 1242 1243 if not strict: 1244 orig_res_t = _UnpackIfSingleton(orig_res_t) 1245 orig_res_f = _UnpackIfSingleton(orig_res_f) 1246 1247 # Check that the return values of the two branches have the same structure. 1248 try: 1249 nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True) 1250 except (TypeError, ValueError): 1251 nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f) 1252 nest.map_structure(_cast_indexed_slice_indices, res_t, res_f) 1253 try: 1254 nest.assert_same_structure(orig_res_t, orig_res_f, 1255 expand_composites=True) 1256 except TypeError as e: 1257 raise TypeError( 1258 f"Incompatible return types of 'true_fn' and 'false_fn': {e}") 1259 except ValueError as e: 1260 raise ValueError( 1261 f"Incompatible return values of 'true_fn' and 'false_fn': {e}") 1262 1263 # Add the final merge to the graph. 1264 if not res_t: 1265 raise ValueError( 1266 "'true_fn' and 'false_fn' must return at least one result.") 1267 1268 res_t_flat = nest.flatten(res_t, expand_composites=True) 1269 res_f_flat = nest.flatten(res_f, expand_composites=True) 1270 1271 for (x, y) in zip(res_t_flat, res_f_flat): 1272 assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) 1273 if x.dtype.base_dtype != y.dtype.base_dtype: 1274 raise ValueError( 1275 "Outputs of 'true_fn' and 'false_fn' must have the same type(s). " 1276 f"Received {x.dtype.name} from 'true_fn' " 1277 f"and {y.dtype.name} from 'false_fn'.") 1278 1279 merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] 1280 merges = nest.map_structure( 1281 _convert_flow_to_tensorarray, 1282 nest.flatten(orig_res_t, expand_composites=True), 1283 merges) 1284 1285 # Only add non-nested conds to the collection. Any nested control flow will 1286 # be encapsulated in the root context. 1287 assert context_t.outer_context == context_f.outer_context 1288 if context_t.outer_context is None: 1289 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) 1290 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) 1291 1292 merges = nest.pack_sequence_as( 1293 structure=orig_res_t, flat_sequence=merges, expand_composites=True) 1294 1295 # Singleton lists and tuples are automatically unpacked if strict == False. 1296 if not strict: 1297 merges = _UnpackIfSingleton(merges) 1298 return merges 1299 1300 1301def _cast_indexed_slice_indices(a, b): 1302 """Cast IndexedSlice.indices from int32 to int64 where necessary. 1303 1304 If `a` and `b` are both IndexedSlices, and their indices have different 1305 dtypes, then cast both their dtypes to `int64` (modifies `a` and `b` 1306 in-place). Otherwise, does nothing. 1307 1308 Args: 1309 a: A value, which may be an IndexedSlices. 1310 b: A value, which may be an IndexedSlices. 1311 """ 1312 if (isinstance(a, indexed_slices.IndexedSlices) and 1313 isinstance(b, indexed_slices.IndexedSlices) and 1314 a.indices.dtype != b.indices.dtype): 1315 # pylint: disable=protected-access 1316 a._indices = math_ops.cast(a.indices, dtypes.int64) 1317 b._indices = math_ops.cast(b.indices, dtypes.int64) 1318 1319 1320# pylint: enable=g-doc-args 1321# pylint: enable=redefined-outer-name 1322 1323 1324@tf_export("cond", v1=[]) 1325@dispatch.add_dispatch_support 1326def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None): 1327 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 1328 1329 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 1330 `false_fn` must have the same non-zero number and type of outputs. 1331 1332 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 1333 `false_fn` will be executed regardless of which branch is selected at runtime. 1334 1335 Although this behavior is consistent with the dataflow model of TensorFlow, 1336 it has frequently surprised users who expected a lazier semantics. 1337 Consider the following simple program: 1338 1339 ```python 1340 z = tf.multiply(a, b) 1341 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 1342 ``` 1343 1344 If `x < y`, the `tf.add` operation will be executed and `tf.square` 1345 operation will not be executed. Since `z` is needed for at least one 1346 branch of the `cond`, the `tf.multiply` operation is always executed, 1347 unconditionally. 1348 1349 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 1350 call to `cond`, and not at all during `Session.run()`). `cond` 1351 stitches together the graph fragments created during the `true_fn` and 1352 `false_fn` calls with some additional graph nodes to ensure that the right 1353 branch gets executed depending on the value of `pred`. 1354 1355 `tf.cond` supports nested structures as implemented in 1356 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 1357 same (possibly nested) value structure of lists, tuples, and/or named tuples. 1358 Singleton lists and tuples form the only exceptions to this: when returned by 1359 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 1360 1361 Note: It is illegal to "directly" use tensors created inside a cond branch 1362 outside it, e.g. by storing a reference to a branch tensor in the python 1363 state. If you need to use a tensor created in a branch function you should 1364 return it as an output of the branch function and use the output from 1365 `tf.cond` instead. 1366 1367 Args: 1368 pred: A scalar determining whether to return the result of `true_fn` or 1369 `false_fn`. 1370 true_fn: The callable to be performed if pred is true. 1371 false_fn: The callable to be performed if pred is false. 1372 name: Optional name prefix for the returned tensors. 1373 1374 Returns: 1375 Tensors returned by the call to either `true_fn` or `false_fn`. If the 1376 callables return a singleton list, the element is extracted from the list. 1377 1378 Raises: 1379 TypeError: if `true_fn` or `false_fn` is not callable. 1380 ValueError: if `true_fn` and `false_fn` do not return the same number of 1381 tensors, or return tensors of different types. 1382 1383 Example: 1384 1385 ```python 1386 x = tf.constant(2) 1387 y = tf.constant(5) 1388 def f1(): return tf.multiply(x, 17) 1389 def f2(): return tf.add(y, 23) 1390 r = tf.cond(tf.less(x, y), f1, f2) 1391 # r is set to f1(). 1392 # Operations in f2 (e.g., tf.add) are not executed. 1393 ``` 1394 1395 """ 1396 return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name) 1397 1398 1399def _resource_safe_shape(t): 1400 """Returns the shape of t or the variable it points to.""" 1401 if t.dtype == dtypes.resource: 1402 while t.op.inputs: 1403 t = t.op.inputs[0] 1404 return tensor_shape.TensorShape(t.op.get_attr("shape")) 1405 return array_ops.shape_internal(t, optimize=False) 1406 1407 1408# TODO(yuanbyu): Consider having a unified notion of context for 1409# not only conditionals and loops but also control dependency and 1410# subgraphs. 1411class WhileContext(ControlFlowContext): 1412 """The context for the loop construct.""" 1413 1414 def __init__(self, 1415 maximum_iterations=None, 1416 parallel_iterations=10, 1417 back_prop=True, 1418 swap_memory=False, 1419 name="while_context", 1420 grad_state=None, 1421 context_def=None, 1422 import_scope=None): 1423 """"Creates a `WhileContext`. 1424 1425 Args: 1426 maximum_iterations: Optional upper bound on number of loop iterations. 1427 parallel_iterations: The number of iterations allowed to run in parallel. 1428 back_prop: Whether backprop is enabled for this while loop. 1429 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 1430 name: Optional name prefix for the returned tensors. 1431 grad_state: The gradient loop state. 1432 context_def: Optional `WhileContextDef` protocol buffer to initialize the 1433 `Whilecontext` python object from. 1434 import_scope: Optional `string`. Name scope to add. Only used when 1435 initialing from protocol buffer. 1436 """ 1437 if context_def: 1438 self._init_from_proto(context_def, import_scope=import_scope) 1439 else: 1440 ControlFlowContext.__init__(self) 1441 self._init_from_args(maximum_iterations, parallel_iterations, back_prop, 1442 swap_memory, name) 1443 # The gradient loop state. 1444 self._grad_state = grad_state 1445 1446 def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop, 1447 swap_memory, name): 1448 """Creates a new `WhileContext` from arguments. 1449 1450 Args: 1451 maximum_iterations: Optional upper bound on number of loop iterations. 1452 parallel_iterations: The number of iterations allowed to run in parallel. 1453 back_prop: Whether backprop is enabled for this while loop. 1454 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 1455 name: Optional name prefix for the returned tensors. 1456 1457 Raises: 1458 ValueError: If `parallel_iterations` has invalid value. 1459 """ 1460 if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0): 1461 raise ValueError("'parallel_iterations' must be a positive integer: " 1462 "%s" % parallel_iterations) 1463 self._name = ops.get_default_graph().unique_name(name) 1464 self._maximum_iterations = maximum_iterations 1465 self._parallel_iterations = parallel_iterations 1466 self._back_prop = back_prop 1467 self._swap_memory = swap_memory 1468 # We use this node to control constants created by the pred lambda. 1469 self._pivot_for_pred = None 1470 # We use this node to control constants created by the body lambda. 1471 self._pivot_for_body = None 1472 # The boolean tensor for loop termination condition. Used in code 1473 # generation for gradient computation 1474 self._pivot = None 1475 # The list of exit tensors for loop variables. 1476 self._loop_exits = [] 1477 # The list of enter tensors for loop variables. 1478 self._loop_enters = [] 1479 self._graph = ops.get_default_graph() 1480 1481 def _init_from_proto(self, context_def, import_scope=None): 1482 """Creates a new `WhileContext` from protocol buffer. 1483 1484 Args: 1485 context_def: `WhileContextDef` protocol buffer. 1486 import_scope: Optional `string`. Name scope to add. 1487 """ 1488 assert isinstance(context_def, control_flow_pb2.WhileContextDef) 1489 # Create from context_def. 1490 g = ops.get_default_graph() 1491 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 1492 if context_def.maximum_iterations_name: 1493 self._maximum_iterations = g.as_graph_element( 1494 ops.prepend_name_scope(context_def.maximum_iterations_name, 1495 import_scope)) 1496 else: 1497 self._maximum_iterations = None 1498 self._parallel_iterations = context_def.parallel_iterations 1499 self._back_prop = context_def.back_prop 1500 self._swap_memory = context_def.swap_memory 1501 self._pivot_for_pred = g.as_graph_element( 1502 ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope)) 1503 # We use this node to control constants created by the body lambda. 1504 self._pivot_for_body = g.as_graph_element( 1505 ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope)) 1506 # The boolean tensor for loop termination condition. Used in code 1507 # generation for gradient computation. 1508 self._pivot = g.as_graph_element( 1509 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 1510 # The list of exit tensors for loop variables. 1511 self._loop_exits = [ 1512 g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) 1513 for exit_name in context_def.loop_exit_names 1514 ] 1515 # The list of enter tensors for loop variables. 1516 self._loop_enters = [ 1517 g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) 1518 for enter_name in context_def.loop_enter_names 1519 ] 1520 super(WhileContext, self).__init__( 1521 values_def=context_def.values_def, import_scope=import_scope) 1522 1523 # import_scope causes self.name to be different from the original serialized 1524 # context's name. Rewrite "frame_name" attrs with the new name. 1525 if import_scope: 1526 for tensor_name in self._values: 1527 op = g.as_graph_element(tensor_name).op 1528 if util.IsLoopEnter(op): 1529 # pylint: disable=protected-access 1530 op._set_attr("frame_name", 1531 attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) 1532 # pylint: enable=protected-access 1533 self._graph = ops.get_default_graph() 1534 1535 @property 1536 def maximum_iterations(self): 1537 """The maximum number of iterations that will be executed.""" 1538 return self._maximum_iterations 1539 1540 @property 1541 def parallel_iterations(self): 1542 """The number of iterations allowed to run in parallel.""" 1543 return self._parallel_iterations 1544 1545 @property 1546 def back_prop(self): 1547 """True iff backprop is enabled for this while loop.""" 1548 return self._back_prop 1549 1550 @property 1551 def swap_memory(self): 1552 """True iff GPU-CPU memory swap is enabled for this while loop.""" 1553 return self._swap_memory 1554 1555 @property 1556 def pivot(self): 1557 """The boolean tensor representing the loop termination condition.""" 1558 return self._pivot 1559 1560 @property 1561 def loop_enters(self): 1562 """The list of enter tensors for loop variables.""" 1563 return self._loop_enters 1564 1565 @property 1566 def loop_exits(self): 1567 """The list of exit tensors for loop variables.""" 1568 return self._loop_exits 1569 1570 @property 1571 def grad_state(self): 1572 """The gradient loop state.""" 1573 return self._grad_state 1574 1575 def to_proto(self, export_scope=None): 1576 """Converts a `WhileContext` to a `WhileContextDef` protocol buffer. 1577 1578 Args: 1579 export_scope: Optional `string`. Name scope to remove. 1580 1581 Returns: 1582 A `WhileContextDef` protocol buffer. 1583 """ 1584 if (export_scope is None or self.name.startswith(export_scope)): 1585 context_def = control_flow_pb2.WhileContextDef() 1586 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 1587 context_def.parallel_iterations = self._parallel_iterations 1588 if self._maximum_iterations is not None: 1589 context_def.maximum_iterations_name = ops.strip_name_scope( 1590 self._maximum_iterations.name, export_scope) 1591 context_def.back_prop = self._back_prop 1592 context_def.swap_memory = self._swap_memory 1593 context_def.pivot_for_pred_name = ops.strip_name_scope( 1594 self._pivot_for_pred.name, export_scope) 1595 context_def.pivot_for_body_name = ops.strip_name_scope( 1596 self._pivot_for_body.name, export_scope) 1597 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 1598 export_scope) 1599 context_def.loop_exit_names.extend([ 1600 ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits 1601 ]) 1602 context_def.loop_enter_names.extend([ 1603 ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters 1604 ]) 1605 context_def.values_def.MergeFrom( 1606 super(WhileContext, self)._to_values_def(export_scope=export_scope)) 1607 for nested in self._nested_contexts: 1608 nested_def = context_def.nested_contexts.add() 1609 nested.to_control_flow_context_def(nested_def) 1610 1611 return context_def 1612 else: 1613 return None 1614 1615 def to_control_flow_context_def(self, context_def, export_scope=None): 1616 context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 1617 1618 @staticmethod 1619 def from_proto(context_def, import_scope=None): 1620 """Returns a `WhileContext` object created from `context_def`. 1621 1622 Args: 1623 context_def: A `WhileContextDef` protocol buffer. 1624 import_scope: Optional `string`. Name scope to add. 1625 1626 Returns: 1627 A `WhileContext` Python object. 1628 """ 1629 ret = WhileContext(context_def=context_def, import_scope=import_scope) 1630 ret.Enter() 1631 for nested_def in context_def.nested_contexts: 1632 from_control_flow_context_def(nested_def, import_scope=import_scope) 1633 ret.Exit() 1634 return ret 1635 1636 def GetWhileContext(self): 1637 return self 1638 1639 def GetControlPivot(self): 1640 if self._pivot_for_body is not None: 1641 return self._pivot_for_body 1642 return self._pivot_for_pred 1643 1644 def AddValue(self, val): 1645 """Add `val` to the current context and its outer context recursively.""" 1646 result = val 1647 new_value = val.name not in self._values 1648 # Don't treat ops in this context as new values. Usually all known values 1649 # are in self._values, except when we're importing a while loop inside this 1650 # WhileContext. Since there's a cycle in this case, `val` may be part of the 1651 # imported while loop but not yet processed by this context and added to 1652 # self._values in _AddOpInternal. We only want to process external input 1653 # tensors to the while loop here. 1654 new_value &= val.op._control_flow_context is not self # pylint: disable=protected-access 1655 if new_value: 1656 self._values.add(val.name) 1657 1658 # If we are in a grad context and val is from its forward context, 1659 # use GetRealValue(), which adds the logic to save the history of 1660 # val in forward. 1661 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 1662 if grad_ctxt: 1663 grad_ctxt = grad_ctxt.GetWhileContext() 1664 if grad_ctxt.grad_state: 1665 forward_ctxt = util.GetWhileContext(val.op) 1666 if util.IsLoopExit(val.op): 1667 forward_ctxt = forward_ctxt.outer_context 1668 if forward_ctxt: 1669 forward_ctxt = forward_ctxt.GetWhileContext() 1670 if forward_ctxt == grad_ctxt.grad_state.forward_context: 1671 real_val = grad_ctxt.grad_state.GetRealValue(val) 1672 self._external_values[val.name] = real_val 1673 return real_val 1674 1675 if self._outer_context is not None: 1676 result = self._outer_context.AddValue(val) 1677 # Create an Enter to make `result` known to this loop context. 1678 with ops.control_dependencies(None): 1679 enter = _Enter( 1680 result, 1681 self._name, 1682 is_constant=True, 1683 parallel_iterations=self._parallel_iterations) 1684 enter.graph.prevent_feeding(enter) 1685 if self._outer_context: 1686 self._outer_context.AddInnerOp(enter.op) 1687 # Fix the control inputs and control flow context of these enter ops. 1688 self._FixControlInputsAndContext([enter]) 1689 1690 # Add `enter` in this context. 1691 self._values.add(enter.name) 1692 self._external_values[val.name] = enter 1693 result = enter 1694 else: 1695 actual_val = self._external_values.get(val.name) 1696 if actual_val is not None: 1697 result = actual_val 1698 return result 1699 1700 def AddOp(self, op): 1701 """Add `op` to the current context.""" 1702 # For a reduction op, if op is in a grad context and its input is from 1703 # its forward context, moving op to the forward context means we would 1704 # store the tensor after the reduction as opposed to the tensor before 1705 # reduction, and therefore could significantly reduce memory consumption. 1706 # For now, we do this only for a few ops. 1707 # 1708 # If in XLA context, do not move constant ops to forward pass as pushing to 1709 # and popping from a stack removes the constant property of an op and breaks 1710 # XLA compilation, which requires certain inputs to be constant for certain 1711 # ops. 1712 if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}: 1713 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 1714 if grad_ctxt: 1715 grad_ctxt = grad_ctxt.GetWhileContext() 1716 if grad_ctxt.grad_state: 1717 op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op) 1718 if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context: 1719 op_input_ctxt = op.inputs[0].op._get_control_flow_context() 1720 op._set_control_flow_context(op_input_ctxt) 1721 op_input_ctxt._AddOpInternal(op) 1722 return 1723 self._AddOpInternal(op) 1724 1725 def _AddOpInternal(self, op): 1726 """Add `op` to the current context. 1727 1728 We move any external control dependencies of the op to the loop pivot, to 1729 ensure they get executed. 1730 """ 1731 # This is needed to prevent frame mismatch errors where there are Const 1732 # nodes inside tf.function in v1 while_loop and inlining is turned on. 1733 if op.type in ["PartitionedCall", "StatefulPartitionedCall"]: 1734 op._add_control_input(self.GetControlPivot().op) # pylint: disable=protected-access 1735 if not op.inputs: 1736 # Remove any external control dependency on this op 1737 control_inputs, external_inputs = self._RemoveExternalControlEdges(op) 1738 # Add a control edge from the control pivot to this op. 1739 if not control_inputs: 1740 # pylint: disable=protected-access 1741 op._add_control_input(self.GetControlPivot().op) 1742 # pylint: enable=protected-access 1743 for x in op.outputs: 1744 self._values.add(x.name) 1745 else: 1746 for index in range(len(op.inputs)): 1747 x = op.inputs[index] 1748 real_x = self.AddValue(x) 1749 if real_x != x: 1750 op._update_input(index, real_x) # pylint: disable=protected-access 1751 # Remove any external control dependency on this op. 1752 _, external_inputs = self._RemoveExternalControlEdges(op) 1753 # Add a control dependency to prevent loop invariants from 1754 # enabling ops that should not be executed. 1755 self._MaybeAddControlDependency(op) 1756 for x in op.outputs: 1757 self._values.add(x.name) 1758 if external_inputs: 1759 # Use an identity to pull control inputs as data inputs. Note that we 1760 # ignore ops which don't have outputs. TODO(apassos): fix that 1761 with ops.control_dependencies(None): 1762 self.Enter() 1763 external_inputs = [ 1764 array_ops.identity(x.outputs[0]).op 1765 for x in external_inputs 1766 if x.outputs 1767 ] 1768 self.Exit() 1769 op._add_control_inputs(external_inputs) # pylint: disable=protected-access 1770 if self._outer_context or not util.IsLoopExit(op): 1771 op.graph.prevent_fetching(op) 1772 for x in op.outputs: 1773 op.graph.prevent_feeding(x) 1774 1775 if self._outer_context: 1776 self._outer_context.AddInnerOp(op) 1777 1778 def _MaybeAddControlDependency(self, op): 1779 """Add a control input to the op if it only depends on loop invariants.""" 1780 1781 def _IsOpFree(op): 1782 """Determines if `op` needs a control dependency.""" 1783 if op.control_inputs: 1784 return False 1785 # pylint: disable=protected-access 1786 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 1787 return True 1788 # pylint: enable=protected-access 1789 for x in op.inputs: 1790 if not util.IsLoopConstantEnter(x.op): 1791 return False 1792 return True 1793 1794 if _IsOpFree(op): 1795 # pylint: disable=protected-access 1796 op._add_control_input(self.GetControlPivot().op) 1797 # pylint: enable=protected-access 1798 1799 def AddForwardLoopCounter(self, outer_grad_state): 1800 """Adds a loop that counts the number of iterations. 1801 1802 This is added to the forward loop at the time when we start to 1803 create the loop for backprop gradient computation. Called in 1804 the outer context of this forward context. 1805 1806 The pseudocode is: 1807 `n = 0; while (_pivot) { n++; }` 1808 1809 Note that a control dependency is added to `n` to ensure the correct 1810 execution order of stack push ops. 1811 1812 Args: 1813 outer_grad_state: The outer grad state. None if not nested. 1814 1815 Returns: 1816 The number of iterations taken by the forward loop and the loop index. 1817 """ 1818 n = constant_op.constant(0, name="f_count") 1819 if outer_grad_state is not None: 1820 # Force the stack pushes of i-th execution of an inner loop to be ordered 1821 # before the pushes of (i+1)-th execution of the same inner loop. 1822 outer_add_op = outer_grad_state.forward_index.op.inputs[0].op 1823 n.op._add_control_input(outer_add_op) # pylint: disable=protected-access 1824 1825 self.Enter() 1826 self.AddName(n.name) 1827 enter_n = _Enter( 1828 n, 1829 self._name, 1830 is_constant=False, 1831 parallel_iterations=self._parallel_iterations, 1832 name="f_count") 1833 self.loop_enters.append(enter_n) 1834 1835 merge_n = merge([enter_n, enter_n])[0] 1836 switch_n = switch(merge_n, self._pivot) 1837 1838 index = math_ops.add(switch_n[1], 1) 1839 next_n = _NextIteration(index) 1840 merge_n.op._update_input(1, next_n) 1841 1842 total_iterations = exit(switch_n[0], name="f_count") 1843 self.loop_exits.append(total_iterations) 1844 self.ExitResult([total_iterations]) 1845 self.Exit() 1846 return total_iterations, next_n 1847 1848 def AddBackpropLoopCounter(self, count, outer_grad_state): 1849 """Add the backprop loop that controls the iterations. 1850 1851 This is added to the backprop loop. It is used to control the loop 1852 termination of the backprop loop. Called in the outer context of 1853 this grad context. 1854 1855 The pseudocode is: 1856 `n = count; while (n >= 1) { n--; }` 1857 1858 Note that a control dependency is added to `final_zero` to ensure the 1859 correct execution order of stack pop ops. 1860 1861 Args: 1862 count: The number of iterations for backprop. 1863 outer_grad_state: The outer grad state. None if not nested. 1864 1865 Returns: 1866 The loop index. 1867 """ 1868 in_separate_functions = count.graph is not ops.get_default_graph() 1869 if in_separate_functions: 1870 # Brings the count into this graph 1871 count = array_ops.identity(count) 1872 else: 1873 # TODO(apassos) XLA expects this constant to be created outside the loop, 1874 # so doing that for now. 1875 one = constant_op.constant(1, name="b_count") 1876 1877 self.Enter() 1878 self.AddName(count.name) 1879 enter_count = _Enter( 1880 count, 1881 self._name, 1882 is_constant=False, 1883 parallel_iterations=self._parallel_iterations, 1884 name="b_count") 1885 self.loop_enters.append(enter_count) 1886 1887 merge_count = merge([enter_count, enter_count])[0] 1888 self._pivot_for_pred = merge_count 1889 1890 if in_separate_functions: 1891 one = constant_op.constant(1, name="b_count") 1892 pred = math_ops.greater_equal(merge_count, one) 1893 self._pivot = loop_cond(pred, name="b_count") 1894 switch_count = switch(merge_count, self._pivot) 1895 1896 index = math_ops.subtract(switch_count[1], one) 1897 self._pivot_for_body = index 1898 next_count = _NextIteration(index) 1899 merge_count.op._update_input(1, next_count) 1900 1901 final_zero = exit(switch_count[0], name="b_count") 1902 self.loop_exits.append(final_zero) 1903 if outer_grad_state is not None: 1904 # Force the stack pops of i-th execution of an inner loop to be ordered 1905 # before the pops of (i+1)-th execution of the same inner loop. 1906 # pylint: disable=protected-access 1907 outer_grad_state.grad_sync._add_control_input(final_zero.op) 1908 # pylint: enable=protected-access 1909 1910 self.ExitResult([final_zero]) 1911 self.Exit() 1912 return next_count 1913 1914 def AddBackpropAccumulator(self, op, grad): 1915 """Add an accumulation loop for every loop invariant. 1916 1917 This is added to the backprop loop. It is used to accumulate partial 1918 gradients within each loop iteration. Called when in the gradient while 1919 context. 1920 1921 The pseudocode is: 1922 ``` 1923 acc = 0.0; 1924 while (_pivot) { 1925 acc += grad; 1926 } 1927 ``` 1928 1929 Args: 1930 op: The Enter op for a loop invariant. 1931 grad: The partial gradient of an iteration for a loop invariant. 1932 1933 Returns: 1934 The gradient for a loop invariant. 1935 """ 1936 self.Exit() 1937 # Create a zeros tensor with the right shape for acc. If we don't 1938 # know the full shape statically, we will have to get the shape 1939 # dynamically from the forward inference. Getting the shape right 1940 # for the zeros is only needed for the base case when the loop exits 1941 # without running any iterations. 1942 shape = grad.get_shape() 1943 if shape.is_fully_defined(): 1944 if self.outer_context: 1945 self.outer_context.Enter() 1946 acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc") 1947 if self.outer_context: 1948 self.outer_context.Exit() 1949 else: 1950 value = op.inputs[0] 1951 if (isinstance(self.outer_context, WhileContext) and 1952 self.outer_context.grad_state is not None): 1953 # We are in a nested while loop. 1954 forward_ctxt = self.grad_state.forward_context 1955 forward_ctxt.outer_context.Enter() 1956 zeros_shape = array_ops.shape_internal(value, optimize=False) 1957 forward_ctxt.outer_context.Exit() 1958 outer_grad_state = self.grad_state.outer_grad_state 1959 history_zeros_shape = outer_grad_state.AddForwardAccumulator( 1960 zeros_shape) 1961 self.outer_context.Enter() 1962 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 1963 history_zeros_shape, zeros_shape) 1964 acc = array_ops.zeros(real_shape, grad.dtype) 1965 self.outer_context.Exit() 1966 else: 1967 if self.outer_context: 1968 self.outer_context.Enter() 1969 zeros_shape = array_ops.shape_internal(value, optimize=False) 1970 acc = array_ops.zeros(zeros_shape, grad.dtype) 1971 if self.outer_context: 1972 self.outer_context.Exit() 1973 1974 self.Enter() 1975 self.AddName(acc.name) 1976 enter_acc = _Enter( 1977 acc, 1978 self._name, 1979 is_constant=False, 1980 parallel_iterations=self._parallel_iterations, 1981 name="b_acc") 1982 self.loop_enters.append(enter_acc) 1983 1984 merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0] 1985 switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot) 1986 1987 add_acc = math_ops.add(switch_acc_true, grad) 1988 next_acc = _NextIteration(add_acc) 1989 merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access 1990 1991 result_acc = exit(switch_acc_false, name="b_acc") 1992 self.loop_exits.append(result_acc) 1993 self.ExitResult([result_acc]) 1994 return result_acc 1995 1996 def AddBackpropIndexedSlicesAccumulator(self, op, grad): 1997 """This is used for accumulating gradients that are IndexedSlices. 1998 1999 This is essentially the equivalent of AddBackpropAccumulator but optimized 2000 for things like updating embeddings from within a while loop. 2001 2002 Args: 2003 op: The Enter op for a loop invariant. 2004 grad: The partial gradients represented as an IndexedSlices. 2005 2006 Returns: 2007 The accumulated IndexedSlices gradient of the loop invariant. 2008 """ 2009 values = grad.values 2010 indices = grad.indices 2011 dense_shape = grad.dense_shape 2012 2013 self.Exit() 2014 if self.outer_context: 2015 self.outer_context.Enter() 2016 if values.get_shape().is_fully_defined(): 2017 values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] + 2018 values.get_shape().dims[1:]) 2019 if self.outer_context: 2020 self.outer_context.Enter() 2021 values_acc = constant_op.constant( 2022 0, values.dtype, shape=values_shape, name="b_acc") 2023 if self.outer_context: 2024 self.outer_context.Exit() 2025 else: 2026 values_shape = _resource_safe_shape(op.inputs[0])[1:] 2027 values_shape = array_ops.concat([[1], values_shape], 0) 2028 values_acc = array_ops.zeros(values_shape, dtype=values.dtype) 2029 indices_acc = constant_op.constant([0], indices.dtype) 2030 shape_acc = None 2031 if dense_shape is not None: 2032 if dense_shape.get_shape().is_fully_defined(): 2033 if self.outer_context: 2034 self.outer_context.Enter() 2035 shape_acc = constant_op.constant( 2036 0, dense_shape.dtype, shape=dense_shape.get_shape()) 2037 if self.outer_context: 2038 self.outer_context.Exit() 2039 else: 2040 shape_acc = array_ops.zeros_like( 2041 array_ops.shape_internal( 2042 op.inputs[0], optimize=False, out_type=dense_shape.dtype), 2043 optimize=False) 2044 2045 if self.outer_context: 2046 self.outer_context.Exit() 2047 2048 self.Enter() 2049 self.AddName(values_acc.name) 2050 self.AddName(indices_acc.name) 2051 init_acc = [indices_acc, values_acc] 2052 if shape_acc is not None: 2053 self.AddName(shape_acc.name) 2054 init_acc.append(shape_acc) 2055 2056 # Set use_input_shape=False since the accumulator tensors will grow in 2057 # size. If use_input_shape=True, the _update_input call below will result in 2058 # incompatible shapes. 2059 enter_acc = [ 2060 _Enter( 2061 x, 2062 self._name, 2063 is_constant=False, 2064 parallel_iterations=self._parallel_iterations, 2065 use_input_shape=False, 2066 name="b_acc") for x in init_acc 2067 ] 2068 # Manually set appropriate partial shapes. 2069 enter_acc[0].set_shape([None]) 2070 if values_acc.shape.dims is not None: 2071 enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:]) 2072 self.loop_enters.extend(enter_acc) 2073 2074 merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc] 2075 switch_acc = [switch(x, self._pivot) for x in merge_acc] 2076 2077 # The actual accumulation. 2078 acc_indexed_slices = [ 2079 array_ops.concat([xa[1], xv], 0) 2080 for xa, xv in zip(switch_acc[:2], [indices, values]) 2081 ] 2082 if shape_acc is not None: 2083 # For the shape we just keep the maximum 2084 acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1])) 2085 2086 next_acc = [_NextIteration(x) for x in acc_indexed_slices] 2087 for xm, xn in zip(merge_acc, next_acc): 2088 xm.op._update_input(1, xn) # pylint: disable=protected-access 2089 2090 exit_acc = [exit(x[0], name="b_acc") for x in switch_acc] 2091 self.loop_exits.extend(exit_acc) 2092 2093 self.ExitResult(exit_acc) 2094 return indexed_slices.IndexedSlices( 2095 indices=exit_acc[0], 2096 values=exit_acc[1], 2097 dense_shape=exit_acc[2] if shape_acc is not None else None) 2098 2099 def _InitializeValues(self, values): 2100 """Makes the values known to this context.""" 2101 self._values = set() 2102 for x in values: 2103 if isinstance(x, ops.Tensor): 2104 self._values.add(x.name) 2105 else: 2106 raise TypeError("'values' must be a list of Tensors. " 2107 f"Received: {type(x)}.") 2108 2109 def _BuildLoop(self, pred, body, flat_orig_loop_vars, flat_loop_vars, 2110 loop_vars_signature): 2111 """Core: Add the loop termination condition and body to the graph.""" 2112 flat_shape_invariants = nest.map_structure( 2113 lambda spec: spec.shape, 2114 nest.flatten(loop_vars_signature, expand_composites=True)) 2115 2116 # Let the context know the loop variables so the loop variables 2117 # would be added in the outer contexts properly. 2118 self._InitializeValues(flat_loop_vars) 2119 if self._outer_context: 2120 real_vars = [self._outer_context.AddValue(x) for x in flat_loop_vars] 2121 else: 2122 real_vars = flat_loop_vars 2123 2124 enter_vars = [] 2125 with ops.control_dependencies(None): 2126 for real_var, shape_invariant in zip(real_vars, flat_shape_invariants): 2127 enter_var = _Enter( 2128 real_var, 2129 self._name, 2130 is_constant=False, 2131 parallel_iterations=self._parallel_iterations, 2132 use_input_shape=False) 2133 2134 if _ShapeLessThanOrEqual(real_var.get_shape(), shape_invariant): 2135 enter_var.set_shape(shape_invariant) 2136 else: 2137 raise ValueError( 2138 f"The shape invariant specified for {real_var.name} is not " 2139 "compatible with the initial shape of the loop variable. It " 2140 f"enters the loop with shape {real_var.get_shape()}, but the " 2141 f"specified shape invariant is {shape_invariant}.") 2142 2143 enter_var.graph.prevent_feeding(enter_var) 2144 if self._outer_context: 2145 self._outer_context.AddInnerOp(enter_var.op) 2146 enter_vars.append(enter_var) 2147 2148 # Finds the closest enclosing non-None control pivot. 2149 outer_context = self._outer_context 2150 control_pivot = None 2151 while outer_context is not None and control_pivot is None: 2152 control_pivot = outer_context.GetControlPivot() 2153 # pylint: disable=protected-access 2154 outer_context = outer_context._outer_context 2155 # pylint: enable=protected-access 2156 2157 if control_pivot is not None: 2158 for var in enter_vars: 2159 if util.IsLoopConstantEnter(var.op.inputs[0].op): 2160 # pylint: disable=protected-access 2161 var.op._add_control_input(control_pivot.op) 2162 # pylint: enable=protected-access 2163 2164 # Fix the control inputs and control flow context of these enter ops. 2165 self._FixControlInputsAndContext(enter_vars) 2166 self._InitializeValues(enter_vars) 2167 self._loop_enters = enter_vars 2168 2169 merge_vars = [merge([x, x])[0] for x in enter_vars] 2170 self._pivot_for_pred = merge_vars[0] 2171 2172 merge_vars_with_tensorarrays = nest.map_structure( 2173 _convert_flow_to_tensorarray, flat_orig_loop_vars, merge_vars) 2174 # Build the graph for pred. 2175 packed_vars = nest.pack_sequence_as( 2176 structure=loop_vars_signature, 2177 flat_sequence=merge_vars_with_tensorarrays, 2178 expand_composites=True) 2179 c = ops.convert_to_tensor(pred(*packed_vars)) 2180 self._pivot = loop_cond(c, name="LoopCond") 2181 switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars] 2182 2183 # Build the graph for body. 2184 vars_for_body = [_Identity(x[1]) for x in switch_vars] 2185 self._pivot_for_body = vars_for_body[0] 2186 # Convert TensorArray flow variables inside the context back into 2187 # their associated TensorArrays for calling the body. 2188 vars_for_body_with_tensorarrays = nest.map_structure( 2189 _convert_flow_to_tensorarray, flat_orig_loop_vars, vars_for_body) 2190 packed_vars_for_body = nest.pack_sequence_as( 2191 structure=loop_vars_signature, 2192 flat_sequence=vars_for_body_with_tensorarrays, 2193 expand_composites=True) 2194 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2195 body_result = body(*packed_vars_for_body) 2196 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2197 if not nest.is_nested(body_result): 2198 body_result = [body_result] 2199 if len(post_summaries) > len(pre_summaries): 2200 new_summaries = post_summaries[len(pre_summaries):] 2201 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2202 summary_ref[:] = pre_summaries 2203 with ops.control_dependencies(new_summaries): 2204 2205 def map_fn(x): 2206 # TODO(apassos) figure out how to trigger with tensor arrays as well 2207 if isinstance(x, tensor_array_ops.TensorArray): 2208 return x 2209 return array_ops.identity(x) 2210 2211 body_result = nest.map_structure( 2212 map_fn, body_result, expand_composites=True) 2213 2214 body_result = variable_utils.convert_variables_to_tensors(body_result) 2215 # Compare the structure types of input and output of body. 2216 # For backwards compatibility, the first layer is forced to a list 2217 # during this comparison, because inputs are typically lists and 2218 # outputs of the body are typically tuples. 2219 nest.assert_same_structure( 2220 list(packed_vars_for_body), list(body_result), expand_composites=True) 2221 2222 # Store body_result to keep track of TensorArrays returned by body 2223 original_body_result = body_result 2224 # Convert TensorArrays returned by body into their flow variables 2225 result = nest.map_structure( 2226 _convert_tensorarray_to_flow, 2227 nest.flatten(body_result, expand_composites=True), 2228 expand_composites=True) 2229 result = ops.convert_n_to_tensor_or_composite(result) 2230 2231 # Add NextIteration and the back edges to complete the loop. 2232 if len(merge_vars) != len(result): 2233 raise ValueError("Number of inputs and outputs of 'body' must match " 2234 f"'loop_vars'. Got {len(merge_vars)} for the number of " 2235 f"inputs/outputs, and {len(result)} for 'loop_vars'.") 2236 next_vars = [] 2237 for m, v in zip(merge_vars, result): 2238 next_vars.append(_AddNextAndBackEdge(m, v)) 2239 2240 # Add the exit ops. 2241 exit_vars = [exit(x[0]) for x in switch_vars] 2242 self._loop_exits = exit_vars 2243 2244 # Exit the loop. 2245 self.ExitResult(exit_vars) 2246 2247 return original_body_result, exit_vars 2248 2249 def BuildLoop(self, pred, body, loop_vars, shape_invariants, 2250 return_same_structure): 2251 """Add the loop termination condition and body to the graph.""" 2252 2253 # Keep flat_orig_loop_vars to identify which are TensorArrays 2254 flat_orig_loop_vars = nest.flatten(loop_vars, expand_composites=True) 2255 2256 loop_vars = nest.map_structure( 2257 _convert_to_tensor_or_composite_or_tensorarray, loop_vars) 2258 # Convert TensorArrays to their flow variables 2259 flat_loop_vars = nest.map_structure( 2260 _convert_tensorarray_to_flow, 2261 nest.flatten(loop_vars, expand_composites=True)) 2262 2263 if shape_invariants is not None: 2264 loop_vars_signature = nest.map_structure( 2265 _shape_invariant_to_type_spec, loop_vars, shape_invariants) 2266 else: 2267 loop_vars_signature = nest.map_structure( 2268 _shape_invariant_to_type_spec, loop_vars) 2269 2270 try: 2271 self.Enter() 2272 # _BuildLoop calls _update_input in several places. _mutation_lock() 2273 # ensures a Session.run call cannot occur between creating and mutating 2274 # new ops. 2275 with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access 2276 original_body_result, exit_vars = self._BuildLoop( 2277 pred, body, flat_orig_loop_vars, flat_loop_vars, 2278 loop_vars_signature) 2279 finally: 2280 self.Exit() 2281 2282 flat_result = nest.flatten(original_body_result, expand_composites=True) 2283 # Convert TensorArray flow variables outside the context back into 2284 # their associated TensorArrays for returning to caller. 2285 exit_vars_with_tensorarrays = nest.map_structure( 2286 _convert_flow_to_tensorarray, flat_result, exit_vars) 2287 2288 packed_exit_vars = nest.pack_sequence_as( 2289 structure=original_body_result, 2290 flat_sequence=exit_vars_with_tensorarrays, 2291 expand_composites=True) 2292 2293 if return_same_structure: 2294 return packed_exit_vars 2295 else: 2296 return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars 2297 2298 def _FixControlInputsAndContext(self, enters): 2299 graph = ops.get_default_graph() 2300 # pylint: disable=protected-access 2301 for e in enters: 2302 if isinstance(e, ops.Tensor): 2303 xs = [e] 2304 else: 2305 raise TypeError("'enters' must be a list of Tensors. " 2306 f"Received: {type(e)}.") 2307 for x in xs: 2308 inp_op = x.op.inputs[0].op 2309 control_inputs = graph._control_dependencies_for_inputs([inp_op]) 2310 outer_control_inputs = [] 2311 for op in control_inputs: 2312 # We need to keep control inputs that are in any ancestor 2313 # ControlFlowContext, and within outer WhileContext. 2314 keep_as_control_input = True 2315 op_ctxt = util.GetOutputContext(op) 2316 outer_ctxt = self.outer_context 2317 outer_while_context = (None if outer_ctxt is None else 2318 outer_ctxt.GetWhileContext()) 2319 while outer_ctxt != op_ctxt: 2320 if outer_ctxt is None or outer_ctxt == outer_while_context: 2321 keep_as_control_input = False 2322 break 2323 outer_ctxt = outer_ctxt.outer_context 2324 if keep_as_control_input: 2325 outer_control_inputs.append(op) 2326 x.op._set_control_flow_context(self) 2327 x.op._add_control_inputs(outer_control_inputs) 2328 graph._record_op_seen_by_control_dependencies(x.op) 2329 # pylint: enable=protected-access 2330 2331 def IsWhileContext(self): 2332 return True 2333 2334 2335# @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature". 2336# pylint: disable=redefined-outer-name 2337@tf_export("while_loop", v1=[]) 2338@deprecation.deprecated_arg_values( 2339 None, 2340 """back_prop=False is deprecated. Consider using tf.stop_gradient instead. 2341Instead of: 2342results = tf.while_loop(c, b, vars, back_prop=False) 2343Use: 2344results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""", 2345 warn_once=True, 2346 back_prop=False) 2347def while_loop_v2(cond, 2348 body, 2349 loop_vars, 2350 shape_invariants=None, 2351 parallel_iterations=10, 2352 back_prop=True, 2353 swap_memory=False, 2354 maximum_iterations=None, 2355 name=None): 2356 """Repeat `body` while the condition `cond` is true. 2357 2358 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 2359 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 2360 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 2361 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 2362 `cond` and `body`. `cond` and `body` both take as many arguments as there are 2363 `loop_vars`. 2364 2365 In addition to regular Tensors or IndexedSlices, the body may accept and 2366 return TensorArray objects. The flows of the TensorArray objects will 2367 be appropriately forwarded between loops and during gradient calculations. 2368 2369 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 2370 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 2371 stitches together the graph fragments created during the `cond` and `body` 2372 calls with some additional graph nodes to create the graph flow that 2373 repeats `body` until `cond` returns false. 2374 2375 For correctness, `tf.while_loop()` strictly enforces shape invariants for 2376 the loop variables. A shape invariant is a (possibly partial) shape that 2377 is unchanged across the iterations of the loop. An error will be raised 2378 if the shape of a loop variable after an iteration is determined to be more 2379 general than or incompatible with its shape invariant. For example, a shape 2380 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 2381 compatible with [11, 17]. By default (if the argument `shape_invariants` is 2382 not specified), it is assumed that the initial shape of each tensor in 2383 `loop_vars` is the same in every iteration. The `shape_invariants` argument 2384 allows the caller to specify a less specific shape invariant for each loop 2385 variable, which is needed if the shape varies between iterations. The 2386 `tf.Tensor.set_shape` 2387 function may also be used in the `body` function to indicate that 2388 the output loop variable has a particular shape. The shape invariant for 2389 SparseTensor and IndexedSlices are treated specially as follows: 2390 2391 a) If a loop variable is a SparseTensor, the shape invariant must be 2392 TensorShape([r]) where r is the rank of the dense tensor represented 2393 by the sparse tensor. It means the shapes of the three tensors of the 2394 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 2395 is the shape of the SparseTensor.dense_shape property. It must be the shape of 2396 a vector. 2397 2398 b) If a loop variable is an IndexedSlices, the shape invariant must be 2399 a shape invariant of the values tensor of the IndexedSlices. It means 2400 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 2401 [shape.ndims]). 2402 2403 `while_loop` implements non-strict semantics, enabling multiple iterations 2404 to run in parallel. The maximum number of parallel iterations can be 2405 controlled by `parallel_iterations`, which gives users some control over 2406 memory consumption and execution order. For correct programs, `while_loop` 2407 should return the same result for any parallel_iterations > 0. 2408 2409 For training, TensorFlow stores the tensors that are produced in the 2410 forward inference and are needed in back propagation. These tensors are a 2411 main source of memory consumption and often cause OOM errors when training 2412 on GPUs. When the flag swap_memory is true, we swap out these tensors from 2413 GPU to CPU. This for example allows us to train RNN models with very long 2414 sequences and large batches. 2415 2416 Args: 2417 cond: A callable that represents the termination condition of the loop. 2418 body: A callable that represents the loop body. 2419 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 2420 `Tensor`, and `TensorArray` objects. 2421 shape_invariants: The shape invariants for the loop variables. 2422 parallel_iterations: The number of iterations allowed to run in parallel. It 2423 must be a positive integer. 2424 back_prop: (optional) Deprecated. False disables support for back 2425 propagation. Prefer using `tf.stop_gradient` instead. 2426 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2427 maximum_iterations: Optional maximum number of iterations of the while loop 2428 to run. If provided, the `cond` output is AND-ed with an additional 2429 condition ensuring the number of iterations executed is no greater than 2430 `maximum_iterations`. 2431 name: Optional name prefix for the returned tensors. 2432 2433 Returns: 2434 The output tensors for the loop variables after the loop. The return value 2435 has the same structure as `loop_vars`. 2436 2437 Raises: 2438 TypeError: if `cond` or `body` is not callable. 2439 ValueError: if `loop_vars` is empty. 2440 2441 Example: 2442 2443 ```python 2444 i = tf.constant(0) 2445 c = lambda i: tf.less(i, 10) 2446 b = lambda i: (tf.add(i, 1), ) 2447 r = tf.while_loop(c, b, [i]) 2448 ``` 2449 2450 Example with nesting and a namedtuple: 2451 2452 ```python 2453 import collections 2454 Pair = collections.namedtuple('Pair', 'j, k') 2455 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 2456 c = lambda i, p: i < 10 2457 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 2458 ijk_final = tf.while_loop(c, b, ijk_0) 2459 ``` 2460 2461 Example using shape_invariants: 2462 2463 ```python 2464 i0 = tf.constant(0) 2465 m0 = tf.ones([2, 2]) 2466 c = lambda i, m: i < 10 2467 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 2468 tf.while_loop( 2469 c, b, loop_vars=[i0, m0], 2470 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 2471 ``` 2472 2473 Example which demonstrates non-strict semantics: In the following 2474 example, the final value of the counter `i` does not depend on `x`. So 2475 the `while_loop` can increment the counter parallel to updates of `x`. 2476 However, because the loop counter at one loop iteration depends 2477 on the value at the previous iteration, the loop counter itself cannot 2478 be incremented in parallel. Hence if we just want the final value of the 2479 counter (which we print on the line `print(sess.run(i))`), then 2480 `x` will never be incremented, but the counter will be updated on a 2481 single thread. Conversely, if we want the value of the output (which we 2482 print on the line `print(sess.run(out).shape)`), then the counter may be 2483 incremented on its own thread, while `x` can be incremented in 2484 parallel on a separate thread. In the extreme case, it is conceivable 2485 that the thread incrementing the counter runs until completion before 2486 `x` is incremented even a single time. The only thing that can never 2487 happen is that the thread updating `x` can never get ahead of the 2488 counter thread because the thread incrementing `x` depends on the value 2489 of the counter. 2490 2491 ```python 2492 import tensorflow as tf 2493 2494 n = 10000 2495 x = tf.constant(list(range(n))) 2496 c = lambda i, x: i < n 2497 b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1, 2498 [i], "x:")) 2499 i, out = tf.while_loop(c, b, (0, x)) 2500 with tf.compat.v1.Session() as sess: 2501 print(sess.run(i)) # prints [0] ... [9999] 2502 2503 # The following line may increment the counter and x in parallel. 2504 # The counter thread may get ahead of the other thread, but not the 2505 # other way around. So you may see things like 2506 # [9996] x:[9987] 2507 # meaning that the counter thread is on iteration 9996, 2508 # while the other thread is on iteration 9987 2509 print(sess.run(out).shape) 2510 ``` 2511 2512 """ 2513 return while_loop( 2514 cond=cond, 2515 body=body, 2516 loop_vars=loop_vars, 2517 shape_invariants=shape_invariants, 2518 parallel_iterations=parallel_iterations, 2519 back_prop=back_prop, 2520 swap_memory=swap_memory, 2521 name=name, 2522 maximum_iterations=maximum_iterations, 2523 return_same_structure=True) 2524 2525 2526# pylint: disable=redefined-outer-name 2527@tf_export(v1=["while_loop"]) 2528def while_loop(cond, 2529 body, 2530 loop_vars, 2531 shape_invariants=None, 2532 parallel_iterations=10, 2533 back_prop=True, 2534 swap_memory=False, 2535 name=None, 2536 maximum_iterations=None, 2537 return_same_structure=False): 2538 """Repeat `body` while the condition `cond` is true. 2539 2540 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 2541 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 2542 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 2543 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 2544 `cond` and `body`. `cond` and `body` both take as many arguments as there are 2545 `loop_vars`. 2546 2547 In addition to regular Tensors or IndexedSlices, the body may accept and 2548 return TensorArray objects. The flows of the TensorArray objects will 2549 be appropriately forwarded between loops and during gradient calculations. 2550 2551 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 2552 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 2553 stitches together the graph fragments created during the `cond` and `body` 2554 calls with some additional graph nodes to create the graph flow that 2555 repeats `body` until `cond` returns false. 2556 2557 For correctness, `tf.while_loop()` strictly enforces shape invariants for 2558 the loop variables. A shape invariant is a (possibly partial) shape that 2559 is unchanged across the iterations of the loop. An error will be raised 2560 if the shape of a loop variable after an iteration is determined to be more 2561 general than or incompatible with its shape invariant. For example, a shape 2562 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 2563 compatible with [11, 17]. By default (if the argument `shape_invariants` is 2564 not specified), it is assumed that the initial shape of each tensor in 2565 `loop_vars` is the same in every iteration. The `shape_invariants` argument 2566 allows the caller to specify a less specific shape invariant for each loop 2567 variable, which is needed if the shape varies between iterations. The 2568 `tf.Tensor.set_shape` 2569 function may also be used in the `body` function to indicate that 2570 the output loop variable has a particular shape. The shape invariant for 2571 SparseTensor and IndexedSlices are treated specially as follows: 2572 2573 a) If a loop variable is a SparseTensor, the shape invariant must be 2574 TensorShape([r]) where r is the rank of the dense tensor represented 2575 by the sparse tensor. It means the shapes of the three tensors of the 2576 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 2577 is the shape of the SparseTensor.dense_shape property. It must be the shape of 2578 a vector. 2579 2580 b) If a loop variable is an IndexedSlices, the shape invariant must be 2581 a shape invariant of the values tensor of the IndexedSlices. It means 2582 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 2583 [shape.ndims]). 2584 2585 `while_loop` implements non-strict semantics, enabling multiple iterations 2586 to run in parallel. The maximum number of parallel iterations can be 2587 controlled by `parallel_iterations`, which gives users some control over 2588 memory consumption and execution order. For correct programs, `while_loop` 2589 should return the same result for any parallel_iterations > 0. 2590 2591 For training, TensorFlow stores the tensors that are produced in the 2592 forward inference and are needed in back propagation. These tensors are a 2593 main source of memory consumption and often cause OOM errors when training 2594 on GPUs. When the flag swap_memory is true, we swap out these tensors from 2595 GPU to CPU. This for example allows us to train RNN models with very long 2596 sequences and large batches. 2597 2598 Args: 2599 cond: A callable that represents the termination condition of the loop. 2600 body: A callable that represents the loop body. 2601 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 2602 `Tensor`, and `TensorArray` objects. 2603 shape_invariants: The shape invariants for the loop variables. 2604 parallel_iterations: The number of iterations allowed to run in parallel. It 2605 must be a positive integer. 2606 back_prop: Whether backprop is enabled for this while loop. 2607 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2608 name: Optional name prefix for the returned tensors. 2609 maximum_iterations: Optional maximum number of iterations of the while loop 2610 to run. If provided, the `cond` output is AND-ed with an additional 2611 condition ensuring the number of iterations executed is no greater than 2612 `maximum_iterations`. 2613 return_same_structure: If True, output has same structure as `loop_vars`. If 2614 eager execution is enabled, this is ignored (and always treated as True). 2615 2616 Returns: 2617 The output tensors for the loop variables after the loop. 2618 If `return_same_structure` is True, the return value has the same 2619 structure as `loop_vars`. 2620 If `return_same_structure` is False, the return value is a Tensor, 2621 TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list 2622 otherwise. 2623 2624 Raises: 2625 TypeError: if `cond` or `body` is not callable. 2626 ValueError: if `loop_vars` is empty. 2627 2628 Example: 2629 2630 ```python 2631 i = tf.constant(0) 2632 c = lambda i: tf.less(i, 10) 2633 b = lambda i: tf.add(i, 1) 2634 r = tf.while_loop(c, b, [i]) 2635 ``` 2636 2637 Example with nesting and a namedtuple: 2638 2639 ```python 2640 import collections 2641 Pair = collections.namedtuple('Pair', 'j, k') 2642 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 2643 c = lambda i, p: i < 10 2644 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 2645 ijk_final = tf.while_loop(c, b, ijk_0) 2646 ``` 2647 2648 Example using shape_invariants: 2649 2650 ```python 2651 i0 = tf.constant(0) 2652 m0 = tf.ones([2, 2]) 2653 c = lambda i, m: i < 10 2654 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 2655 tf.while_loop( 2656 c, b, loop_vars=[i0, m0], 2657 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 2658 ``` 2659 2660 Example which demonstrates non-strict semantics: In the following 2661 example, the final value of the counter `i` does not depend on `x`. So 2662 the `while_loop` can increment the counter parallel to updates of `x`. 2663 However, because the loop counter at one loop iteration depends 2664 on the value at the previous iteration, the loop counter itself cannot 2665 be incremented in parallel. Hence if we just want the final value of the 2666 counter (which we print on the line `print(sess.run(i))`), then 2667 `x` will never be incremented, but the counter will be updated on a 2668 single thread. Conversely, if we want the value of the output (which we 2669 print on the line `print(sess.run(out).shape)`), then the counter may be 2670 incremented on its own thread, while `x` can be incremented in 2671 parallel on a separate thread. In the extreme case, it is conceivable 2672 that the thread incrementing the counter runs until completion before 2673 `x` is incremented even a single time. The only thing that can never 2674 happen is that the thread updating `x` can never get ahead of the 2675 counter thread because the thread incrementing `x` depends on the value 2676 of the counter. 2677 2678 ```python 2679 import tensorflow as tf 2680 2681 n = 10000 2682 x = tf.constant(list(range(n))) 2683 c = lambda i, x: i < n 2684 b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1, 2685 [i], "x:")) 2686 i, out = tf.while_loop(c, b, (0, x)) 2687 with tf.compat.v1.Session() as sess: 2688 print(sess.run(i)) # prints [0] ... [9999] 2689 2690 # The following line may increment the counter and x in parallel. 2691 # The counter thread may get ahead of the other thread, but not the 2692 # other way around. So you may see things like 2693 # [9996] x:[9987] 2694 # meaning that the counter thread is on iteration 9996, 2695 # while the other thread is on iteration 9987 2696 print(sess.run(out).shape) 2697 ``` 2698 2699 """ 2700 if not callable(cond): 2701 raise TypeError("'cond' must be callable.") 2702 if not callable(body): 2703 raise TypeError("'body' must be callable.") 2704 if parallel_iterations < 1: 2705 raise TypeError("'parallel_iterations' must be a positive integer.") 2706 2707 loop_vars = variable_utils.convert_variables_to_tensors(loop_vars) 2708 2709 # Always enable control flow v2 if building a function, regardless of toggle. 2710 executing_eagerly = context.executing_eagerly() 2711 if (util.EnableControlFlowV2(ops.get_default_graph()) and 2712 not executing_eagerly): 2713 return while_v2.while_loop( 2714 cond, 2715 body, 2716 loop_vars, 2717 shape_invariants=shape_invariants, 2718 parallel_iterations=parallel_iterations, 2719 maximum_iterations=maximum_iterations, 2720 name=name, 2721 return_same_structure=return_same_structure, 2722 back_prop=back_prop) 2723 2724 with ops.name_scope(name, "while", loop_vars): 2725 if not loop_vars: 2726 raise ValueError("'loop_vars' must be provided.") 2727 try_to_pack = (len(loop_vars) == 1 and not return_same_structure) 2728 if maximum_iterations is not None: 2729 maximum_iterations = ops.convert_to_tensor( 2730 maximum_iterations, name="maximum_iterations") 2731 if maximum_iterations.shape.ndims != 0: 2732 raise ValueError( 2733 "'maximum_iterations' must be a scalar. " 2734 f"Received shape: {maximum_iterations.shape}") 2735 2736 if executing_eagerly: 2737 counter = 0 2738 maximum_iterations = int(maximum_iterations.numpy()) 2739 else: 2740 counter = constant_op.constant( 2741 0, dtype=maximum_iterations.dtype, name="iteration_counter") 2742 orig_cond = cond 2743 orig_body = body 2744 if try_to_pack: 2745 loop_vars = (counter, loop_vars[0]) 2746 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 2747 math_ops.logical_and(i < maximum_iterations, orig_cond(lv))) 2748 body = lambda i, lv: (i + 1, orig_body(lv)) 2749 else: 2750 loop_vars = (counter, loop_vars) 2751 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 2752 math_ops.logical_and(i < maximum_iterations, orig_cond(*lv))) 2753 body = lambda i, lv: (i + 1, orig_body(*lv)) 2754 try_to_pack = False 2755 2756 if executing_eagerly: 2757 packed = False # whether the body result was packed into a 1-item tuple 2758 2759 loop_var_structure = nest.map_structure(type_spec.type_spec_from_value, 2760 list(loop_vars)) 2761 while cond(*loop_vars): 2762 loop_vars = body(*loop_vars) 2763 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)): 2764 packed = True 2765 loop_vars = (loop_vars,) 2766 nest.assert_same_structure(loop_var_structure, list(loop_vars)) 2767 2768 def convert(x): 2769 if isinstance(x, tensor_array_ops.TensorArray): 2770 return x 2771 return ops.convert_to_tensor(x) 2772 2773 loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True) 2774 if maximum_iterations is not None: 2775 return loop_vars[1] 2776 else: 2777 return loop_vars[0] if packed else loop_vars 2778 2779 if shape_invariants is not None: 2780 if maximum_iterations is not None: 2781 shape_invariants = (tensor_shape.TensorShape([]), shape_invariants) 2782 2783 loop_context = WhileContext( 2784 maximum_iterations=maximum_iterations, 2785 parallel_iterations=parallel_iterations, 2786 back_prop=back_prop, 2787 swap_memory=swap_memory) 2788 # Only add non-nested loops to the collection. Any nested control flow will 2789 # be encapsulated in the root context. 2790 if loop_context.outer_context is None: 2791 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) 2792 result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, 2793 return_same_structure) 2794 if maximum_iterations is not None: 2795 return result[1] 2796 else: 2797 return result 2798 2799 2800# pylint: enable=redefined-outer-name 2801 2802 2803def _AsTensorList(x, p): 2804 """Return x as a list of Tensors or IndexedSlices. 2805 2806 For entries of `x` that are Operations, this returns an Identity of `p` 2807 with a dependency on the operation. 2808 2809 Args: 2810 x: A Tensor/IndexedSlices/Operation or a list or tuple of them. 2811 p: A Tensor to return for entries in `x` that are Operations. 2812 2813 Returns: 2814 A list of Tensors or IndexedSlices. 2815 """ 2816 if not isinstance(x, (list, _basetuple)): 2817 x = [x] 2818 2819 l = [] 2820 for v in x: 2821 if isinstance(v, ops.Operation): 2822 v = with_dependencies([v], p) 2823 v = ops.convert_to_tensor_or_composite(v) 2824 if isinstance(v, ops.Tensor): 2825 l.append(array_ops.identity(v)) 2826 else: 2827 l.append( 2828 indexed_slices.IndexedSlices( 2829 array_ops.identity(v.values), array_ops.identity(v.indices))) 2830 return l 2831 2832 2833def _CheckResults(a, b): 2834 assert len(a) == len(b), ( 2835 "Values returned by a() and b() must have the same length.") 2836 for x, y in zip(a, b): 2837 assert x.dtype == y.dtype, ( 2838 "Values returned by a() [%s] and b() [%s] must have " 2839 "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name)) 2840 2841 2842def with_dependencies(dependencies, output_tensor, name=None): 2843 """Produces the content of `output_tensor` only after `dependencies`. 2844 2845 In some cases, a user may want the output of an operation to be 2846 consumed externally only after some other dependencies have run 2847 first. This function ensures returns `output_tensor`, but only after all 2848 operations in `dependencies` have run. Note that this means that there is 2849 no guarantee that `output_tensor` will be evaluated after any `dependencies` 2850 have run. 2851 2852 See also `tf.tuple` and `tf.group`. 2853 2854 Args: 2855 dependencies: Iterable of operations to run before this op finishes. 2856 output_tensor: A `Tensor` or `IndexedSlices` that will be returned. 2857 name: (Optional) A name for this operation. 2858 2859 Returns: 2860 Same as `output_tensor`. 2861 2862 Raises: 2863 TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. 2864 """ 2865 if context.executing_eagerly(): 2866 return output_tensor 2867 with ops.name_scope(name, "control_dependency", 2868 list(dependencies) + [output_tensor]) as name: 2869 with ops.colocate_with(output_tensor): 2870 with ops.control_dependencies(dependencies): 2871 output_tensor = ops.convert_to_tensor_or_composite(output_tensor) 2872 if isinstance(output_tensor, indexed_slices.IndexedSlices): 2873 return indexed_slices.IndexedSlices( 2874 _Identity(output_tensor.values, name=name), output_tensor.indices, 2875 output_tensor.dense_shape) 2876 else: 2877 return _Identity(output_tensor, name=name) 2878 2879 2880def _GroupControlDeps(dev, deps, name=None): 2881 with ops.control_dependencies(deps): 2882 if dev is None: 2883 return no_op(name=name) 2884 else: 2885 with ops.device(dev): 2886 return no_op(name=name) 2887 2888 2889# TODO(touts): Accept "inputs" as a list. 2890@tf_export("group") 2891def group(*inputs, **kwargs): 2892 """Create an op that groups multiple operations. 2893 2894 When this op finishes, all ops in `inputs` have finished. This op has no 2895 output. 2896 2897 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 2898 this method, as ops execute in the expected order thanks to automatic control 2899 dependencies.* Only use `tf.group` when working with v1 2900 `tf.Graph` code. 2901 2902 When operating in a v1-style graph context, ops are not executed in the same 2903 order as specified in the code; TensorFlow will attempt to execute ops in 2904 parallel or in an order convenient to the result it is computing. `tf.group` 2905 allows you to request that one or more results finish before execution 2906 continues. 2907 2908 `tf.group` creates a single op (of type `NoOp`), and then adds appropriate 2909 control dependencies. Thus, `c = tf.group(a, b)` will compute the same graph 2910 as this: 2911 2912 with tf.control_dependencies([a, b]): 2913 c = tf.no_op() 2914 2915 See also `tf.tuple` and 2916 `tf.control_dependencies`. 2917 2918 Args: 2919 *inputs: Zero or more tensors to group. 2920 name: A name for this operation (optional). 2921 2922 Returns: 2923 An Operation that executes all its inputs. 2924 2925 Raises: 2926 ValueError: If an unknown keyword argument is provided. 2927 """ 2928 if context.executing_eagerly(): 2929 return None 2930 name = kwargs.pop("name", None) 2931 if kwargs: 2932 raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys())) 2933 with ops.name_scope(name, "group_deps", inputs) as name: 2934 # Grouping no inputs means do nothing 2935 if not inputs: 2936 return no_op(name=name) 2937 2938 # Sorts *inputs according to their devices. 2939 ops_on_device = {} # device -> operations specified on the device. 2940 for inp in nest.flatten(inputs, expand_composites=True): 2941 if not hasattr(inp, "device"): 2942 raise TypeError("'inputs' should be zero or more (nested) Tensors. " 2943 f"Received '{inp}' with type '{type(inp)}'.") 2944 dev = inp.device 2945 if dev in ops_on_device: 2946 ops_on_device[dev].append(inp) 2947 else: 2948 ops_on_device[dev] = [inp] 2949 if len(ops_on_device) == 1: 2950 # 1-level tree. The root node is the returned NoOp node. 2951 (dev, deps), = ops_on_device.items() 2952 return _GroupControlDeps(dev, deps, name=name) 2953 2954 # 2-level tree. The root node is the returned NoOp node. 2955 # deps contains 1 NoOp node for each device. 2956 deps = [] 2957 2958 def device_key(dev): 2959 """A sort key that allows None to be compared to strings.""" 2960 return "" if dev is None else dev 2961 2962 for dev in sorted(ops_on_device, key=device_key): 2963 deps.append(_GroupControlDeps(dev, ops_on_device[dev])) 2964 2965 with ops.control_dependencies(deps): 2966 return no_op(name=name) 2967 2968 2969@tf_export("tuple", v1=[]) 2970@dispatch.add_dispatch_support 2971def tuple_v2(tensors, control_inputs=None, name=None): 2972 """Groups tensors together. 2973 2974 The returned tensors have the same value as the input tensors, but they 2975 are computed only after all the input tensors have been computed. 2976 2977 Note: *In TensorFlow 2 with eager and/or Autograph, you should not require 2978 this method, as ops execute in the expected order thanks to automatic control 2979 dependencies.* Only use `tf.tuple` when working with v1 `tf.Graph` code. 2980 2981 See also `tf.group` and `tf.control_dependencies`. 2982 2983 Args: 2984 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 2985 control_inputs: List of additional ops to finish before returning. 2986 name: (optional) A name to use as a `name_scope` for the operation. 2987 2988 Returns: 2989 Same as `tensors`. 2990 2991 Raises: 2992 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 2993 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 2994 objects. 2995 2996 """ 2997 return tuple(tensors=tensors, name=name, control_inputs=control_inputs) # pylint: disable=redefined-builtin 2998 2999 3000@tf_export(v1=["tuple"]) 3001@dispatch.add_dispatch_support 3002def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin 3003 """Group tensors together. 3004 3005 This creates a tuple of tensors with the same values as the `tensors` 3006 argument, except that the value of each tensor is only returned after the 3007 values of all tensors have been computed. 3008 3009 `control_inputs` contains additional ops that have to finish before this op 3010 finishes, but whose outputs are not returned. 3011 3012 This can be used as a "join" mechanism for parallel computations: all the 3013 argument tensors can be computed in parallel, but the values of any tensor 3014 returned by `tuple` are only available after all the parallel computations 3015 are done. 3016 3017 See also `tf.group` and 3018 `tf.control_dependencies`. 3019 3020 Args: 3021 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 3022 name: (optional) A name to use as a `name_scope` for the operation. 3023 control_inputs: List of additional ops to finish before returning. 3024 3025 Returns: 3026 Same as `tensors`. 3027 3028 Raises: 3029 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 3030 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 3031 objects. 3032 3033 """ 3034 if context.executing_eagerly(): 3035 return tensors 3036 with ops.name_scope(name, "tuple", tensors) as name: 3037 tensors = [ 3038 t if (isinstance(t, ops.Operation) or tensor_util.is_tf_type(t) or 3039 t is None) else ops.convert_to_tensor(t) for t in tensors 3040 ] 3041 gating_ops = [ 3042 t if isinstance(t, ops.Operation) else t.op 3043 for t in tensors 3044 if t is not None 3045 ] 3046 if control_inputs: 3047 for c in control_inputs: 3048 if isinstance(c, ops.Tensor): 3049 c = c.op 3050 elif not isinstance(c, ops.Operation): 3051 raise TypeError( 3052 "'control_inputs' must only contain Operation or Tensor. " 3053 f"Received: {type(c)}") 3054 gating_ops.append(c) 3055 # Note that in order to ensure ordering in the pbtxt, we must take care to 3056 # ensure the order here. 3057 gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops. 3058 if not gating_ops: 3059 raise ValueError("'tensors' must have at least one Tensor. " 3060 f"Received: {tensors}.") 3061 gate = group(*gating_ops) 3062 tpl = [] 3063 for t in tensors: 3064 if tensor_util.is_tf_type(t): 3065 tpl.append(with_dependencies([gate], t)) 3066 elif isinstance(t, ops.Operation): 3067 with ops.control_dependencies([gate]): 3068 tpl.append(group(t)) 3069 else: 3070 tpl.append(None) 3071 return tpl 3072 3073 3074def _assert_at_most_n_true(predicates, n, msg): 3075 """Returns an Assert op that checks that at most n predicates are True. 3076 3077 Args: 3078 predicates: list of bool scalar tensors. 3079 n: maximum number of true predicates allowed. 3080 msg: Error message. 3081 """ 3082 preds_c = array_ops.stack(predicates, name="preds_c") 3083 num_true_conditions = math_ops.reduce_sum( 3084 math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") 3085 condition = math_ops.less_equal(num_true_conditions, 3086 constant_op.constant(n, name="n_true_conds")) 3087 preds_names = ", ".join(getattr(p, "name", "?") for p in predicates) 3088 error_msg = [ 3089 "%s: more than %d conditions (%s) evaluated as True:" % 3090 (msg, n, preds_names), preds_c 3091 ] 3092 return Assert(condition, data=error_msg, summarize=len(predicates)) 3093 3094 3095def _case_create_default_action(predicates, actions): 3096 """Creates default action for a list of actions and their predicates. 3097 3098 It uses the input actions to select an arbitrary as default and makes sure 3099 that corresponding predicates have valid values. 3100 3101 Args: 3102 predicates: a list of bool scalar tensors 3103 actions: a list of callable objects which return tensors. 3104 3105 Returns: 3106 a callable 3107 """ 3108 k = len(predicates) - 1 # could pick any 3109 predicate, action = predicates[k], actions[k] 3110 other_predicates, other_actions = predicates[:k], actions[:k] 3111 3112 def default_action(): 3113 others_msg = ("Implementation error: " 3114 "selected default action #%d was called, but some of other " 3115 "predicates are True: " % k) 3116 default_msg = ("Input error: " 3117 "None of conditions evaluated as True:", 3118 array_ops.stack(predicates, name="preds_c")) 3119 with ops.control_dependencies([ 3120 _assert_at_most_n_true(other_predicates, n=0, msg=others_msg), 3121 Assert(predicate, data=default_msg) 3122 ]): 3123 return action() 3124 3125 return default_action, other_predicates, other_actions 3126 3127 3128def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name, 3129 allow_python_preds): 3130 """Verifies input arguments for the case function. 3131 3132 Args: 3133 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 3134 callable which returns a list of tensors. 3135 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3136 name: A name for the case operation. 3137 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 3138 addition to boolean Tensors 3139 3140 Raises: 3141 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3142 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3143 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3144 callable. 3145 3146 Returns: 3147 a tuple <list of scalar bool tensors, list of callables>. 3148 """ 3149 if not isinstance(pred_fn_pairs, (list, _basetuple, dict)): 3150 raise TypeError("'pred_fn_pairs' must be a list, tuple, or dict. " 3151 f"Received: {type(pred_fn_pairs)}") 3152 3153 if isinstance(pred_fn_pairs, collections.OrderedDict): 3154 pred_fn_pairs = pred_fn_pairs.items() 3155 elif isinstance(pred_fn_pairs, dict): 3156 if context.executing_eagerly(): 3157 # No name to sort on in eager mode. Use dictionary traversal order, 3158 # which is nondeterministic in versions of Python < 3.6 3159 if not exclusive: 3160 raise ValueError("Unordered dictionaries are not supported for the " 3161 "'pred_fn_pairs' argument when `exclusive=False` and " 3162 "eager mode is enabled.") 3163 pred_fn_pairs = list(pred_fn_pairs.items()) 3164 else: 3165 pred_fn_pairs = sorted( 3166 pred_fn_pairs.items(), key=lambda item: item[0].name) 3167 if not exclusive: 3168 logging.warn( 3169 "%s: An unordered dictionary of predicate/fn pairs was " 3170 "provided, but exclusive=False. The order of conditional " 3171 "tests is deterministic but not guaranteed.", name) 3172 for pred_fn_pair in pred_fn_pairs: 3173 if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2: 3174 raise TypeError("Each entry in 'pred_fn_pairs' must be a 2-tuple. " 3175 f"Received {pred_fn_pair}.") 3176 pred, fn = pred_fn_pair 3177 3178 if isinstance(pred, ops.Tensor): 3179 if pred.dtype != dtypes.bool: 3180 raise TypeError("pred must be Tensor of type bool: %s" % pred.name) 3181 elif not allow_python_preds: 3182 raise TypeError("pred must be a Tensor, got: %s" % pred) 3183 elif not isinstance(pred, bool): 3184 raise TypeError("pred must be a Tensor or bool, got: %s" % pred) 3185 3186 if not callable(fn): 3187 raise TypeError("fn for pred %s must be callable." % pred.name) 3188 3189 predicates, actions = zip(*pred_fn_pairs) 3190 return predicates, actions 3191 3192 3193def _case_helper(cond_fn, 3194 pred_fn_pairs, 3195 default, 3196 exclusive, 3197 name, 3198 allow_python_preds=False, 3199 **cond_kwargs): 3200 """Implementation of case that allows for different cond functions. 3201 3202 Args: 3203 cond_fn: method that has signature and semantics of `cond` above. 3204 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 3205 callable which returns a list of tensors. 3206 default: Optional callable that returns a list of tensors. 3207 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3208 name: A name for this operation (optional). 3209 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 3210 addition to boolean Tensors 3211 **cond_kwargs: keyword arguments that will be passed to `cond_fn`. 3212 3213 Returns: 3214 The tensors returned by the first pair whose predicate evaluated to True, or 3215 those returned by `default` if none does. 3216 3217 Raises: 3218 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3219 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3220 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3221 callable. 3222 """ 3223 predicates, actions = _case_verify_and_canonicalize_args( 3224 pred_fn_pairs, exclusive, name, allow_python_preds) 3225 with ops.name_scope(name, "case", [predicates]): 3226 if default is None: 3227 default, predicates, actions = _case_create_default_action( 3228 predicates, actions) 3229 fn = default 3230 # To eval conditions in direct order we create nested conditions in reverse: 3231 # cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...)) 3232 for predicate, action in reversed(list(zip(predicates, actions))): 3233 fn = functools.partial( 3234 cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs) 3235 if exclusive: 3236 with ops.control_dependencies([ 3237 _assert_at_most_n_true( 3238 predicates, n=1, msg="Input error: exclusive=True") 3239 ]): 3240 return fn() 3241 else: 3242 return fn() 3243 3244 3245def _indexed_case_verify_and_canonicalize_args(branch_fns, default, 3246 branch_index): 3247 """Verifies input arguments for the case function. 3248 3249 Args: 3250 branch_fns: Dict or list of pairs of an `int` and a callable which 3251 returns a list of tensors. 3252 default: Optional callable that returns a list of tensors. 3253 branch_index: Optional int `Tensor`, which selects for the corresponding 3254 pred_fn_pair. 3255 3256 Raises: 3257 TypeError: If `branch_fns` is not a list/dictionary. 3258 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3259 callables. 3260 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3261 callable. 3262 3263 Returns: 3264 branch_fns: validated list of callables for each branch (default last). 3265 """ 3266 if not isinstance(branch_index, ops.Tensor): 3267 raise TypeError("'branch_index' must be a Tensor, got {}".format( 3268 type(branch_index))) 3269 if not branch_index.dtype.is_integer: 3270 raise TypeError("'branch_index' must be an integer Tensor, got {}".format( 3271 branch_index.dtype)) 3272 3273 if not branch_fns: 3274 raise ValueError("Must provide at least one item in 'branch_fns'") 3275 if not isinstance(branch_fns, (list, _basetuple, dict)): 3276 raise TypeError("'branch_fns' must be a list, tuple, or dict") 3277 3278 if isinstance(branch_fns, dict): 3279 branch_fns = branch_fns.items() 3280 3281 if all(callable(fn) for fn in branch_fns): 3282 branch_fns = list(enumerate(branch_fns)) 3283 3284 for key_fn_pair in branch_fns: 3285 if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2: 3286 raise TypeError("Each entry in 'branch_fns' must be a 2-tuple. " 3287 f"Received {key_fn_pair}.") 3288 key, branch_fn = key_fn_pair 3289 3290 if not isinstance(key, int): 3291 raise TypeError("key must be a Python `int`, got {}".format(type(key))) 3292 3293 if not callable(branch_fn): 3294 raise TypeError("fn for key {} must be callable.".format(key)) 3295 3296 keys = [p[0] for p in branch_fns] 3297 if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys): 3298 raise ValueError( 3299 "branch indices (keys) must form contiguous range of [0 to {}) but " 3300 "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys))))) 3301 actions = [p[1] for p in sorted(branch_fns)] 3302 if default is not None: 3303 actions.append(default) 3304 return actions 3305 3306 3307def _indexed_case_helper(branch_fns, 3308 default, 3309 branch_index, 3310 name, 3311 lower_using_switch_merge=None): 3312 """Implementation of case that emits the n-way indexed Case op. 3313 3314 Args: 3315 branch_fns: Dict or list of pairs of a boolean scalar tensor, and a 3316 callable which returns a list of tensors. 3317 default: Optional callable that returns a list of tensors. 3318 branch_index: Optional int `Tensor`, which selects for the corresponding 3319 pred_fn_pair. 3320 name: A name for this operation (optional). 3321 lower_using_switch_merge: Lower this op using switch merge ops (optional). 3322 3323 Returns: 3324 The tensors returned by the pair whose key matched branch_index, or 3325 those returned by `default` if none does. 3326 3327 Raises: 3328 TypeError: If `branch_fns` is not a list/dictionary. 3329 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3330 callables. 3331 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3332 callable. 3333 """ 3334 branch_fns = _indexed_case_verify_and_canonicalize_args( 3335 branch_fns, default, branch_index) 3336 with ops.name_scope(name, "case", [branch_index]): 3337 if context.executing_eagerly() and not hasattr(branch_index, "graph"): 3338 branch_index = array_ops.where( 3339 math_ops.less(branch_index, 0) 3340 | math_ops.greater_equal(branch_index, len(branch_fns)), 3341 len(branch_fns) - 1, branch_index) 3342 return branch_fns[int(branch_index)]() 3343 return cond_v2.indexed_case( 3344 branch_index, 3345 branch_fns, 3346 lower_using_switch_merge=lower_using_switch_merge) 3347 3348 3349@tf_export("case", v1=[]) 3350@dispatch.add_dispatch_support 3351def case_v2(pred_fn_pairs, 3352 default=None, 3353 exclusive=False, 3354 strict=False, 3355 name="case"): 3356 """Create a case operation. 3357 3358 See also `tf.switch_case`. 3359 3360 The `pred_fn_pairs` parameter is a list of pairs of size N. 3361 Each pair contains a boolean scalar tensor and a python callable that 3362 creates the tensors to be returned if the boolean evaluates to True. 3363 `default` is a callable generating a list of tensors. All the callables 3364 in `pred_fn_pairs` as well as `default` (if provided) should return the same 3365 number and types of tensors. 3366 3367 If `exclusive==True`, all predicates are evaluated, and an exception is 3368 thrown if more than one of the predicates evaluates to `True`. 3369 If `exclusive==False`, execution stops at the first predicate which 3370 evaluates to True, and the tensors generated by the corresponding function 3371 are returned immediately. If none of the predicates evaluate to True, this 3372 operation returns the tensors generated by `default`. 3373 3374 `tf.case` supports nested structures as implemented in 3375 `tf.nest`. All of the callables must return the same (possibly nested) value 3376 structure of lists, tuples, and/or named tuples. Singleton lists and tuples 3377 form the only exceptions to this: when returned by a callable, they are 3378 implicitly unpacked to single values. This behavior is disabled by passing 3379 `strict=True`. 3380 3381 @compatibility(v2) 3382 `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and 3383 tf.Variable are no longer hashable in v2, so cannot be used as a key for a 3384 dictionary. Please use a list or a tuple instead. 3385 @end_compatibility 3386 3387 3388 **Example 1:** 3389 3390 Pseudocode: 3391 3392 ``` 3393 if (x < y) return 17; 3394 else return 23; 3395 ``` 3396 3397 Expressions: 3398 3399 ```python 3400 f1 = lambda: tf.constant(17) 3401 f2 = lambda: tf.constant(23) 3402 r = tf.case([(tf.less(x, y), f1)], default=f2) 3403 ``` 3404 3405 **Example 2:** 3406 3407 Pseudocode: 3408 3409 ``` 3410 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 3411 if (x < y) return 17; 3412 else if (x > z) return 23; 3413 else return -1; 3414 ``` 3415 3416 Expressions: 3417 3418 ```python 3419 def f1(): return tf.constant(17) 3420 def f2(): return tf.constant(23) 3421 def f3(): return tf.constant(-1) 3422 r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)], 3423 default=f3, exclusive=True) 3424 ``` 3425 3426 Args: 3427 pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which 3428 returns a list of tensors. 3429 default: Optional callable that returns a list of tensors. 3430 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3431 strict: A boolean that enables/disables 'strict' mode; see above. 3432 name: A name for this operation (optional). 3433 3434 Returns: 3435 The tensors returned by the first pair whose predicate evaluated to True, or 3436 those returned by `default` if none does. 3437 3438 Raises: 3439 TypeError: If `pred_fn_pairs` is not a list/tuple. 3440 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3441 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3442 callable. 3443 """ 3444 return _case_helper( 3445 cond, 3446 pred_fn_pairs, 3447 default, 3448 exclusive, 3449 name, 3450 allow_python_preds=False, 3451 strict=strict) 3452 3453 3454@tf_export(v1=["case"]) 3455@dispatch.add_dispatch_support 3456def case(pred_fn_pairs, 3457 default=None, 3458 exclusive=False, 3459 strict=False, 3460 name="case"): 3461 """Create a case operation. 3462 3463 See also `tf.switch_case`. 3464 3465 The `pred_fn_pairs` parameter is a dict or list of pairs of size N. 3466 Each pair contains a boolean scalar tensor and a python callable that 3467 creates the tensors to be returned if the boolean evaluates to True. 3468 `default` is a callable generating a list of tensors. All the callables 3469 in `pred_fn_pairs` as well as `default` (if provided) should return the same 3470 number and types of tensors. 3471 3472 If `exclusive==True`, all predicates are evaluated, and an exception is 3473 thrown if more than one of the predicates evaluates to `True`. 3474 If `exclusive==False`, execution stops at the first predicate which 3475 evaluates to True, and the tensors generated by the corresponding function 3476 are returned immediately. If none of the predicates evaluate to True, this 3477 operation returns the tensors generated by `default`. 3478 3479 `tf.case` supports nested structures as implemented in 3480 `tf.nest`. All of the callables must return the same (possibly nested) value 3481 structure of lists, tuples, and/or named tuples. Singleton lists and tuples 3482 form the only exceptions to this: when returned by a callable, they are 3483 implicitly unpacked to single values. This behavior is disabled by passing 3484 `strict=True`. 3485 3486 If an unordered dictionary is used for `pred_fn_pairs`, the order of the 3487 conditional tests is not guaranteed. However, the order is guaranteed to be 3488 deterministic, so that variables created in conditional branches are created 3489 in fixed order across runs. 3490 3491 @compatibility(eager) 3492 Unordered dictionaries are not supported in eager mode when `exclusive=False`. 3493 Use a list of tuples instead. 3494 @end_compatibility 3495 3496 3497 **Example 1:** 3498 3499 Pseudocode: 3500 3501 ``` 3502 if (x < y) return 17; 3503 else return 23; 3504 ``` 3505 3506 Expressions: 3507 3508 ```python 3509 f1 = lambda: tf.constant(17) 3510 f2 = lambda: tf.constant(23) 3511 r = tf.case([(tf.less(x, y), f1)], default=f2) 3512 ``` 3513 3514 **Example 2:** 3515 3516 Pseudocode: 3517 3518 ``` 3519 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 3520 if (x < y) return 17; 3521 else if (x > z) return 23; 3522 else return -1; 3523 ``` 3524 3525 Expressions: 3526 3527 ```python 3528 def f1(): return tf.constant(17) 3529 def f2(): return tf.constant(23) 3530 def f3(): return tf.constant(-1) 3531 r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2}, 3532 default=f3, exclusive=True) 3533 ``` 3534 3535 Args: 3536 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 3537 callable which returns a list of tensors. 3538 default: Optional callable that returns a list of tensors. 3539 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3540 strict: A boolean that enables/disables 'strict' mode; see above. 3541 name: A name for this operation (optional). 3542 3543 Returns: 3544 The tensors returned by the first pair whose predicate evaluated to True, or 3545 those returned by `default` if none does. 3546 3547 Raises: 3548 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3549 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3550 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3551 callable. 3552 """ 3553 return _case_helper( 3554 cond, 3555 pred_fn_pairs, 3556 default, 3557 exclusive, 3558 name, 3559 allow_python_preds=False, 3560 strict=strict) 3561 3562 3563@tf_export("switch_case") 3564def switch_case(branch_index, 3565 branch_fns, 3566 default=None, 3567 name="switch_case"): 3568 """Create a switch/case operation, i.e. an integer-indexed conditional. 3569 3570 See also `tf.case`. 3571 3572 This op can be substantially more efficient than `tf.case` when exactly one 3573 branch will be selected. `tf.switch_case` is more like a C++ switch/case 3574 statement than `tf.case`, which is more like an if/elif/elif/else chain. 3575 3576 The `branch_fns` parameter is either a dict from `int` to callables, or list 3577 of (`int`, callable) pairs, or simply a list of callables (in which case the 3578 index is implicitly the key). The `branch_index` `Tensor` is used to select an 3579 element in `branch_fns` with matching `int` key, falling back to `default` 3580 if none match, or `max(keys)` if no `default` is provided. The keys must form 3581 a contiguous set from `0` to `len(branch_fns) - 1`. 3582 3583 `tf.switch_case` supports nested structures as implemented in `tf.nest`. All 3584 callables must return the same (possibly nested) value structure of lists, 3585 tuples, and/or named tuples. 3586 3587 **Example:** 3588 3589 Pseudocode: 3590 3591 ```c++ 3592 switch (branch_index) { // c-style switch 3593 case 0: return 17; 3594 case 1: return 31; 3595 default: return -1; 3596 } 3597 ``` 3598 or 3599 ```python 3600 branches = {0: lambda: 17, 1: lambda: 31} 3601 branches.get(branch_index, lambda: -1)() 3602 ``` 3603 3604 Expressions: 3605 3606 ```python 3607 def f1(): return tf.constant(17) 3608 def f2(): return tf.constant(31) 3609 def f3(): return tf.constant(-1) 3610 r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3) 3611 # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3}) 3612 ``` 3613 3614 Args: 3615 branch_index: An int Tensor specifying which of `branch_fns` should be 3616 executed. 3617 branch_fns: A `dict` mapping `int`s to callables, or a `list` of 3618 (`int`, callable) pairs, or simply a list of callables (in which case the 3619 index serves as the key). Each callable must return a matching structure 3620 of tensors. 3621 default: Optional callable that returns a structure of tensors. 3622 name: A name for this operation (optional). 3623 3624 Returns: 3625 The tensors returned by the callable identified by `branch_index`, or those 3626 returned by `default` if no key matches and `default` was provided, or those 3627 returned by the max-keyed `branch_fn` if no `default` is provided. 3628 3629 Raises: 3630 TypeError: If `branch_fns` is not a list/dictionary. 3631 TypeError: If `branch_fns` is a list but does not contain 2-tuples or 3632 callables. 3633 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3634 callable. 3635 """ 3636 return _indexed_case_helper(branch_fns, default, branch_index, name) 3637 3638 3639@tf_export("__internal__.execute_fn_for_device", v1=[]) 3640def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"): 3641 """Executes one of the provided callables based on the device placement. 3642 3643 This API is used when the implementations for high level function depend on 3644 the underlying device placement. It takes a dictionary of device type to 3645 callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of 3646 the device where to run this op matches the key in 'device_branch_fns', 3647 the corresponding callable is executed, falling back to 'default_fn' if none 3648 matches. 3649 3650 **Example:** 3651 ```python 3652 def f1(): return tf.constant(1) 3653 def f2(): return tf.constant(2) 3654 r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1) 3655 ``` 3656 'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on 3657 any other device types. 3658 3659 3660 Args: 3661 device_branch_fns: a dictionary of device types to the callables. Each 3662 callable must return a matching structure of tensors. 3663 default_fn: fallback callable when the underlying device does not match any 3664 key in the 'device_branch_fns'. 3665 name: A name for this operation (optional). 3666 3667 Returns: 3668 The tensors returned by the callable identified by device type during 3669 execution, or those returned by 'default_fn' if no key matches. 3670 """ 3671 # Always execute the default fn for XLA to avoid complicated graph by case op. 3672 # see more discussions in b/167276293. 3673 is_in_xla = util.GraphOrParentsInXlaContext(ops.get_default_graph()) 3674 if is_in_xla: 3675 return default_fn() 3676 device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()} 3677 branch_fns = list(device_branch_fns_upper.values()) 3678 devices = list(device_branch_fns_upper.keys()) 3679 device_index = gen_functional_ops.device_index(device_names=devices) 3680 return _indexed_case_helper( 3681 branch_fns, 3682 default_fn, 3683 device_index, 3684 name, 3685 lower_using_switch_merge=False) 3686 3687 3688class XLAControlFlowContext(ControlFlowContext): 3689 """Base class for XLA and TPU control flow contexts.""" 3690 3691 def __init__(self): 3692 super(XLAControlFlowContext, self).__init__() 3693 self._name = "XLAControlFlowContext" 3694 3695 def to_control_flow_context_def(self, context_def, export_scope=None): 3696 # pylint: disable=useless-super-delegation 3697 # NOTE(slebedev): the method is required by `ControlFlowContext`. 3698 super(XLAControlFlowContext, 3699 self).to_control_flow_context_def(context_def, export_scope) 3700 3701 def IsXLAContext(self): 3702 return True 3703 3704 def AddOp(self, _): 3705 pass 3706 3707 def AddValue(self, x): 3708 return x 3709 3710 def RequiresUniqueFunctionRetracing(self): 3711 """Returns whether the tf.function should be retraced if the context changes. 3712 """ 3713 return False 3714 3715 3716@tf_export("__internal__.get_enclosing_xla_context", v1=[]) 3717def get_enclosing_xla_context(): 3718 """Recursively find and return the XLAControlFlowContext.""" 3719 graph = ops.get_default_graph() 3720 while graph is not None: 3721 # pylint: disable=protected-access 3722 context_ = graph._get_control_flow_context() 3723 # pylint: enable=protected-access 3724 while context_ is not None: 3725 if isinstance(context_, XLAControlFlowContext): 3726 return context_ 3727 context_ = context_.outer_context 3728 # This may be a FuncGraph due to defuns or v2 control flow. We need to 3729 # find the original graph with the XLAControlFlowContext. 3730 graph = getattr(graph, "outer_graph", None) 3731 return None 3732 3733 3734def from_control_flow_context_def(context_def, import_scope=None): 3735 """Deserializes `context_def` into the appropriate ControlFlowContext. 3736 3737 Args: 3738 context_def: ControlFlowContextDef proto 3739 import_scope: Optional `string`. Name scope to add. 3740 3741 Returns: 3742 A ControlFlowContext subclass 3743 """ 3744 if context_def.HasField("cond_ctxt"): 3745 return CondContext.from_proto( 3746 context_def.cond_ctxt, import_scope=import_scope) 3747 if context_def.HasField("while_ctxt"): 3748 return WhileContext.from_proto( 3749 context_def.while_ctxt, import_scope=import_scope) 3750 raise NotImplementedError("Unknown ControlFlowContextDef field: %s" % 3751 context_def.WhichOneof("ctxt")) 3752 3753 3754ops.register_proto_function( 3755 ops.GraphKeys.COND_CONTEXT, 3756 proto_type=control_flow_pb2.CondContextDef, 3757 to_proto=CondContext.to_proto, 3758 from_proto=CondContext.from_proto) 3759 3760ops.register_proto_function( 3761 ops.GraphKeys.WHILE_CONTEXT, 3762 proto_type=control_flow_pb2.WhileContextDef, 3763 to_proto=WhileContext.to_proto, 3764 from_proto=WhileContext.from_proto) 3765