1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Control Flow Operations. 16 17See the [autograph](https://www.tensorflow.org/guide/autographs) guide. 18""" 19# pylint: disable=g-bad-name 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import abc 25import collections 26import functools 27 28import six 29 30from tensorflow.core.framework import attr_value_pb2 31from tensorflow.core.protobuf import control_flow_pb2 32from tensorflow.python.eager import context 33from tensorflow.python.framework import composite_tensor 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.framework import tensor_util 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import control_flow_util as util 42from tensorflow.python.ops import gen_array_ops 43from tensorflow.python.ops import gen_control_flow_ops 44from tensorflow.python.ops import gen_data_flow_ops 45from tensorflow.python.ops import gen_logging_ops 46from tensorflow.python.ops import gen_resource_variable_ops 47from tensorflow.python.ops import math_ops 48from tensorflow.python.ops import tensor_array_ops 49# go/tf-wildcard-import 50# pylint: disable=wildcard-import,undefined-variable 51from tensorflow.python.ops.gen_control_flow_ops import * 52# pylint: enable=wildcard-import 53from tensorflow.python.platform import tf_logging as logging 54from tensorflow.python.util import compat 55from tensorflow.python.util import deprecation 56from tensorflow.python.util import nest 57from tensorflow.python.util import tf_should_use 58from tensorflow.python.util.lazy_loader import LazyLoader 59from tensorflow.python.util.tf_export import tf_export 60 61# This is to avoid a circular dependency: 62# cond_v2 -> gradients_util -> control_flow_ops 63cond_v2 = LazyLoader("cond_v2", globals(), 64 "tensorflow.python.ops.cond_v2") 65 66# This is to avoid circular dependencies: 67# while_v2 -> control_flow_ops 68# while_v2 -> gradients_util -> control_flow_ops 69while_v2 = LazyLoader("while_v2", globals(), 70 "tensorflow.python.ops.while_v2") 71 72# We override the 'tuple' for a control flow op, so we keep python's 73# existing 'tuple' for later use in this module. 74_basetuple = tuple 75 76 77def _summarize_eager(tensor, summarize=None): 78 """Returns a summarized string representation of eager `tensor`. 79 80 Args: 81 tensor: EagerTensor to summarize 82 summarize: Include these many first elements of `array` 83 """ 84 # Emulate the behavior of Tensor::SummarizeValue() 85 if summarize is None: 86 summarize = 3 87 elif summarize < 0: 88 summarize = array_ops.size(tensor) 89 # reshape((-1,)) is the fastest way to get a flat array view 90 if tensor._rank(): # pylint: disable=protected-access 91 flat = tensor.numpy().reshape((-1,)) 92 lst = [str(x) for x in flat[:summarize]] 93 if len(lst) < flat.size: 94 lst.append("...") 95 else: 96 # tensor.numpy() returns a scalar for zero dimensional arrays 97 if summarize != 0: 98 lst = [str(tensor.numpy())] 99 else: 100 lst = [] 101 102 return ", ".join(lst) 103 104 105# pylint: disable=protected-access 106 107 108# Assert and Print are special symbols in python, so we must 109# use an upper-case version of them. 110@tf_export("debugging.Assert", "Assert") 111@tf_should_use.should_use_result 112def Assert(condition, data, summarize=None, name=None): 113 """Asserts that the given condition is true. 114 115 If `condition` evaluates to false, print the list of tensors in `data`. 116 `summarize` determines how many entries of the tensors to print. 117 118 NOTE: In graph mode, to ensure that Assert executes, one usually attaches 119 a dependency: 120 121 ```python 122 # Ensure maximum element of x is smaller or equal to 1 123 assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x]) 124 with tf.control_dependencies([assert_op]): 125 ... code using x ... 126 ``` 127 128 Args: 129 condition: The condition to evaluate. 130 data: The tensors to print out when condition is false. 131 summarize: Print this many entries of each tensor. 132 name: A name for this operation (optional). 133 134 Returns: 135 assert_op: An `Operation` that, when executed, raises a 136 `tf.errors.InvalidArgumentError` if `condition` is not true. 137 @compatibility{eager} returns None. 138 139 Raises: 140 @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition` 141 is not true 142 """ 143 if context.executing_eagerly(): 144 if not condition: 145 xs = ops.convert_n_to_tensor(data) 146 data_str = [_summarize_eager(x, summarize) for x in xs] 147 raise errors.InvalidArgumentError( 148 node_def=None, 149 op=None, 150 message="Expected '%s' to be true. Summarized data: %s" % 151 (condition, "\n".join(data_str))) 152 return 153 154 with ops.name_scope(name, "Assert", [condition, data]) as name: 155 xs = ops.convert_n_to_tensor(data) 156 if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs): 157 # As a simple heuristic, we assume that string and int32 are 158 # on host to avoid the need to use cond. If it is not case, 159 # we will pay the price copying the tensor to host memory. 160 return gen_logging_ops._assert(condition, data, summarize, name="Assert") 161 else: 162 condition = ops.convert_to_tensor(condition, name="Condition") 163 164 def true_assert(): 165 return gen_logging_ops._assert( 166 condition, data, summarize, name="Assert") 167 168 guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard") 169 if context.executing_eagerly(): 170 return 171 return guarded_assert.op 172 173 174def _Identity(data, name=None): 175 """Return a tensor with the same shape and contents as the input tensor. 176 177 Args: 178 data: A Tensor. 179 name: A name for this operation (optional). 180 181 Returns: 182 A Tensor with the same type and value as the input Tensor. 183 """ 184 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 185 if isinstance(data, ops.Tensor): 186 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 187 return gen_array_ops.ref_identity(data, name=name) 188 else: 189 return array_ops.identity(data, name=name) 190 elif isinstance(data, composite_tensor.CompositeTensor): 191 return nest.map_structure(_Identity, data, expand_composites=True) 192 else: 193 raise TypeError("Type %s not supported" % type(data)) 194 195 196def _NextIteration(data, name=None): 197 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 198 if isinstance(data, ops.Tensor): 199 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 200 return ref_next_iteration(data, name=name) 201 else: 202 return next_iteration(data, name=name) 203 elif isinstance(data, composite_tensor.CompositeTensor): 204 return nest.map_structure(_NextIteration, data, expand_composites=True) 205 else: 206 raise TypeError("Type %s not supported" % type(data)) 207 208 209def _Enter(data, 210 frame_name, 211 is_constant=False, 212 parallel_iterations=10, 213 use_ref=True, 214 use_input_shape=True, 215 name=None): 216 """Creates or finds a child frame, and makes `data` available to it. 217 218 The unique `frame_name` is used by the `Executor` to identify frames. If 219 `is_constant` is true, `data` is a constant in the child frame; otherwise 220 it may be changed in the child frame. At most `parallel_iterations` 221 iterations are run in parallel in the child frame. 222 223 Args: 224 data: The tensor to be made available to the child frame. 225 frame_name: The name of the child frame. 226 is_constant: If true, the output is constant within the child frame. 227 parallel_iterations: The number of iterations allowed to run in parallel. 228 use_ref: If true, use ref_enter if data is of ref type. 229 use_input_shape: If true, set the result's shape based on data's shape. 230 name: A name for this operation (optional). 231 232 Returns: 233 The same tensor as `data`. 234 """ 235 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 236 if isinstance(data, ops.Tensor): 237 if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access 238 result = gen_control_flow_ops.ref_enter( 239 data, frame_name, is_constant, parallel_iterations, name=name) 240 else: 241 result = gen_control_flow_ops.enter( 242 data, frame_name, is_constant, parallel_iterations, name=name) 243 if use_input_shape: 244 result.set_shape(data.get_shape()) 245 return result 246 elif isinstance(data, composite_tensor.CompositeTensor): 247 def enter_component(t): 248 return _Enter(t, frame_name, is_constant, parallel_iterations, 249 use_ref, use_input_shape) 250 return nest.map_structure(enter_component, data, expand_composites=True) 251 else: 252 raise TypeError("Type %s not supported" % type(data)) 253 254 255def exit(data, name=None): # pylint: disable=redefined-builtin 256 """Exits the current frame to its parent frame. 257 258 Exit makes its input `data` available to the parent frame. 259 260 Args: 261 data: The tensor to be made available to the parent frame. 262 name: A name for this operation (optional). 263 264 Returns: 265 The same tensor as `data`. 266 """ 267 data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True) 268 if isinstance(data, ops.Tensor): 269 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 270 return gen_control_flow_ops.ref_exit(data, name) 271 else: 272 return gen_control_flow_ops._exit(data, name) 273 elif isinstance(data, composite_tensor.CompositeTensor): 274 return nest.map_structure(exit, data, expand_composites=True) 275 else: 276 raise TypeError("Type %s not supported" % type(data)) 277 278 279def switch(data, pred, dtype=None, name=None): 280 """Forwards `data` to an output determined by `pred`. 281 282 If `pred` is false, the `data` input is forwarded to the first output. 283 Otherwise, the data goes to the second output. 284 285 This op handles `Tensor`s and `IndexedSlices`. 286 287 Args: 288 data: The tensor to be forwarded to the appropriate output. 289 pred: A scalar that specifies which output port will receive data. 290 dtype: Optional element type for the returned tensor. If missing, the type 291 is inferred from the type of `value`. 292 name: A name for this operation (optional). 293 294 Returns: 295 `(output_false, output_true)`: If `pred` is true, data will be forwarded 296 to `output_true`, otherwise it goes to `output_false`. 297 """ 298 with ops.name_scope(name, "Switch", [data, pred]) as name: 299 data = ops.internal_convert_to_tensor_or_composite( 300 data, dtype=dtype, name="data", as_ref=True) 301 pred = ops.convert_to_tensor(pred, name="pred") 302 if isinstance(data, ops.Tensor): 303 return gen_control_flow_ops.switch(data, pred, name=name) 304 else: 305 if not isinstance(data, composite_tensor.CompositeTensor): 306 raise TypeError("Type %s not supported" % type(data)) 307 tensors = nest.flatten(data, expand_composites=True) 308 mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors] 309 mapped_f, mapped_t = zip(*mapped) 310 return (nest.pack_sequence_as(data, mapped_f, expand_composites=True), 311 nest.pack_sequence_as(data, mapped_t, expand_composites=True)) 312 313 314def _SwitchRefOrTensor(data, pred, name="Switch"): 315 """Forwards `data` to an output determined by `pred`. 316 317 If `pred` is false, the `data` input is forwarded to the first output. 318 Otherwise, the data goes to the second output. 319 320 This op handles `Tensor`s and `IndexedSlices`. 321 322 Args: 323 data: The tensor to be forwarded to the appropriate output. 324 pred: A scalar that specifies which output port will receive data. 325 name: A name for this operation (optional). 326 327 Returns: 328 `(output_false, output_true)`: If `pred` is true, data will be forwarded to 329 `output_true`, otherwise it goes to `output_false`. 330 331 Raises: 332 TypeError: if data is not a Tensor or IndexedSlices 333 """ 334 data = ops.convert_to_tensor_or_composite(data, name="data") 335 # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below 336 # addresses the following scenario. 337 # 338 # Assume you execute Optimizer.apply_gradients() in a branch of a cond(). 339 # 340 # 1. The update op is created inside a `with ops.colocate(var):` block 341 # 342 # 2. Some tensor `data` is captured and a switch is created in a 343 # `with ops.colocate_with(data):` block. 344 # 345 # with ops.colocate_with(var): 346 # with ops.colocate_with(data): 347 # op = ... 348 # 349 # var and data may be pinned to different devices, so we want to ops 350 # created within ops.colocate_with(data) to ignore the existing stack. 351 with ops.colocate_with(data, ignore_existing=True): 352 if isinstance(data, ops.Tensor): 353 if data.dtype._is_ref_dtype: # pylint: disable=protected-access 354 return ref_switch(data, pred, name=name) 355 return switch(data, pred, name=name) 356 357 358def merge(inputs, name=None): 359 """Returns the value of an available element of `inputs`. 360 361 This op tests each of the tensors in `inputs` in turn to determine if any of 362 them is available. If it finds an available tensor, it returns it and its 363 index in `inputs`. 364 365 It is an error if more than one tensor in `inputs` is available. If no tensor 366 in `inputs` is available, the returned tensor and index are not set. 367 368 This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of 369 `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices 370 before merging. 371 372 Args: 373 inputs: The input tensors, at most one of which is available. 374 name: A name for this operation (optional). 375 376 Returns: 377 A tuple containing the chosen input tensor and its index in `inputs`. 378 379 Raises: 380 ValueError: If any of the inputs is None, or inputs are IndexedSlices and 381 some but not all have a dense_shape property. 382 """ 383 if any(inp is None for inp in inputs): 384 raise ValueError("At least one of the merge inputs is None: %s" % inputs) 385 with ops.name_scope(name, "Merge", inputs) as name: 386 inputs = [ 387 ops.internal_convert_to_tensor_or_composite(inp, as_ref=True) 388 for inp in inputs 389 ] 390 if all(isinstance(v, ops.Tensor) for v in inputs): 391 if all(v.dtype._is_ref_dtype for v in inputs): # pylint: disable=protected-access 392 return gen_control_flow_ops.ref_merge(inputs, name) 393 else: 394 return gen_control_flow_ops.merge(inputs, name) 395 else: 396 # If there is a mix of tensors and indexed slices, then convert the 397 # tensors to indexed slices. 398 if all(isinstance(v, (ops.IndexedSlices, ops.Tensor)) for v in inputs): 399 inputs = math_ops._as_indexed_slices_list(inputs, optimize=False) 400 401 for v in inputs: 402 if not isinstance(v, composite_tensor.CompositeTensor): 403 raise TypeError("Type %s not supported" % type(v)) 404 405 for v in inputs[1:]: 406 nest.assert_same_structure(inputs[0], v, expand_composites=True) 407 408 flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs] 409 merged_results = [gen_control_flow_ops.merge(component) 410 for component in zip(*flat_inputs)] 411 flat_merged = [tensor for (tensor, _) in merged_results] 412 chosen_index = merged_results[0][1] 413 merged_inputs = nest.pack_sequence_as(inputs[0], flat_merged, 414 expand_composites=True) 415 return (merged_inputs, chosen_index) 416 417 418# pylint: enable=protected-access 419 420 421def _convert_tensorarray_to_flow(tensor_or_tensor_array): 422 if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray): 423 return tensor_or_tensor_array.flow 424 else: 425 return tensor_or_tensor_array 426 427 428def _make_tensor_array(ta, t_or_flow): 429 # pylint: disable=protected-access 430 new_ta = tensor_array_ops.TensorArray( 431 dtype=ta.dtype, 432 handle=ta.handle, 433 flow=t_or_flow, 434 infer_shape=ta._infer_shape, 435 colocate_with_first_write_call=ta._colocate_with_first_write_call) 436 new_ta._colocate_with = ta._colocate_with 437 new_ta._element_shape = ta._element_shape 438 # pylint: enable=protected-access 439 return new_ta 440 441 442def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows): 443 if len(tensors_or_tensorarrays) != len(tensors_or_flows): 444 raise ValueError( 445 "Lengths of original Tensor list and new list do not match: %d vs. %d" % 446 (len(tensors_or_tensorarrays), len(tensors_or_flows))) 447 return [ 448 _make_tensor_array(ta, t_or_flow) if isinstance( 449 ta, tensor_array_ops.TensorArray) else t_or_flow 450 for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows) 451 ] 452 453 454def _ShapeLessThanOrEqual(shape1, shape2): 455 if shape2.dims is None: 456 return True 457 if shape1.ndims != shape2.ndims: 458 return False 459 for dim1, dim2 in zip(shape1.dims, shape2.dims): 460 if dim2.value is not None and dim1.value != dim2.value: 461 return False 462 return True 463 464 465def _get_shape_invariant(var, shape=None): 466 """Returns a shape invariant for the given variable. 467 468 If `var` is a `CompositeTensor`, then this uses 469 `_shape_invariant_to_components()` to get shape invariants for the 470 component tensors. 471 472 Args: 473 var: The tensor whose shape is described. 474 shape: The shape invariant for the tensor. If not specified, then a default 475 shape invariant for `var` is returned. 476 477 Returns: 478 The shape invariant for `var` (if it is a `Tensor`), or the shape invariants 479 for the components that comprise `var` (if it is a `CompositeTensor`). 480 """ 481 if isinstance(var, composite_tensor.CompositeTensor): 482 return var._shape_invariant_to_components(shape) # pylint: disable=protected-access 483 elif shape is None: 484 return var.shape 485 else: 486 return shape 487 488 489def _SetShapeInvariants(input_vars, enter_vars, shapes): 490 """Set the shapes of the tensors in `enter_vars` to `shapes`. 491 492 Args: 493 input_vars: A list of tensors that are inputs to `enter_vars`. 494 enter_vars: A list of tensors whose shapes will be set. 495 shapes: A (possibly nested) list of shapes. 496 497 Raises: 498 ValueError: If any tensor in `enter_vars` has a less specific shape 499 than its corresponding shape in `shapes`. 500 """ 501 if shapes is None: 502 return 503 flat_shapes = nest.flatten(shapes) 504 if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes): 505 raise ValueError("`shapes` must be a (possibly nested) list of shapes.") 506 # Check that the shapes of the inputs are less than the shape invariants, 507 # and set the shapes of `enter_vars` to the shape invariants. 508 for inp, var, shape in zip(input_vars, enter_vars, flat_shapes): 509 if isinstance(var, ops.Tensor): 510 if not _ShapeLessThanOrEqual(inp.get_shape(), shape): 511 raise ValueError( 512 "The shape invariant specified for %s is not compatible with " 513 "the initial shape of the loop variable. It enters the loop " 514 "with shape %s, but the specified shape invariant is %s." % 515 (inp.name, inp.get_shape(), shape)) 516 var.set_shape(shape) 517 else: 518 raise TypeError("Type %s not supported" % type(var)) 519 520 521def _EnforceShapeInvariant(merge_var, next_var): 522 """Check if the shapes of the loops variables are invariants. 523 524 Args: 525 merge_var: The list of tensors representing the initial values of the loop 526 variables. 527 next_var: The list of tensors representing the values of the loop variables 528 after one loop iteration. 529 530 Raises: 531 ValueError: If any tensor in `merge_var` has a more specific shape than 532 its correspnding tensor in `next_var`. 533 """ 534 if isinstance(merge_var, ops.Tensor): 535 m_shape = merge_var.get_shape() 536 n_shape = next_var.get_shape() 537 if not _ShapeLessThanOrEqual(n_shape, m_shape): 538 enter = merge_var.op.inputs[0].op 539 assert util.IsLoopEnter(enter) 540 input_t = enter.inputs[0] 541 raise ValueError( 542 "Input tensor '%s' enters the loop with shape %s, but has shape %s " 543 "after one iteration. To allow the shape to vary across iterations, " 544 "use the `shape_invariants` argument of tf.while_loop to specify a " 545 "less-specific shape." % (input_t.name, input_t.shape, n_shape)) 546 else: 547 raise TypeError("Type %s not supported" % type(merge_var)) 548 549 550def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True): 551 """Add NextIteration and back edge from v to m.""" 552 if isinstance(m, ops.Tensor): 553 v = ops.convert_to_tensor(v) 554 v = _NextIteration(v) 555 if enforce_shape_invariant: 556 # Make sure the shapes of loop outputs are correct. We do this before 557 # calling _update_input, which will raise a less-helpful error message if 558 # the types don't match. 559 # TODO(skyewm): call this for other cases below (needs testing) 560 _EnforceShapeInvariant(m, v) 561 m.op._update_input(1, v) # pylint: disable=protected-access 562 elif isinstance(m, composite_tensor.CompositeTensor): 563 # pylint: disable=protected-access 564 def update_component(m_component, v_component): 565 m_component.op._update_input(1, v_component) 566 if isinstance(m, ops.IndexedSlices): 567 v = math_ops._as_indexed_slices(v, optimize=False) 568 # pylint: enable=protected-access 569 v = _NextIteration(v) 570 return nest.map_structure(update_component, m, v, expand_composites=True) 571 else: 572 raise TypeError("Type %s not supported" % type(m)) 573 return v 574 575 576def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt): 577 """Calculate a max_size for use by stack ops inside an XLA while_loop. 578 579 Args: 580 value: The value inside the while_loop forward context. Used for printing 581 error messages. 582 while_ctxt: The forward context inside which value resides. This does not 583 always match the value's immediate context, as `value` may be inside e.g. 584 a cond context inside the while_loop. 585 586 Returns: 587 A tensor containing the `max_size` to feed to a Stack initializer. 588 589 Raises: 590 ValueError: If `value` is nested inside a `while_loop` that either 591 lacks a `maximum_iterations` parameter, or the `maximum_iterations` 592 parameter: 593 594 - is inside a `while_loop` that is a parent of the calling context, and 595 - cannot be evaluated at graph build time to a constant. 596 """ 597 value_name = value.name 598 # curr_ctxt is the context that tf.gradients was called in. 599 curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access 600 601 curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else "" 602 max_size = constant_op.constant(1) 603 604 # Loop through all containing while contexts between value and the 605 # current context, multiplying together each context's 606 # max_iterations to get the maximum stack size. 607 while while_ctxt not in (None, curr_ctxt): 608 max_iter = while_ctxt.maximum_iterations 609 if max_iter is None: 610 raise ValueError( 611 "Cannot create a gradient accumulator for tensor '%s' inside " 612 "XLA while_loop because maximum_iterations was not passed to " 613 "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name)) 614 615 # pylint: disable=protected-access 616 max_iter_ctxt = max_iter.op._get_control_flow_context() 617 # pylint: enable=protected-access 618 619 # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use. 620 if util.IsContainingContext(curr_ctxt, max_iter_ctxt): 621 max_size *= max_iter 622 else: 623 # We cannot use max_iter because it's defined in a nested while 624 # or cond context, so will fail if we try to use it as input to 625 # any ops in curr_ctxt (e.g. max_size or the final accumulator 626 # stack). Attempt to get a constant value out to use instead. 627 const_max_iter = tensor_util.constant_value(max_iter) 628 if const_max_iter is None: 629 raise ValueError( 630 "Cannot create a gradient accumulator for tensor '%s' inside XLA " 631 "while_loop. maximum_iterations tensor '%s' for while_loop context " 632 "'%s' must be statically known (e.g. a constant value or known " 633 "shape dimension), or be defined at or outside the while loop " 634 "context '%s' (currently defined in '%s')." % 635 (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name, 636 max_iter_ctxt.name)) 637 max_size *= const_max_iter 638 639 # Find the next outer WhileContext (or stop if we reach the 640 # tf.gradient's context). 641 while_ctxt = util.GetContainingWhileContext( 642 while_ctxt.outer_context, stop_ctxt=curr_ctxt) 643 644 return max_size 645 646 647class GradLoopState(object): 648 """The state used for constructing the gradient graph for a while loop. 649 650 We create a GradLoopState for each while loop in forward and its 651 corresponding while loop in backprop. This gives us access to both 652 the forward and the backprop WhileContexts. 653 654 During the construction of gradient graph, any time when we detect 655 a forward value that is needed for backprop, we create a history 656 accumulator and add it to `history_map`. Any time when we backprop 657 a loop switch op (in _SwitchGrad), we add the grad merge op in 658 `switch_map`. 659 """ 660 661 def __init__(self, forward_ctxt, outer_grad_state): 662 # The grad loop state for the outer while loop. 663 self._outer_grad_state = None 664 665 # The while loop context for forward. 666 self._forward_context = None 667 668 # The loop counter added by AddForwardLoopCounter. It is the value 669 # of the loop counter for the next iteration. 670 self._forward_index = None 671 672 # A sync op for forward. 673 self._forward_sync = None 674 675 # The while loop context for backprop. 676 self._grad_context = None 677 678 # The loop counter added by AddBackpropLoopCounter. It is the value 679 # of the loop counter for the current iteration. 680 self._grad_index = None 681 682 # A sync op for backprop. 683 self._grad_sync = None 684 685 # Information needed by backprop. 686 self._history_map = {} 687 self._switch_map = {} 688 self._unused_exits = [] 689 self._deferred_exits = [] 690 self._forward_loop_exits = list(forward_ctxt.loop_exits) 691 self._pending_exits_count = len(forward_ctxt.loop_exits) 692 693 self._outer_grad_state = outer_grad_state 694 if outer_grad_state: 695 outer_forward_ctxt = outer_grad_state.forward_context 696 else: 697 if not hasattr(forward_ctxt, "outer_context"): 698 raise ValueError("Failed to call gradients on a while loop without" 699 "properly serializing graph via MetaGraphDef") 700 outer_forward_ctxt = forward_ctxt.outer_context 701 702 # Add the forward loop counter. 703 with forward_ctxt._graph.as_default(): # pylint: disable=protected-access 704 if outer_forward_ctxt: 705 outer_forward_ctxt.Enter() 706 cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) 707 if outer_forward_ctxt: 708 outer_forward_ctxt.Exit() 709 self._forward_context = forward_ctxt 710 self._forward_index = forward_index 711 712 # Add the backprop WhileContext, and the backprop loop counter. 713 if outer_grad_state: 714 # This is a nested loop. Remember the iteration counts for each 715 # execution of this inner loop. 716 outer_forward_ctxt.AddName(cnt.name) 717 history_cnt = outer_grad_state.AddForwardAccumulator(cnt) 718 719 outer_grad_ctxt = outer_grad_state.grad_context 720 outer_grad_ctxt.Enter() 721 self._grad_context = WhileContext( 722 maximum_iterations=forward_ctxt.maximum_iterations, 723 parallel_iterations=forward_ctxt.parallel_iterations, 724 back_prop=forward_ctxt.back_prop, 725 swap_memory=forward_ctxt.swap_memory, 726 name=forward_ctxt.name, 727 grad_state=self) 728 real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) 729 self._grad_index = self._grad_context.AddBackpropLoopCounter( 730 real_cnt, outer_grad_state) 731 outer_grad_ctxt.Exit() 732 else: 733 if outer_forward_ctxt: 734 outer_forward_ctxt.Enter() 735 self._grad_context = WhileContext( 736 maximum_iterations=forward_ctxt.maximum_iterations, 737 parallel_iterations=forward_ctxt.parallel_iterations, 738 back_prop=forward_ctxt.back_prop, 739 swap_memory=forward_ctxt.swap_memory, 740 name=forward_ctxt.name, 741 grad_state=self) 742 self._grad_index = self._grad_context.AddBackpropLoopCounter( 743 cnt, outer_grad_state) 744 if outer_forward_ctxt: 745 outer_forward_ctxt.Exit() 746 747 @property 748 def outer_grad_state(self): 749 """The grad loop state for outer loop.""" 750 return self._outer_grad_state 751 752 @property 753 def forward_context(self): 754 """The while loop context for forward.""" 755 return self._forward_context 756 757 @property 758 def forward_index(self): 759 """The loop index of forward loop.""" 760 return self._forward_index 761 762 @property 763 def forward_sync(self): 764 """A control trigger node for synchronization in the forward loop. 765 766 One main use is to keep the push ops of a stack executed in the 767 iteration order. 768 """ 769 if self._forward_sync is None: 770 with ops.control_dependencies(None): 771 self._forward_sync = control_trigger(name="f_sync") 772 self._forward_sync._set_control_flow_context(self._forward_context) 773 self._forward_index.op._add_control_input(self._forward_sync) 774 return self._forward_sync 775 776 @property 777 def grad_context(self): 778 """The corresponding WhileContext for gradient.""" 779 return self._grad_context 780 781 @property 782 def grad_index(self): 783 """The loop index of backprop loop.""" 784 return self._grad_index 785 786 @property 787 def grad_sync(self): 788 """A control trigger node for synchronization in the grad loop. 789 790 One main use is to keep the pop ops of a stack executed in the 791 iteration order. 792 """ 793 if self._grad_sync is None: 794 with ops.control_dependencies(None): 795 self._grad_sync = control_trigger(name="b_sync") 796 self._grad_sync._set_control_flow_context(self._grad_context) 797 self._grad_index.op._add_control_input(self._grad_sync) 798 if self._grad_context.outer_context: 799 self._grad_context.outer_context.AddInnerOp(self._grad_sync) 800 return self._grad_sync 801 802 @property 803 def history_map(self): 804 """The map that records all the tensors needed for backprop.""" 805 return self._history_map 806 807 @property 808 def switch_map(self): 809 """The map that records all the Switch ops for the while loop.""" 810 return self._switch_map 811 812 @property 813 def unused_exits(self): 814 """The list of "unused" exits.""" 815 return self._unused_exits 816 817 @property 818 def deferred_exits(self): 819 """The list of "deferred" exits.""" 820 return self._deferred_exits 821 822 @property 823 def forward_loop_exits(self): 824 """The list of exits of the forward loop.""" 825 return self._forward_loop_exits 826 827 @property 828 def pending_exits_count(self): 829 """The number of exits we expect to see but haven't.""" 830 return self._pending_exits_count 831 832 @pending_exits_count.setter 833 def pending_exits_count(self, cnt): 834 """Set the pending count to cnt.""" 835 self._pending_exits_count = cnt 836 837 def AddForwardAccumulator(self, value, dead_branch=False): 838 """Add an accumulator for each forward tensor that is needed in backprop. 839 840 This is added to the forward loop at the first time when a tensor 841 in the forward loop is used by backprop gradient computation loop. 842 We create an accumulator that accumulates the value of tensor at each 843 iteration. Called in the control flow context where gradients() is called. 844 845 The pseudocode is: 846 ``` 847 acc = stack(); 848 while (_pivot) { 849 acc = stack_push(acc, value); 850 } 851 ``` 852 853 We make sure that the stack push op in one iteration is executed before 854 next iteration. This is achieved by adding a control edge from 855 `forward_index.op.inputs[0].op` to the push op, and another control 856 edge from the push op to either `forward_index.op` or `forward_sync`. 857 858 Args: 859 value: The source tensor in forward that is to be accumulated. 860 dead_branch: True iff the tensor is on a dead branch of a cond. 861 862 Returns: 863 The stack that contains the accumulated history of the tensor. 864 865 Raises: 866 TypeError: For internal errors involving the value condition context. 867 ValueError: If `value` is inside a XLA scope and a valid max size 868 for the stack can't be found. 869 """ 870 # curr_ctxt is the context that tf.gradients was called in. 871 with self._forward_index.graph.as_default(): 872 curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access 873 with ops.control_dependencies(None): 874 if curr_ctxt: 875 curr_ctxt.Enter() 876 with ops.colocate_with(value): 877 # We only need to pass maximum_iterations to the stack if 878 # we're inside an XLA context. 879 if not util.IsInXLAContext(value.op): 880 max_size = constant_op.constant(-1, dtypes.int32) 881 else: 882 max_size = GetMaxSizeFromNestedMaximumIterations( 883 value, self.forward_context) 884 acc = gen_data_flow_ops.stack_v2( 885 max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") 886 if curr_ctxt: 887 curr_ctxt.Exit() 888 889 # Make acc available in the forward context. 890 enter_acc = self.forward_context.AddValue(acc) 891 892 # Add the stack_push op in the context of value.op. 893 swap_enabled = self.forward_context.swap_memory 894 value_ctxt = util.GetOutputContext(value.op) 895 if value_ctxt == self.forward_context: 896 # value is not nested in the forward context. 897 self.forward_context.Enter() 898 push = gen_data_flow_ops.stack_push_v2( 899 enter_acc, value, swap_memory=swap_enabled) 900 self.forward_context.Exit() 901 # Protect stack push and order it before forward_index. 902 self.forward_index.op._add_control_input(push.op) 903 else: 904 # value is in a cond context within the forward context. 905 if not isinstance(value_ctxt, CondContext): 906 raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) 907 if dead_branch: 908 # The special case for creating a zero tensor for a dead 909 # branch of a switch. See ControlFlowState.ZerosLike(). 910 value_ctxt.outer_context.Enter() 911 push = gen_data_flow_ops.stack_push_v2( 912 enter_acc, value, swap_memory=swap_enabled) 913 value_ctxt.outer_context.Exit() 914 push.op._set_control_flow_context(value_ctxt) 915 else: 916 value_ctxt.Enter() 917 push = gen_data_flow_ops.stack_push_v2( 918 enter_acc, value, swap_memory=swap_enabled) 919 value_ctxt.Exit() 920 # Protect stack push and order it before forward_sync. 921 self.forward_sync._add_control_input(push.op) 922 # Order stack push after the successor of forward_index 923 add_op = self.forward_index.op.inputs[0].op 924 push.op._add_control_input(add_op) 925 return acc 926 927 def AddBackpropAccumulatedValue(self, history_value, value, 928 dead_branch=False): 929 """Add the getter for an accumulated value in the grad context. 930 931 This is added to the backprop loop. Called in the grad context to 932 get the value of an accumulated value. The stack pop op must be guarded 933 by the pred of the controlling cond. 934 935 Args: 936 history_value: The history (a stack) of a value. 937 value: The value that is pushed onto the stack. 938 dead_branch: True iff the tensor is on a dead branch of a cond. 939 940 Returns: 941 The current value (the top of the stack). 942 """ 943 history_ctxt = history_value.op._get_control_flow_context() 944 # Find the cond context that controls history_value if any. 945 cond_ctxt = None 946 value_ctxt = value.op._get_control_flow_context() 947 while value_ctxt and value_ctxt != history_ctxt: 948 if isinstance(value_ctxt, CondContext): 949 cond_ctxt = value_ctxt 950 break 951 value_ctxt = value_ctxt.outer_context 952 with ops.control_dependencies(None): 953 self.grad_context.Enter() 954 if cond_ctxt: 955 # Guard stack pop with a switch if it is controlled by a cond. 956 grad_state = self 957 pred = None 958 while pred is None and grad_state: 959 pred = grad_state.history_map.get(cond_ctxt.pred.name) 960 grad_state = grad_state.outer_grad_state 961 if pred is None: 962 pred = cond_ctxt.pred 963 branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch 964 history_value = _SwitchRefOrTensor(history_value, pred)[branch] 965 pop = gen_data_flow_ops.stack_pop_v2(history_value, 966 value.dtype.base_dtype) 967 pop.set_shape(value.get_shape()) 968 self.grad_context.Exit() 969 parallel_iterations = self.grad_context.parallel_iterations 970 if parallel_iterations > 1: 971 # All pops are ordered after pivot_for_body and before grad_sync. 972 self.grad_sync._add_control_input(pop.op) 973 return pop 974 975 def GetRealValue(self, value): 976 """Get the real value of `value`. 977 978 If backprop "uses" a value produced by forward inference, an accumulator 979 is added in the forward loop to accumulate its values. We use the 980 accumulated value. This method must be called in the grad loop context. 981 `value` must be in forward and needed for backprop. 982 983 Args: 984 value: A tensor to be captured. 985 986 Returns: 987 The same tensor obtained from the saved history. 988 """ 989 assert value.op.type not in ["Variable", "VariableV2"] 990 real_value = self._history_map.get(value.name) 991 if real_value is None: 992 cur_value = value 993 cur_grad_state = self 994 while True: 995 enter_op = util.GetLoopConstantEnter(cur_value) 996 if enter_op: 997 # Special case: cur_value comes from a constant Enter node. 998 cur_value = enter_op.inputs[0] 999 cur_grad_state = cur_grad_state.outer_grad_state 1000 if cur_grad_state is None: 1001 # We are now outside all nested loops for this gradient(), 1002 # so `value` is a loop invariant and there is no need to 1003 # save the history of value. Just make cur_value to enter 1004 # the right control flow context. 1005 real_value = self._grad_context.AddValue(cur_value) 1006 break 1007 elif constant_op.is_constant(cur_value): 1008 # If the value to be forwarded is a constant, clone the constant in 1009 # the gradient loop rather than using a stack. 1010 # TODO(phawkins): consider hoisting the constant out of the loop 1011 # instead. 1012 real_value = constant_op.constant( 1013 tensor_util.constant_value(cur_value), dtype=cur_value.dtype) 1014 break 1015 else: 1016 # Record the history of this value in forward_ctxt. 1017 self._grad_context.Exit() 1018 history_value = cur_grad_state.AddForwardAccumulator(cur_value) 1019 self._grad_context.Enter() 1020 break 1021 1022 if real_value is None: 1023 # Add the stack pop op in the grad context. 1024 real_value = cur_grad_state.AddBackpropAccumulatedValue( 1025 history_value, cur_value) 1026 if cur_grad_state != self: 1027 real_value = self._grad_context.AddValue(real_value) 1028 self._history_map[value.name] = real_value 1029 return real_value 1030 1031 1032def _GetWhileContext(op): 1033 """Get the WhileContext to which this op belongs.""" 1034 ctxt = op._get_control_flow_context() 1035 if ctxt: 1036 ctxt = ctxt.GetWhileContext() 1037 return ctxt 1038 1039 1040class ControlFlowState(object): 1041 """Maintain the mapping from the loops to their grad states.""" 1042 1043 def __init__(self): 1044 self._map = {} # maps forward loop context to GradLoopState 1045 1046 def GetGradState(self, op, before): 1047 """Return the grad state for this op if it's in a forward loop context.""" 1048 if before and util.IsLoopExit(op): 1049 forward_ctxt = op._get_control_flow_context() 1050 forward_ctxt = forward_ctxt.outer_context 1051 if forward_ctxt: 1052 forward_ctxt = forward_ctxt.GetWhileContext() 1053 else: 1054 forward_ctxt = _GetWhileContext(op) 1055 if forward_ctxt: 1056 return self._map.get(forward_ctxt) 1057 return None 1058 1059 def ProcessUnusedLoopExits(self, pending_count, to_ops_set): 1060 """Process all the "unused" loop exits. 1061 1062 The "unused" exits of the loops are added to `unused_exits`. An exit is 1063 unused if its pending_count is 0. If there is an exit with real gradient, 1064 all these deferred exits will enter the backprop loop with zero gradient. 1065 Otherwise, they will enter the backprop loop with None. As an example, 1066 people often write: 1067 1068 ```python 1069 v1, _ = tf.while_loop(p, b, [x1, x2]) 1070 result = gradients(v1, x1) 1071 ``` 1072 1073 The exit node for x2 is not included by the betweenness analysis. But we 1074 need to backprop x2 if x2 is involved in computing v1. 1075 1076 Args: 1077 pending_count: The number of backprop inputs for every op. 1078 to_ops_set: The set of ops for ys in gradients(ys, xs) 1079 1080 Returns: 1081 The set of unused loop exits that we know at this point we need 1082 to backprop. 1083 """ 1084 loop_exits = [] 1085 for grad_state in self._map.values(): 1086 for y in grad_state.forward_loop_exits: 1087 if pending_count[y.op] == 0: 1088 grad_state.pending_exits_count -= 1 1089 if y.op not in to_ops_set: 1090 grad_state.unused_exits.append(y) 1091 if grad_state.pending_exits_count == 0: 1092 loop_exits.extend(grad_state.unused_exits) 1093 # Need to include Enters in backprop for higher-order gradients. 1094 for y in grad_state.forward_context.loop_enters: 1095 if pending_count[y.op] == 0: 1096 pending_count[y.op] = 1 1097 return loop_exits 1098 1099 def EnterGradWhileContext(self, op, before): 1100 """Enter the WhileContext for gradient computation.""" 1101 grad_state = self.GetGradState(op, before) 1102 if grad_state: 1103 grad_state.grad_context.Enter() 1104 1105 def ExitGradWhileContext(self, op, before): 1106 """Exit the WhileContext for gradient computation.""" 1107 grad_state = self.GetGradState(op, before) 1108 if grad_state: 1109 grad_state.grad_context.Exit() 1110 1111 def AddWhileContext(self, op, between_op_list, between_ops): 1112 """Add the grad state for the while loop that op belongs to. 1113 1114 Note that op is an Exit, and this method must be called in 1115 the control flow context where gradients() is called. 1116 1117 Note that this method modifies `between_op_list` and `between_ops`. 1118 """ 1119 forward_ctxt = _GetWhileContext(op) 1120 grad_state = self._map.get(forward_ctxt) 1121 if grad_state is None: 1122 # This is a new while loop so create a grad state for it. 1123 outer_forward_ctxt = forward_ctxt.outer_context 1124 if outer_forward_ctxt: 1125 outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() 1126 outer_grad_state = None 1127 if outer_forward_ctxt: 1128 outer_grad_state = self._map.get(outer_forward_ctxt) 1129 grad_state = GradLoopState(forward_ctxt, outer_grad_state) 1130 self._map[forward_ctxt] = grad_state 1131 1132 # We need to include all exits of a loop for backprop. 1133 for loop_exit in grad_state.forward_loop_exits: 1134 if loop_exit.op not in between_ops: 1135 between_ops.add(loop_exit.op) 1136 between_op_list.append(loop_exit.op) 1137 1138 def ZerosLikeForExit(self, val): 1139 """Create zeros_like gradient for a loop exit. 1140 1141 If the result of a loop variable is not used but is involved in 1142 computing the result of some needed loop variable, we create a 1143 zero-valued tensor that is fed as gradient for the Exit node of that 1144 loop variable. Note that val.op is an Exit, and this method must be 1145 called in the control flow context where gradients() is called. 1146 1147 Args: 1148 val: The output tensor of an Exit op. 1149 1150 Returns: 1151 A zero tensor of the same shape of val. 1152 """ 1153 val_shape = val.get_shape() 1154 forward_ctxt = val.op._get_control_flow_context() 1155 outer_forward_ctxt = forward_ctxt.outer_context 1156 if outer_forward_ctxt: 1157 outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() 1158 outer_grad_state = None 1159 if outer_forward_ctxt: 1160 outer_grad_state = self._map.get(outer_forward_ctxt) 1161 if outer_grad_state: 1162 # This is a nested loop. 1163 if val_shape.is_fully_defined(): 1164 # If the shape is known statically, just create a zero tensor 1165 # with the right shape in the right context. 1166 outer_grad_state.grad_context.Enter() 1167 result = array_ops.zeros(val_shape.dims, val.dtype) 1168 outer_grad_state.grad_context.Exit() 1169 else: 1170 # Only the shape of value is needed for backprop. 1171 forward_ctxt.outer_context.Enter() 1172 shape = array_ops.shape_internal(val, optimize=False) 1173 forward_ctxt.outer_context.Exit() 1174 # Save the shape to a stack. 1175 history_shape = outer_grad_state.AddForwardAccumulator(shape) 1176 # Get the shape back from the stack. 1177 outer_grad_ctxt = outer_grad_state.grad_context 1178 outer_grad_ctxt.Enter() 1179 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 1180 history_shape, shape) 1181 result = array_ops.zeros(real_shape, val.dtype) 1182 outer_grad_ctxt.Exit() 1183 else: 1184 # This is not a nested loop. 1185 if val_shape.is_fully_defined(): 1186 # If the shape is known statically, just create a zero tensor 1187 # with the right shape. 1188 result = array_ops.zeros(val_shape.dims, val.dtype) 1189 else: 1190 result = array_ops.zeros_like(val, optimize=False) 1191 return result 1192 1193 def ZerosLike(self, op, index): 1194 """Create zeros_like for the specified output of an op. 1195 1196 If op is in a while loop that is part of gradients(), this method 1197 must be called in its grad loop context. 1198 1199 Args: 1200 op: A tensorflow operation. 1201 index: the index for a specific output of the op. 1202 1203 Returns: 1204 A zero tensor of the same shape of op.outputs[index]. 1205 """ 1206 if util.IsLoopSwitch(op): 1207 return None 1208 if op.graph._building_function: # pylint: disable=protected-access 1209 # The optimization here is tricky to apply to functions 1210 return array_ops.zeros_like(op.outputs[index]) 1211 dead_branch = util.IsSwitch(op) 1212 forward_ctxt = _GetWhileContext(op) 1213 grad_state = self._map.get(forward_ctxt) 1214 if grad_state is None: 1215 # op is not in a while loop that is part of gradients(). 1216 return ZerosLikeOutsideLoop(op, index) 1217 op_ctxt = op._get_control_flow_context() 1218 val = ops.convert_to_tensor(op.outputs[index], name="tensor") 1219 shape = val.get_shape() 1220 if shape.is_fully_defined(): 1221 # If the shape is known statically, just create a zero tensor with 1222 # the right shape in the grad loop context. 1223 result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) 1224 if dead_branch: 1225 # op is a cond switch. Guard the zero tensor with a switch. 1226 pred = grad_state.history_map.get(op_ctxt.pred.name) 1227 branch = op_ctxt.branch 1228 result = _SwitchRefOrTensor(result, pred)[1 - branch] 1229 else: 1230 # Unknown shape so keep a history of the shape at runtime. 1231 if dead_branch: 1232 # Need to add a special switch to guard the value. 1233 pred = op_ctxt.pred 1234 branch = op_ctxt.branch 1235 op_ctxt.outer_context.Enter() 1236 val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch] 1237 zeros_shape = array_ops.shape_internal(val, optimize=False) 1238 op_ctxt.outer_context.Exit() 1239 val.op._set_control_flow_context(op_ctxt) 1240 zeros_shape.op._set_control_flow_context(op_ctxt) 1241 else: 1242 op_ctxt.Enter() 1243 zeros_shape = array_ops.shape_internal(val, optimize=False) 1244 op_ctxt.Exit() 1245 1246 # Add forward accumulator for shape. 1247 grad_state.grad_context.Exit() 1248 history_zeros_shape = grad_state.AddForwardAccumulator( 1249 zeros_shape, dead_branch=dead_branch) 1250 grad_state.grad_context.Enter() 1251 1252 # Create a zero tensor with the right shape. 1253 shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, 1254 zeros_shape, dead_branch) 1255 result = array_ops.zeros(shape, val.dtype) 1256 return result 1257 1258 def PostProcessing(self): 1259 """Perform postprocessing at the end of gradients(). 1260 1261 We have created the gradient graph at this point. So this function 1262 can be used to perform any postprocessing on the gradient graph. 1263 We currently perform the following postprocessing: 1264 1. Patch the gradient graph if the output of a loop variable 1265 doesn't depend on its input. 1266 """ 1267 for _, grad_state in self._map.items(): 1268 for _, b_merge in grad_state.switch_map.items(): 1269 if b_merge.op.inputs[0] == b_merge.op.inputs[1]: 1270 # The value of this loop variable at iteration i+1 doesn't 1271 # depend on its value at iteration i. So use zeros as the 1272 # gradients for all iterations > 0. 1273 dtype = b_merge.op.inputs[0].dtype 1274 shape = b_merge.op.inputs[0].get_shape() 1275 # pylint: disable=protected-access 1276 if shape.is_fully_defined(): 1277 grad_state.grad_context.Enter() 1278 # Create a zeros and use it for iterations > 0. 1279 grad_val = constant_op.constant(0, dtype=dtype, shape=shape) 1280 next_grad_val = _NextIteration(grad_val) 1281 grad_state.grad_context.Exit() 1282 else: 1283 # Create a zeros in the outer grad context. 1284 outer_grad_ctxt = grad_state.grad_context.outer_context 1285 if outer_grad_ctxt: 1286 outer_grad_ctxt.Enter() 1287 enter_grad_op = b_merge.op.inputs[0].op 1288 enter_grad = enter_grad_op.inputs[0] 1289 grad_shape = array_ops.shape_internal(enter_grad, optimize=False) 1290 grad_val = array_ops.zeros(grad_shape) 1291 if outer_grad_ctxt: 1292 outer_grad_ctxt.Exit() 1293 # Use the zeros for iterations > 0. 1294 grad_state.grad_context.Enter() 1295 next_grad_val = _NextIteration(grad_val) 1296 grad_state.grad_context.Exit() 1297 b_merge.op._update_input(1, next_grad_val) 1298 # pylint: enable=protected-access 1299 1300 1301def MaybeCreateControlFlowState(between_op_list, between_ops, 1302 colocate_gradients_with_ops): 1303 """Create the state for all the while loops involved in one gradients(). 1304 1305 We create a ControlFlowState when there are while loops involved in 1306 gradients(). In gradients(), control flow logic is only invoked when 1307 the ControlFlowState is not None. 1308 1309 Note that this method modifies `between_op_list` and `between_ops`. 1310 """ 1311 loop_state = None 1312 for op in between_op_list: 1313 if util.IsLoopExit(op): 1314 if loop_state is None: 1315 loop_state = ControlFlowState() 1316 if colocate_gradients_with_ops: 1317 with ops.colocate_with(op): 1318 loop_state.AddWhileContext(op, between_op_list, between_ops) 1319 else: 1320 loop_state.AddWhileContext(op, between_op_list, between_ops) 1321 return loop_state 1322 1323 1324def ZerosLikeOutsideLoop(op, index): 1325 """Create zeros_like for the specified output of an op.""" 1326 val = op.outputs[index] 1327 if not util.IsSwitch(op): 1328 if val.dtype == dtypes.resource: 1329 return array_ops.zeros(gen_resource_variable_ops.variable_shape(val)) 1330 return array_ops.zeros_like(val, optimize=False) 1331 else: 1332 op_ctxt = op._get_control_flow_context() 1333 if op_ctxt: 1334 # We are in a cond context. Use a switch to create zeros only when needed. 1335 pred = op_ctxt.pred 1336 branch = op_ctxt.branch 1337 switch_val = switch(op.inputs[0], pred)[1 - branch] 1338 # A op is created along the branch taken as control dependencies are on 1339 # the whole op and not on the tensor output. 1340 pivot = array_ops.identity(switch_val) 1341 if val.dtype == dtypes.resource: 1342 with ops.control_dependencies([pivot]): 1343 return array_ops.zeros( 1344 gen_resource_variable_ops.variable_shape(switch_val)) 1345 zeros_shape = array_ops.shape_internal(switch_val, optimize=False) 1346 # Ensure ops created within array_ops.zeros are dominated by switch in 1347 # cond context. 1348 with ops.control_dependencies([pivot]): 1349 return array_ops.zeros(zeros_shape, dtype=val.dtype) 1350 else: 1351 return array_ops.zeros_like(val, optimize=False) 1352 1353 1354@six.add_metaclass(abc.ABCMeta) 1355class ControlFlowContext(object): 1356 """The base class for control flow context. 1357 1358 The usage pattern is a sequence of (Enter, Exit) followed by a final 1359 ExitResult. 1360 1361 We maintain the following state for control flow contexts during graph 1362 construction: 1363 1. graph has _control_flow_context: the current context used to 1364 construct new nodes. Changed by ctxt.Enter() and ctxt.Exit() 1365 2. op has _control_flow_context: the context to which the op belongs. 1366 Set at the time the op is created. Immutable. 1367 3. A ControlFlowContext has _outer_context: the context in which this 1368 context is created. Set at the time a context is created. Immutable. 1369 4. A ControlFlowContext has _context_stack. 1370 Pushed and popped by ctxt.Enter() and ctxt.Exit() 1371 """ 1372 1373 def __init__(self, values_def=None, import_scope=None): 1374 self._nested_contexts = [] 1375 self._outer_context = ops.get_default_graph()._get_control_flow_context() 1376 if self._outer_context: 1377 self._outer_context._nested_contexts.append(self) # pylint: disable=protected-access 1378 self._context_stack = [] 1379 if values_def: 1380 self._init_values_from_proto(values_def, import_scope=import_scope) 1381 else: 1382 # The names of tensors that have been already seen in this context. 1383 self._values = set() 1384 # The keys are the names of tensors referenced by but external to this 1385 # context. Each value is the Tensor that should be used by this context to 1386 # access the key value (e.g. a switch output guarding a cond input value). 1387 self._external_values = {} 1388 1389 def _init_values_from_proto(self, values_def, import_scope=None): 1390 """Initializes values and external_values from `ValuesDef` protocol buffer. 1391 1392 Args: 1393 values_def: `ValuesDef` protocol buffer. 1394 import_scope: Optional `string`. Name scope to add. 1395 """ 1396 assert isinstance(values_def, control_flow_pb2.ValuesDef) 1397 self._values = set( 1398 ops.prepend_name_scope(value, import_scope) 1399 for value in values_def.values) 1400 g = ops.get_default_graph() 1401 self._external_values = {} 1402 for k, v in values_def.external_values.items(): 1403 k = ops.prepend_name_scope(k, import_scope) 1404 self._external_values[k] = g.as_graph_element( 1405 ops.prepend_name_scope(v, import_scope)) 1406 op_names = set([ 1407 op.split(":")[0] 1408 for op in self._values - set(self._external_values.keys()) 1409 ]) 1410 for op in op_names: 1411 # pylint: disable=protected-access 1412 g.as_graph_element(op)._set_control_flow_context(self) 1413 # pylint: enable=protected-access 1414 1415 @property 1416 def name(self): 1417 return self._name 1418 1419 @property 1420 def outer_context(self): 1421 """Return the context containing this context.""" 1422 return self._outer_context 1423 1424 @property 1425 def grad_state(self): 1426 raise NotImplementedError("Abstract method") 1427 1428 @property 1429 def back_prop(self): 1430 raise NotImplementedError("Abstract method") 1431 1432 @abc.abstractmethod 1433 def to_control_flow_context_def(self, context_def, export_scope=None): 1434 """Serializes this into `context_def`. 1435 1436 Args: 1437 context_def: a `ControlFlowContextDef` protocol buffer. 1438 export_scope: Optional `string`. Name scope to remove. 1439 """ 1440 raise NotImplementedError("Abstract method") 1441 1442 def _to_values_def(self, export_scope=None): 1443 """Converts the values to a `ValuesDef` protocol buffer. 1444 1445 Args: 1446 export_scope: Optional `string`. Name scope to remove. 1447 1448 Returns: 1449 A `ValuesDef` protocol buffer. 1450 """ 1451 values_def = control_flow_pb2.ValuesDef() 1452 values_def.values.extend( 1453 [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)]) 1454 for k, v in self._external_values.items(): 1455 k = ops.strip_name_scope(k, export_scope) 1456 values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope) 1457 return values_def 1458 1459 def AddName(self, name): 1460 self._values.add(name) 1461 1462 # pylint: disable=protected-access 1463 def Enter(self): 1464 """Enter this control flow context.""" 1465 graph = ops.get_default_graph() 1466 self._context_stack.append(graph._get_control_flow_context()) 1467 graph._set_control_flow_context(self) 1468 1469 def Exit(self): 1470 """Exit this control flow context.""" 1471 graph = ops.get_default_graph() 1472 last_context = self._context_stack.pop() 1473 graph._set_control_flow_context(last_context) 1474 1475 def EnterGradientColocation(self, op, gradient_uid): 1476 """Start building a gradient colocated with an op.""" 1477 if self._outer_context: 1478 self._outer_context.EnterGradientColocation(op, gradient_uid) 1479 1480 def ExitGradientColocation(self, op, gradient_uid): 1481 """Start building a gradient colocated with an op.""" 1482 if self._outer_context: 1483 self._outer_context.ExitGradientColocation(op, gradient_uid) 1484 1485 def ExitResult(self, result): 1486 """Make a list of tensors available in the outer context.""" 1487 if self._outer_context: 1488 nest.map_structure(lambda x: self._outer_context.AddName(x.name), result, 1489 expand_composites=True) 1490 1491 def GetWhileContext(self): 1492 """Return the while context containing this context.""" 1493 if self._outer_context: 1494 return self._outer_context.GetWhileContext() 1495 return None 1496 1497 def _IsInOuterContext(self, op): 1498 op_ctxt = util.GetOutputContext(op) 1499 outer_ctxt = self.outer_context 1500 while outer_ctxt != op_ctxt: 1501 if outer_ctxt is None: 1502 return False 1503 outer_ctxt = outer_ctxt.outer_context 1504 return True 1505 1506 def _RemoveExternalControlEdges(self, op): 1507 """Remove any external control dependency on this op.""" 1508 while_ctxt = self.GetWhileContext() 1509 # A control input of `op` is internal if it is in the same while 1510 # loop context as the enclosing while loop context of self. 1511 if while_ctxt is None: 1512 internal_control_inputs = op.control_inputs 1513 else: 1514 internal_control_inputs = [] 1515 for x in op.control_inputs: 1516 ctxt = util.GetOutputContext(x) 1517 if ctxt is not None and ctxt.GetWhileContext() == while_ctxt: 1518 internal_control_inputs.append(x) 1519 external_control_inputs = [] 1520 if len(internal_control_inputs) != len(op.control_inputs): 1521 external_control_inputs = list( 1522 set(op.control_inputs) - set(internal_control_inputs)) 1523 op._remove_all_control_inputs() 1524 op._add_control_inputs(internal_control_inputs) 1525 return internal_control_inputs, external_control_inputs 1526 1527 # pylint: enable=protected-access 1528 1529 def AddInnerOp(self, op): 1530 """Notifies a scope about an operator added to an inner scope.""" 1531 if self._outer_context: 1532 self._outer_context.AddInnerOp(op) 1533 1534 def GetControlPivot(self): 1535 """Returns the pivot node for this context, or None.""" 1536 return None 1537 1538 def IsWhileContext(self): 1539 return False 1540 1541 def IsCondContext(self): 1542 return False 1543 1544 def IsXLAContext(self): 1545 return False 1546 1547 def __str__(self): 1548 return self.name 1549 1550 1551class CondContext(ControlFlowContext): 1552 """The context for the conditional construct.""" 1553 1554 def __init__(self, 1555 pred=None, 1556 pivot=None, 1557 branch=None, 1558 name="cond_text", 1559 context_def=None, 1560 import_scope=None): 1561 """Creates a `CondContext`. 1562 1563 Args: 1564 pred: The `boolean` tensor for the conditional predicate. 1565 pivot: The predicate tensor in this branch. 1566 branch: 0 or 1 representing this branch. 1567 name: Name of the `CondContext` python object. 1568 context_def: Optional `ContextDef` protocol buffer to initialize the 1569 `CondContext` object from. 1570 import_scope: Optional `string`. Name scope to add. Only used when 1571 initialing from protocol buffer. 1572 """ 1573 self._name = ops.get_default_graph().unique_name(name) 1574 1575 if context_def: 1576 self._init_from_proto(context_def, import_scope=import_scope) 1577 else: 1578 # Initializes the default fields. 1579 ControlFlowContext.__init__(self) 1580 self._pred = pred # The boolean tensor for the cond predicate 1581 self._pivot = pivot # The predicate tensor in this branch 1582 self._branch = branch # 0 or 1 representing this branch 1583 1584 # Values considered to have been already seen in this context. pred is not 1585 # included in this context. 1586 self._values.add(pred.name) 1587 self._external_values[pred.name] = pred 1588 self._values.add(pivot.name) 1589 pivot.op._set_control_flow_context(self) # pylint: disable=protected-access 1590 1591 def _init_from_proto(self, context_def, import_scope=None): 1592 """Creates a new `CondContext` from protocol buffer. 1593 1594 Args: 1595 context_def: `CondContextDef` protocol buffer. 1596 import_scope: Optional `string`. Name scope to add. 1597 """ 1598 assert isinstance(context_def, control_flow_pb2.CondContextDef) 1599 # Create from context_def. 1600 g = ops.get_default_graph() 1601 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 1602 self._pred = g.as_graph_element( 1603 ops.prepend_name_scope(context_def.pred_name, import_scope)) 1604 self._pivot = g.as_graph_element( 1605 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 1606 self._branch = context_def.branch 1607 super(CondContext, self).__init__( 1608 values_def=context_def.values_def, import_scope=import_scope) 1609 1610 @property 1611 def pred(self): 1612 return self._pred 1613 1614 @property 1615 def pivot(self): 1616 return self._pivot 1617 1618 @property 1619 def branch(self): 1620 return self._branch 1621 1622 @property 1623 def grad_state(self): 1624 if self.GetWhileContext(): 1625 return self.GetWhileContext().grad_state 1626 return None 1627 1628 @property 1629 def back_prop(self): 1630 if self.GetWhileContext(): 1631 self.GetWhileContext().back_prop 1632 return False 1633 1634 def GetControlPivot(self): 1635 return self._pivot 1636 1637 def to_proto(self, export_scope=None): 1638 """Converts a `CondContext` to a `CondContextDef` protocol buffer. 1639 1640 Args: 1641 export_scope: Optional `string`. Name scope to remove. 1642 1643 Returns: 1644 A `CondContextDef` protocol buffer. 1645 """ 1646 if (export_scope is None or self.name.startswith(export_scope)): 1647 context_def = control_flow_pb2.CondContextDef() 1648 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 1649 context_def.pred_name = ops.strip_name_scope(self._pred.name, 1650 export_scope) 1651 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 1652 export_scope) 1653 context_def.branch = self._branch 1654 context_def.values_def.MergeFrom( 1655 super(CondContext, self)._to_values_def(export_scope)) 1656 for nested in self._nested_contexts: 1657 nested_def = context_def.nested_contexts.add() 1658 nested.to_control_flow_context_def(nested_def) 1659 1660 return context_def 1661 else: 1662 return None 1663 1664 @staticmethod 1665 def from_proto(context_def, import_scope=None): 1666 """Returns a `CondContext` object created from `context_def`.""" 1667 ret = CondContext(context_def=context_def, import_scope=import_scope) 1668 1669 ret.Enter() 1670 for nested_def in context_def.nested_contexts: 1671 from_control_flow_context_def(nested_def, import_scope=import_scope) 1672 ret.Exit() 1673 return ret 1674 1675 def to_control_flow_context_def(self, context_def, export_scope=None): 1676 context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 1677 1678 def AddValue(self, val): 1679 """Add `val` to the current context and its outer context recursively.""" 1680 if val.name in self._values: 1681 # Use the real value if it comes from outer context. This is needed in 1682 # particular for nested conds. 1683 result = self._external_values.get(val.name) 1684 result = val if result is None else result 1685 else: 1686 result = val 1687 self._values.add(val.name) 1688 if self._outer_context: 1689 result = self._outer_context.AddValue(val) 1690 self._values.add(result.name) 1691 self._external_values[result.name] = result 1692 with ops.control_dependencies(None): 1693 result = _SwitchRefOrTensor(result, self._pred)[self._branch] 1694 if self._outer_context: 1695 self._outer_context.AddInnerOp(result.op) 1696 1697 result.op.graph.prevent_fetching(result.op) 1698 # pylint: disable=protected-access 1699 result.op._set_control_flow_context(self) 1700 # pylint: enable=protected-access 1701 1702 # Mark Switch output as seen by this context and any outer contexts, 1703 # just like what we do for normal op outputs in _AddOpInternal() below. 1704 ctxt = self 1705 while ctxt is not None: 1706 # pylint: disable=protected-access 1707 ctxt._values.add(result.name) 1708 ctxt = ctxt._outer_context 1709 # pylint: enable=protected-access 1710 1711 self._external_values[val.name] = result 1712 return result 1713 1714 def AddOp(self, op): 1715 self._AddOpInternal(op) 1716 1717 def _AddOpInternal(self, op): 1718 """Add `op` to the current context.""" 1719 if not op.inputs: 1720 # If we're in a while loop, remove any control inputs from outside the 1721 # loop. 1722 self._RemoveExternalControlEdges(op) 1723 1724 if not any( 1725 util.OpInContext(input_op, self) for input_op in op.control_inputs): 1726 # pylint: disable=protected-access 1727 op._add_control_input(self._pivot.op) 1728 # pylint: enable=protected-access 1729 else: 1730 # Make each input to 'op' available in this CondContext. If an input is 1731 # already part of this context there's nothing to do, but if it's 1732 # external, AddValue() will handle adding the appropriate Switch node and 1733 # other bookkeeping. 1734 for index in range(len(op.inputs)): 1735 x = op.inputs[index] 1736 if op.type == "Merge" and x.op.type == "NextIteration": 1737 # Edge case: if we're importing a while loop inside this CondContext, 1738 # AddValue() will not correctly handle the NextIteration inputs to 1739 # Merge node. The problem is that the NextIteration should also be 1740 # part of this context, but if we're importing it won't have been 1741 # processed and added to the context yet, so AddValue() will try to 1742 # add a Switch which results in an invalid graph. Instead, we use the 1743 # NextIteration input as-is here, and it will eventually be added to 1744 # the context via AddOp(). 1745 real_x = x 1746 else: 1747 real_x = self.AddValue(x) 1748 if real_x != x: 1749 # pylint: disable=protected-access 1750 op._update_input(index, real_x) 1751 # pylint: enable=protected-access 1752 # Remove any external control dependency on this op. 1753 self._RemoveExternalControlEdges(op) 1754 # pylint: disable=protected-access 1755 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 1756 op._add_control_input(self._pivot.op) 1757 # pylint: enable=protected-access 1758 1759 # Mark op's outputs as seen by this context and any outer contexts. 1760 output_names = [x.name for x in op.outputs] 1761 ctxt = self 1762 while ctxt is not None: 1763 # pylint: disable=protected-access 1764 ctxt._values.update(output_names) 1765 ctxt = ctxt._outer_context 1766 # pylint: enable=protected-access 1767 1768 if self._outer_context or not util.IsLoopExit(op): 1769 op.graph.prevent_fetching(op) 1770 1771 if self._outer_context: 1772 self._outer_context.AddInnerOp(op) 1773 1774 def _ProcessOutputTensor(self, val): 1775 """Process an output tensor of a conditional branch.""" 1776 real_val = val 1777 if val.name not in self._values: 1778 # Handle the special case of lambda: x 1779 self._values.add(val.name) 1780 if self._outer_context: 1781 real_val = self._outer_context.AddValue(val) 1782 self._values.add(real_val.name) 1783 self._external_values[real_val.name] = real_val 1784 real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch] 1785 self._external_values[val.name] = real_val 1786 else: 1787 external_val = self._external_values.get(val.name) 1788 if external_val is not None: 1789 real_val = external_val 1790 return real_val 1791 1792 def _BuildCondTensor(self, v): 1793 if isinstance(v, ops.Operation): 1794 # Use pivot as the proxy for this op. 1795 return with_dependencies([v], self._pivot) 1796 else: 1797 v = nest.map_structure(_convert_tensorarray_to_flow, v, 1798 expand_composites=True) 1799 return self._ProcessOutputTensor(ops.convert_to_tensor(v)) 1800 1801 def BuildCondBranch(self, fn): 1802 """Add the subgraph defined by fn() to the graph.""" 1803 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1804 original_result = fn() 1805 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1806 if len(post_summaries) > len(pre_summaries): 1807 new_summaries = post_summaries[len(pre_summaries):] 1808 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 1809 summary_ref[:] = pre_summaries 1810 with ops.control_dependencies(new_summaries): 1811 if original_result is None: 1812 return no_op(), None 1813 else: 1814 original_result = nest.map_structure(array_ops.identity, 1815 original_result, 1816 expand_composites=True) 1817 if original_result is None: 1818 return None, None 1819 1820 result = nest.map_structure(self._BuildCondTensor, original_result, 1821 expand_composites=True) 1822 if not isinstance(result, (list, _basetuple)): 1823 result = [result] 1824 return original_result, result 1825 1826 def IsCondContext(self): 1827 return True 1828 1829 1830def _UnpackIfSingleton(res): 1831 if isinstance(res, (list, _basetuple)) and len(res) == 1: 1832 return res[0] 1833 else: 1834 return res 1835 1836 1837# pylint: disable=redefined-outer-name 1838# pylint: disable=g-doc-args 1839@tf_export(v1=["cond"]) 1840@deprecation.deprecated_args( 1841 None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.", 1842 "fn1", "fn2") 1843def cond(pred, 1844 true_fn=None, 1845 false_fn=None, 1846 strict=False, 1847 name=None, 1848 fn1=None, 1849 fn2=None): 1850 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 1851 1852 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 1853 `false_fn` must have the same non-zero number and type of outputs. 1854 1855 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 1856 `false_fn` will be executed regardless of which branch is selected at runtime. 1857 1858 Although this behavior is consistent with the dataflow model of TensorFlow, 1859 it has frequently surprised users who expected a lazier semantics. 1860 Consider the following simple program: 1861 1862 ```python 1863 z = tf.multiply(a, b) 1864 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 1865 ``` 1866 1867 If `x < y`, the `tf.add` operation will be executed and `tf.square` 1868 operation will not be executed. Since `z` is needed for at least one 1869 branch of the `cond`, the `tf.multiply` operation is always executed, 1870 unconditionally. 1871 1872 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 1873 call to `cond`, and not at all during `Session.run()`). `cond` 1874 stitches together the graph fragments created during the `true_fn` and 1875 `false_fn` calls with some additional graph nodes to ensure that the right 1876 branch gets executed depending on the value of `pred`. 1877 1878 `tf.cond` supports nested structures as implemented in 1879 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 1880 same (possibly nested) value structure of lists, tuples, and/or named tuples. 1881 Singleton lists and tuples form the only exceptions to this: when returned by 1882 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 1883 This behavior is disabled by passing `strict=True`. 1884 1885 Args: 1886 pred: A scalar determining whether to return the result of `true_fn` or 1887 `false_fn`. 1888 true_fn: The callable to be performed if pred is true. 1889 false_fn: The callable to be performed if pred is false. 1890 strict: A boolean that enables/disables 'strict' mode; see above. 1891 name: Optional name prefix for the returned tensors. 1892 1893 Returns: 1894 Tensors returned by the call to either `true_fn` or `false_fn`. If the 1895 callables return a singleton list, the element is extracted from the list. 1896 1897 Raises: 1898 TypeError: if `true_fn` or `false_fn` is not callable. 1899 ValueError: if `true_fn` and `false_fn` do not return the same number of 1900 tensors, or return tensors of different types. 1901 1902 Example: 1903 1904 ```python 1905 x = tf.constant(2) 1906 y = tf.constant(5) 1907 def f1(): return tf.multiply(x, 17) 1908 def f2(): return tf.add(y, 23) 1909 r = tf.cond(tf.less(x, y), f1, f2) 1910 # r is set to f1(). 1911 # Operations in f2 (e.g., tf.add) are not executed. 1912 ``` 1913 1914 """ 1915 # Always enable control flow v2 if building a function, regardless of toggle. 1916 if (util.EnableControlFlowV2(ops.get_default_graph()) and 1917 not context.executing_eagerly()): 1918 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 1919 1920 # We needed to make true_fn/false_fn keyword arguments for 1921 # backwards-compatibility. This check exists so that we can convert back to 1922 # having them be positional arguments. 1923 # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after 1924 # `fn1` and `fn2` are deleted. 1925 if fn1 is not None: 1926 if true_fn is not None: 1927 raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.") 1928 true_fn = fn1 1929 elif true_fn is None: 1930 raise TypeError("cond(): true_fn argument required") 1931 if fn2 is not None: 1932 if false_fn is not None: 1933 raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.") 1934 false_fn = fn2 1935 elif false_fn is None: 1936 raise TypeError("cond(): false_fn argument required") 1937 1938 if not callable(true_fn): 1939 raise TypeError("true_fn must be callable.") 1940 if not callable(false_fn): 1941 raise TypeError("false_fn must be callable.") 1942 1943 with ops.name_scope(name, "cond", [pred]): 1944 if context.executing_eagerly(): 1945 if pred: 1946 return _UnpackIfSingleton(true_fn()) 1947 return _UnpackIfSingleton(false_fn()) 1948 1949 # Add the Switch to the graph. 1950 if isinstance(pred, bool): 1951 raise TypeError("pred must not be a Python bool") 1952 p_2, p_1 = switch(pred, pred) 1953 pivot_1 = array_ops.identity(p_1, name="switch_t") 1954 pivot_2 = array_ops.identity(p_2, name="switch_f") 1955 pred = array_ops.identity(pred, name="pred_id") 1956 # Disable the fetching of tensors that are only on one branch of cond. 1957 for tensor in [p_1, p_2, pivot_1, pivot_2, pred]: 1958 tensor.op.graph.prevent_fetching(tensor.op) 1959 1960 # Build the graph for the true branch in a new context. 1961 context_t = CondContext(pred, pivot_1, branch=1) 1962 try: 1963 context_t.Enter() 1964 orig_res_t, res_t = context_t.BuildCondBranch(true_fn) 1965 if orig_res_t is None: 1966 raise ValueError("true_fn must have a return value.") 1967 context_t.ExitResult(res_t) 1968 finally: 1969 context_t.Exit() 1970 1971 # Build the graph for the false branch in a new context. 1972 context_f = CondContext(pred, pivot_2, branch=0) 1973 try: 1974 context_f.Enter() 1975 orig_res_f, res_f = context_f.BuildCondBranch(false_fn) 1976 if orig_res_f is None: 1977 raise ValueError("false_fn must have a return value.") 1978 context_f.ExitResult(res_f) 1979 finally: 1980 context_f.Exit() 1981 1982 if not strict: 1983 orig_res_t = _UnpackIfSingleton(orig_res_t) 1984 orig_res_f = _UnpackIfSingleton(orig_res_f) 1985 1986 # Check that the return values of the two branches have the same structure. 1987 try: 1988 nest.assert_same_structure(orig_res_t, orig_res_f, 1989 expand_composites=True) 1990 except TypeError as e: 1991 raise TypeError( 1992 "Incompatible return types of true_fn and false_fn: {}".format(e)) 1993 except ValueError as e: 1994 raise ValueError( 1995 "Incompatible return values of true_fn and false_fn: {}".format(e)) 1996 1997 # Add the final merge to the graph. 1998 if not res_t: 1999 raise ValueError("true_fn and false_fn must return at least one result.") 2000 2001 res_t_flat = nest.flatten(res_t, expand_composites=True) 2002 res_f_flat = nest.flatten(res_f, expand_composites=True) 2003 2004 for i, (x, y) in enumerate(zip(res_t_flat, res_f_flat)): 2005 assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor) 2006 if x.dtype.base_dtype != y.dtype.base_dtype: 2007 _cast_indexed_slice_indices(res_t, res_t_flat, res_f_flat) 2008 if res_t_flat[i].dtype.base_dtype != res_f_flat[i].dtype.base_dtype: 2009 raise ValueError( 2010 "Outputs of true_fn and false_fn must have the same type: " 2011 "%s, %s" % (x.dtype.name, y.dtype.name)) 2012 2013 merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)] 2014 merges = _convert_flows_to_tensorarrays( 2015 nest.flatten(orig_res_t, expand_composites=True), merges) 2016 2017 # Only add non-nested conds to the collection. Any nested control flow will 2018 # be encapsulated in the root context. 2019 assert context_t.outer_context == context_f.outer_context 2020 if context_t.outer_context is None: 2021 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) 2022 ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) 2023 2024 merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges, 2025 expand_composites=True) 2026 2027 # Singleton lists and tuples are automatically unpacked if strict == False. 2028 if not strict: 2029 merges = _UnpackIfSingleton(merges) 2030 return merges 2031 2032 2033def _cast_indexed_slice_indices(structure, flat_a, flat_b): 2034 """Cast IndexedSlice.indices from int32 to int64 where necessary. 2035 2036 For each `IndexedSlices` in the nested structure `structure`, find its 2037 indices `Tensor` in the corresponding flattened lists `flat_a` and `flat_b` 2038 (where composites have been expanded); and if those indices tensors have 2039 different dtypes (i.e., if one is int64 but the other is int32), then cast 2040 them to both be int64. 2041 2042 Args: 2043 structure: The nested structure that was flattened. 2044 flat_a: A flattened list of `Tensors` whose structure matches 2045 `structure`. Will be modified in place to cast `IndexedSlices` 2046 indices tensors to int64, where necessary. 2047 flat_a: A flattened list of `Tensors` whose structure matches 2048 `structure`. Will be modified in place to cast `IndexedSlices` 2049 indices tensors to int64, where necessary. 2050 """ 2051 # Find the locations (in flat_a and flat_b) of the IndexedSlices' 2052 # indices tensors. 2053 indexed_slice_indices = [] 2054 current_index = 0 2055 for item in nest.flatten(structure, expand_composites=False): 2056 if isinstance(item, ops.IndexedSlices): 2057 # indices is the second component of the composite tensor. 2058 indexed_slice_indices.append(current_index + 1) 2059 if nest.is_sequence_or_composite(item): 2060 current_index += len(nest.flatten(item, expand_composites=True)) 2061 else: 2062 current_index += 1 2063 assert current_index == len(flat_a) 2064 2065 for index in indexed_slice_indices: 2066 assert flat_a[index].dtype in (dtypes.int32, dtypes.int64) 2067 assert flat_b[index].dtype in (dtypes.int32, dtypes.int64) 2068 if flat_a[index].dtype != flat_b[index].dtype: 2069 if flat_b[index].dtype == dtypes.int32: 2070 flat_b[index] = math_ops.cast(flat_b[index], dtypes.int64) 2071 else: 2072 flat_a[index] = math_ops.cast(flat_a[index], dtypes.int64) 2073 2074 2075# pylint: enable=g-doc-args 2076# pylint: enable=redefined-outer-name 2077 2078 2079@tf_export("cond", v1=[]) 2080def cond_for_tf_v2(pred, 2081 true_fn=None, 2082 false_fn=None, 2083 name=None): 2084 """Return `true_fn()` if the predicate `pred` is true else `false_fn()`. 2085 2086 `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and 2087 `false_fn` must have the same non-zero number and type of outputs. 2088 2089 **WARNING**: Any Tensors or Operations created outside of `true_fn` and 2090 `false_fn` will be executed regardless of which branch is selected at runtime. 2091 2092 Although this behavior is consistent with the dataflow model of TensorFlow, 2093 it has frequently surprised users who expected a lazier semantics. 2094 Consider the following simple program: 2095 2096 ```python 2097 z = tf.multiply(a, b) 2098 result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y)) 2099 ``` 2100 2101 If `x < y`, the `tf.add` operation will be executed and `tf.square` 2102 operation will not be executed. Since `z` is needed for at least one 2103 branch of the `cond`, the `tf.multiply` operation is always executed, 2104 unconditionally. 2105 2106 Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the 2107 call to `cond`, and not at all during `Session.run()`). `cond` 2108 stitches together the graph fragments created during the `true_fn` and 2109 `false_fn` calls with some additional graph nodes to ensure that the right 2110 branch gets executed depending on the value of `pred`. 2111 2112 `tf.cond` supports nested structures as implemented in 2113 `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the 2114 same (possibly nested) value structure of lists, tuples, and/or named tuples. 2115 Singleton lists and tuples form the only exceptions to this: when returned by 2116 `true_fn` and/or `false_fn`, they are implicitly unpacked to single values. 2117 2118 Args: 2119 pred: A scalar determining whether to return the result of `true_fn` or 2120 `false_fn`. 2121 true_fn: The callable to be performed if pred is true. 2122 false_fn: The callable to be performed if pred is false. 2123 name: Optional name prefix for the returned tensors. 2124 2125 Returns: 2126 Tensors returned by the call to either `true_fn` or `false_fn`. If the 2127 callables return a singleton list, the element is extracted from the list. 2128 2129 Raises: 2130 TypeError: if `true_fn` or `false_fn` is not callable. 2131 ValueError: if `true_fn` and `false_fn` do not return the same number of 2132 tensors, or return tensors of different types. 2133 2134 Example: 2135 2136 ```python 2137 x = tf.constant(2) 2138 y = tf.constant(5) 2139 def f1(): return tf.multiply(x, 17) 2140 def f2(): return tf.add(y, 23) 2141 r = tf.cond(tf.less(x, y), f1, f2) 2142 # r is set to f1(). 2143 # Operations in f2 (e.g., tf.add) are not executed. 2144 ``` 2145 2146 """ 2147 return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name) 2148 2149 2150def _resource_safe_shape(t): 2151 """Returns the shape of t or the variable it points to.""" 2152 if t.dtype == dtypes.resource: 2153 while t.op.inputs: 2154 t = t.op.inputs[0] 2155 return tensor_shape.TensorShape(t.op.get_attr("shape")) 2156 return array_ops.shape_internal(t, optimize=False) 2157 2158 2159# TODO(yuanbyu): Consider having a unified notion of context for 2160# not only conditionals and loops but also control dependency and 2161# subgraphs. 2162class WhileContext(ControlFlowContext): 2163 """The context for the loop construct.""" 2164 2165 def __init__(self, 2166 maximum_iterations=None, 2167 parallel_iterations=10, 2168 back_prop=True, 2169 swap_memory=False, 2170 name="while_context", 2171 grad_state=None, 2172 context_def=None, 2173 import_scope=None): 2174 """"Creates a `WhileContext`. 2175 2176 Args: 2177 maximum_iterations: Optional upper bound on number of loop iterations. 2178 parallel_iterations: The number of iterations allowed to run in parallel. 2179 back_prop: Whether backprop is enabled for this while loop. 2180 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2181 name: Optional name prefix for the returned tensors. 2182 grad_state: The gradient loop state. 2183 context_def: Optional `WhileContextDef` protocol buffer to initialize the 2184 `Whilecontext` python object from. 2185 import_scope: Optional `string`. Name scope to add. Only used when 2186 initialing from protocol buffer. 2187 """ 2188 if context_def: 2189 self._init_from_proto(context_def, import_scope=import_scope) 2190 else: 2191 ControlFlowContext.__init__(self) 2192 self._init_from_args(maximum_iterations, parallel_iterations, back_prop, 2193 swap_memory, name) 2194 # The gradient loop state. 2195 self._grad_state = grad_state 2196 2197 def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop, 2198 swap_memory, name): 2199 """Creates a new `WhileContext` from arguments. 2200 2201 Args: 2202 maximum_iterations: Optional upper bound on number of loop iterations. 2203 parallel_iterations: The number of iterations allowed to run in parallel. 2204 back_prop: Whether backprop is enabled for this while loop. 2205 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 2206 name: Optional name prefix for the returned tensors. 2207 2208 Raises: 2209 ValueError: If `parallel_iterations` has invalid value. 2210 """ 2211 if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0): 2212 raise ValueError("`parallel_iterations` must be a positive integer: " 2213 "%s" % parallel_iterations) 2214 self._name = ops.get_default_graph().unique_name(name) 2215 self._maximum_iterations = maximum_iterations 2216 self._parallel_iterations = parallel_iterations 2217 self._back_prop = back_prop 2218 self._swap_memory = swap_memory 2219 # We use this node to control constants created by the pred lambda. 2220 self._pivot_for_pred = None 2221 # We use this node to control constants created by the body lambda. 2222 self._pivot_for_body = None 2223 # The boolean tensor for loop termination condition. Used in code 2224 # generation for gradient computation 2225 self._pivot = None 2226 # The list of exit tensors for loop variables. 2227 self._loop_exits = [] 2228 # The list of enter tensors for loop variables. 2229 self._loop_enters = [] 2230 self._graph = ops.get_default_graph() 2231 2232 def _init_from_proto(self, context_def, import_scope=None): 2233 """Creates a new `WhileContext` from protocol buffer. 2234 2235 Args: 2236 context_def: `WhileContextDef` protocol buffer. 2237 import_scope: Optional `string`. Name scope to add. 2238 """ 2239 assert isinstance(context_def, control_flow_pb2.WhileContextDef) 2240 # Create from context_def. 2241 g = ops.get_default_graph() 2242 self._name = ops.prepend_name_scope(context_def.context_name, import_scope) 2243 if context_def.maximum_iterations_name: 2244 self._maximum_iterations = g.as_graph_element( 2245 ops.prepend_name_scope(context_def.maximum_iterations_name, 2246 import_scope)) 2247 else: 2248 self._maximum_iterations = None 2249 self._parallel_iterations = context_def.parallel_iterations 2250 self._back_prop = context_def.back_prop 2251 self._swap_memory = context_def.swap_memory 2252 self._pivot_for_pred = g.as_graph_element( 2253 ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope)) 2254 # We use this node to control constants created by the body lambda. 2255 self._pivot_for_body = g.as_graph_element( 2256 ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope)) 2257 # The boolean tensor for loop termination condition. Used in code 2258 # generation for gradient computation. 2259 self._pivot = g.as_graph_element( 2260 ops.prepend_name_scope(context_def.pivot_name, import_scope)) 2261 # The list of exit tensors for loop variables. 2262 self._loop_exits = [ 2263 g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope)) 2264 for exit_name in context_def.loop_exit_names 2265 ] 2266 # The list of enter tensors for loop variables. 2267 self._loop_enters = [ 2268 g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope)) 2269 for enter_name in context_def.loop_enter_names 2270 ] 2271 super(WhileContext, self).__init__( 2272 values_def=context_def.values_def, import_scope=import_scope) 2273 2274 # import_scope causes self.name to be different from the original serialized 2275 # context's name. Rewrite "frame_name" attrs with the new name. 2276 if import_scope: 2277 for tensor_name in self._values: 2278 op = g.as_graph_element(tensor_name).op 2279 if util.IsLoopEnter(op): 2280 # pylint: disable=protected-access 2281 op._set_attr("frame_name", 2282 attr_value_pb2.AttrValue(s=compat.as_bytes(self.name))) 2283 # pylint: enable=protected-access 2284 self._graph = ops.get_default_graph() 2285 2286 @property 2287 def maximum_iterations(self): 2288 """The maximum number of iterations that will be executed.""" 2289 return self._maximum_iterations 2290 2291 @property 2292 def parallel_iterations(self): 2293 """The number of iterations allowed to run in parallel.""" 2294 return self._parallel_iterations 2295 2296 @property 2297 def back_prop(self): 2298 """True iff backprop is enabled for this while loop.""" 2299 return self._back_prop 2300 2301 @property 2302 def swap_memory(self): 2303 """True iff GPU-CPU memory swap is enabled for this while loop.""" 2304 return self._swap_memory 2305 2306 @property 2307 def pivot(self): 2308 """The boolean tensor representing the loop termination condition.""" 2309 return self._pivot 2310 2311 @property 2312 def loop_enters(self): 2313 """The list of enter tensors for loop variables.""" 2314 return self._loop_enters 2315 2316 @property 2317 def loop_exits(self): 2318 """The list of exit tensors for loop variables.""" 2319 return self._loop_exits 2320 2321 @property 2322 def grad_state(self): 2323 """The gradient loop state.""" 2324 return self._grad_state 2325 2326 def to_proto(self, export_scope=None): 2327 """Converts a `WhileContext` to a `WhileContextDef` protocol buffer. 2328 2329 Args: 2330 export_scope: Optional `string`. Name scope to remove. 2331 2332 Returns: 2333 A `WhileContextDef` protocol buffer. 2334 """ 2335 if (export_scope is None or self.name.startswith(export_scope)): 2336 context_def = control_flow_pb2.WhileContextDef() 2337 context_def.context_name = ops.strip_name_scope(self.name, export_scope) 2338 context_def.parallel_iterations = self._parallel_iterations 2339 if self._maximum_iterations is not None: 2340 context_def.maximum_iterations_name = ops.strip_name_scope( 2341 self._maximum_iterations.name, export_scope) 2342 context_def.back_prop = self._back_prop 2343 context_def.swap_memory = self._swap_memory 2344 context_def.pivot_for_pred_name = ops.strip_name_scope( 2345 self._pivot_for_pred.name, export_scope) 2346 context_def.pivot_for_body_name = ops.strip_name_scope( 2347 self._pivot_for_body.name, export_scope) 2348 context_def.pivot_name = ops.strip_name_scope(self._pivot.name, 2349 export_scope) 2350 context_def.loop_exit_names.extend([ 2351 ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits 2352 ]) 2353 context_def.loop_enter_names.extend([ 2354 ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters 2355 ]) 2356 context_def.values_def.MergeFrom( 2357 super(WhileContext, self)._to_values_def(export_scope=export_scope)) 2358 for nested in self._nested_contexts: 2359 nested_def = context_def.nested_contexts.add() 2360 nested.to_control_flow_context_def(nested_def) 2361 2362 return context_def 2363 else: 2364 return None 2365 2366 def to_control_flow_context_def(self, context_def, export_scope=None): 2367 context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope)) 2368 2369 @staticmethod 2370 def from_proto(context_def, import_scope=None): 2371 """Returns a `WhileContext` object created from `context_def`. 2372 2373 Args: 2374 context_def: A `WhileContextDef` protocol buffer. 2375 import_scope: Optional `string`. Name scope to add. 2376 2377 Returns: 2378 A `WhileContext` Python object. 2379 """ 2380 ret = WhileContext(context_def=context_def, import_scope=import_scope) 2381 ret.Enter() 2382 for nested_def in context_def.nested_contexts: 2383 from_control_flow_context_def(nested_def, import_scope=import_scope) 2384 ret.Exit() 2385 return ret 2386 2387 def GetWhileContext(self): 2388 return self 2389 2390 def GetControlPivot(self): 2391 if self._pivot_for_body is not None: 2392 return self._pivot_for_body 2393 return self._pivot_for_pred 2394 2395 def AddValue(self, val): 2396 """Add `val` to the current context and its outer context recursively.""" 2397 result = val 2398 new_value = val.name not in self._values 2399 # Don't treat ops in this context as new values. Usually all known values 2400 # are in self._values, except when we're importing a while loop inside this 2401 # WhileContext. Since there's a cycle in this case, `val` may be part of the 2402 # imported while loop but not yet processed by this context and added to 2403 # self._values in _AddOpInternal. We only want to process external input 2404 # tensors to the while loop here. 2405 new_value &= val.op._control_flow_context is not self # pylint: disable=protected-access 2406 if new_value: 2407 self._values.add(val.name) 2408 2409 # If we are in a grad context and val is from its forward context, 2410 # use GetRealValue(), which adds the logic to save the history of 2411 # val in forward. 2412 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 2413 if grad_ctxt: 2414 grad_ctxt = grad_ctxt.GetWhileContext() 2415 if grad_ctxt.grad_state: 2416 forward_ctxt = _GetWhileContext(val.op) 2417 if util.IsLoopExit(val.op): 2418 forward_ctxt = forward_ctxt.outer_context 2419 if forward_ctxt: 2420 forward_ctxt = forward_ctxt.GetWhileContext() 2421 if forward_ctxt == grad_ctxt.grad_state.forward_context: 2422 real_val = grad_ctxt.grad_state.GetRealValue(val) 2423 self._external_values[val.name] = real_val 2424 return real_val 2425 2426 if self._outer_context is not None: 2427 result = self._outer_context.AddValue(val) 2428 # Create an Enter to make `result` known to this loop context. 2429 with ops.control_dependencies(None): 2430 enter = _Enter( 2431 result, 2432 self._name, 2433 is_constant=True, 2434 parallel_iterations=self._parallel_iterations) 2435 enter.graph.prevent_feeding(enter) 2436 if self._outer_context: 2437 self._outer_context.AddInnerOp(enter.op) 2438 # Fix the control inputs and control flow context of these enter ops. 2439 self._FixControlInputsAndContext([enter]) 2440 2441 # Add `enter` in this context. 2442 self._values.add(enter.name) 2443 self._external_values[val.name] = enter 2444 result = enter 2445 else: 2446 actual_val = self._external_values.get(val.name) 2447 if actual_val is not None: 2448 result = actual_val 2449 return result 2450 2451 def AddOp(self, op): 2452 """Add `op` to the current context.""" 2453 # For a reduction op, if op is in a grad context and its input is from 2454 # its forward context, moving op to the forward context means we would 2455 # store the tensor after the reduction as opposed to the tensor before 2456 # reduction, and therefore could significantly reduce memory consumption. 2457 # For now, we do this only for a few ops. 2458 if op.type in {"Shape", "Size", "Rank"}: 2459 grad_ctxt = ops.get_default_graph()._get_control_flow_context() 2460 if grad_ctxt: 2461 grad_ctxt = grad_ctxt.GetWhileContext() 2462 if grad_ctxt.grad_state: 2463 op_input_forward_ctxt = _GetWhileContext(op.inputs[0].op) 2464 if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context: 2465 op_input_ctxt = op.inputs[0].op._get_control_flow_context() 2466 op._set_control_flow_context(op_input_ctxt) 2467 op_input_ctxt._AddOpInternal(op) 2468 return 2469 self._AddOpInternal(op) 2470 2471 def _AddOpInternal(self, op): 2472 """Add `op` to the current context. 2473 2474 We move any external control dependencies of the op to the loop pivot, to 2475 ensure they get executed. 2476 """ 2477 if not op.inputs: 2478 # Remove any external control dependency on this op 2479 control_inputs, external_inputs = self._RemoveExternalControlEdges(op) 2480 # Add a control edge from the control pivot to this op. 2481 if not control_inputs: 2482 # pylint: disable=protected-access 2483 op._add_control_input(self.GetControlPivot().op) 2484 # pylint: enable=protected-access 2485 for x in op.outputs: 2486 self._values.add(x.name) 2487 else: 2488 for index in range(len(op.inputs)): 2489 x = op.inputs[index] 2490 real_x = self.AddValue(x) 2491 if real_x != x: 2492 op._update_input(index, real_x) # pylint: disable=protected-access 2493 # Remove any external control dependency on this op. 2494 _, external_inputs = self._RemoveExternalControlEdges(op) 2495 # Add a control dependency to prevent loop invariants from 2496 # enabling ops that should not be executed. 2497 self._MaybeAddControlDependency(op) 2498 for x in op.outputs: 2499 self._values.add(x.name) 2500 if external_inputs: 2501 # Use an identity to pull control inputs as data inputs. Note that we 2502 # ignore ops which don't have outputs. TODO(apassos): fix that 2503 with ops.control_dependencies(None): 2504 self.Enter() 2505 external_inputs = [ 2506 array_ops.identity(x.outputs[0]).op 2507 for x in external_inputs 2508 if x.outputs 2509 ] 2510 self.Exit() 2511 op._add_control_inputs(external_inputs) # pylint: disable=protected-access 2512 if self._outer_context or not util.IsLoopExit(op): 2513 op.graph.prevent_fetching(op) 2514 for x in op.outputs: 2515 op.graph.prevent_feeding(x) 2516 2517 if self._outer_context: 2518 self._outer_context.AddInnerOp(op) 2519 2520 def _MaybeAddControlDependency(self, op): 2521 """Add a control input to the op if it only depends on loop invariants.""" 2522 2523 def _IsOpFree(op): 2524 """Determines if `op` needs a control dependency.""" 2525 if op.control_inputs: 2526 return False 2527 # pylint: disable=protected-access 2528 if op.graph._is_function(op.type) or op.type == "SymbolicGradient": 2529 return True 2530 # pylint: enable=protected-access 2531 for x in op.inputs: 2532 if not util.IsLoopConstantEnter(x.op): 2533 return False 2534 return True 2535 2536 if _IsOpFree(op): 2537 # pylint: disable=protected-access 2538 op._add_control_input(self.GetControlPivot().op) 2539 # pylint: enable=protected-access 2540 2541 def AddForwardLoopCounter(self, outer_grad_state): 2542 """Adds a loop that counts the number of iterations. 2543 2544 This is added to the forward loop at the time when we start to 2545 create the loop for backprop gradient computation. Called in 2546 the outer context of this forward context. 2547 2548 The pseudocode is: 2549 `n = 0; while (_pivot) { n++; }` 2550 2551 Note that a control dependency is added to `n` to ensure the correct 2552 execution order of stack push ops. 2553 2554 Args: 2555 outer_grad_state: The outer grad state. None if not nested. 2556 2557 Returns: 2558 The number of iterations taken by the forward loop and the loop index. 2559 """ 2560 n = constant_op.constant(0, name="f_count") 2561 if outer_grad_state is not None: 2562 # Force the stack pushes of i-th execution of an inner loop to be ordered 2563 # before the pushes of (i+1)-th execution of the same inner loop. 2564 outer_add_op = outer_grad_state.forward_index.op.inputs[0].op 2565 n.op._add_control_input(outer_add_op) # pylint: disable=protected-access 2566 2567 self.Enter() 2568 self.AddName(n.name) 2569 enter_n = _Enter( 2570 n, 2571 self._name, 2572 is_constant=False, 2573 parallel_iterations=self._parallel_iterations, 2574 name="f_count") 2575 self.loop_enters.append(enter_n) 2576 2577 merge_n = merge([enter_n, enter_n])[0] 2578 switch_n = switch(merge_n, self._pivot) 2579 2580 index = math_ops.add(switch_n[1], 1) 2581 next_n = _NextIteration(index) 2582 merge_n.op._update_input(1, next_n) 2583 2584 total_iterations = exit(switch_n[0], name="f_count") 2585 self.loop_exits.append(total_iterations) 2586 self.ExitResult([total_iterations]) 2587 self.Exit() 2588 return total_iterations, next_n 2589 2590 def AddBackpropLoopCounter(self, count, outer_grad_state): 2591 """Add the backprop loop that controls the iterations. 2592 2593 This is added to the backprop loop. It is used to control the loop 2594 termination of the backprop loop. Called in the outer context of 2595 this grad context. 2596 2597 The pseudocode is: 2598 `n = count; while (n >= 1) { n--; }` 2599 2600 Note that a control dependency is added to `final_zero` to ensure the 2601 correct execution order of stack pop ops. 2602 2603 Args: 2604 count: The number of iterations for backprop. 2605 outer_grad_state: The outer grad state. None if not nested. 2606 2607 Returns: 2608 The loop index. 2609 """ 2610 in_separate_functions = count.graph is not ops.get_default_graph() 2611 if in_separate_functions: 2612 # Brings the count into this graph 2613 count = array_ops.identity(count) 2614 else: 2615 # TODO(apassos) XLA expects this constant to be created outside the loop, 2616 # so doing that for now. 2617 one = constant_op.constant(1, name="b_count") 2618 2619 self.Enter() 2620 self.AddName(count.name) 2621 enter_count = _Enter( 2622 count, 2623 self._name, 2624 is_constant=False, 2625 parallel_iterations=self._parallel_iterations, 2626 name="b_count") 2627 self.loop_enters.append(enter_count) 2628 2629 merge_count = merge([enter_count, enter_count])[0] 2630 self._pivot_for_pred = merge_count 2631 2632 if in_separate_functions: 2633 one = constant_op.constant(1, name="b_count") 2634 pred = math_ops.greater_equal(merge_count, one) 2635 self._pivot = loop_cond(pred, name="b_count") 2636 switch_count = switch(merge_count, self._pivot) 2637 2638 index = math_ops.subtract(switch_count[1], one) 2639 self._pivot_for_body = index 2640 next_count = _NextIteration(index) 2641 merge_count.op._update_input(1, next_count) 2642 2643 final_zero = exit(switch_count[0], name="b_count") 2644 self.loop_exits.append(final_zero) 2645 if outer_grad_state is not None: 2646 # Force the stack pops of i-th execution of an inner loop to be ordered 2647 # before the pops of (i+1)-th execution of the same inner loop. 2648 # pylint: disable=protected-access 2649 outer_grad_state.grad_sync._add_control_input(final_zero.op) 2650 # pylint: enable=protected-access 2651 2652 self.ExitResult([final_zero]) 2653 self.Exit() 2654 return next_count 2655 2656 def AddBackpropAccumulator(self, op, grad): 2657 """Add an accumulation loop for every loop invariant. 2658 2659 This is added to the backprop loop. It is used to accumulate partial 2660 gradients within each loop iteration. Called when in the gradient while 2661 context. 2662 2663 The pseudocode is: 2664 ``` 2665 acc = 0.0; 2666 while (_pivot) { 2667 acc += grad; 2668 } 2669 ``` 2670 2671 Args: 2672 op: The Enter op for a loop invariant. 2673 grad: The partial gradient of an iteration for a loop invariant. 2674 2675 Returns: 2676 The gradient for a loop invariant. 2677 """ 2678 self.Exit() 2679 # Create a zeros tensor with the right shape for acc. If we don't 2680 # know the full shape statically, we will have to get the shape 2681 # dynamically from the forward inference. Getting the shape right 2682 # for the zeros is only needed for the base case when the loop exits 2683 # without running any iterations. 2684 shape = grad.get_shape() 2685 if shape.is_fully_defined(): 2686 if self.outer_context: 2687 self.outer_context.Enter() 2688 acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc") 2689 if self.outer_context: 2690 self.outer_context.Exit() 2691 else: 2692 value = op.inputs[0] 2693 if (isinstance(self.outer_context, WhileContext) and 2694 self.outer_context.grad_state is not None): 2695 # We are in a nested while loop. 2696 forward_ctxt = self.grad_state.forward_context 2697 forward_ctxt.outer_context.Enter() 2698 zeros_shape = array_ops.shape_internal(value, optimize=False) 2699 forward_ctxt.outer_context.Exit() 2700 outer_grad_state = self.grad_state.outer_grad_state 2701 history_zeros_shape = outer_grad_state.AddForwardAccumulator( 2702 zeros_shape) 2703 self.outer_context.Enter() 2704 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 2705 history_zeros_shape, zeros_shape) 2706 acc = array_ops.zeros(real_shape, grad.dtype) 2707 self.outer_context.Exit() 2708 else: 2709 if self.outer_context: 2710 self.outer_context.Enter() 2711 zeros_shape = array_ops.shape_internal(value, optimize=False) 2712 acc = array_ops.zeros(zeros_shape, grad.dtype) 2713 if self.outer_context: 2714 self.outer_context.Exit() 2715 2716 self.Enter() 2717 self.AddName(acc.name) 2718 enter_acc = _Enter( 2719 acc, 2720 self._name, 2721 is_constant=False, 2722 parallel_iterations=self._parallel_iterations, 2723 name="b_acc") 2724 self.loop_enters.append(enter_acc) 2725 2726 merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0] 2727 switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot) 2728 2729 add_acc = math_ops.add(switch_acc_true, grad) 2730 next_acc = _NextIteration(add_acc) 2731 merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access 2732 2733 result_acc = exit(switch_acc_false, name="b_acc") 2734 self.loop_exits.append(result_acc) 2735 self.ExitResult([result_acc]) 2736 return result_acc 2737 2738 def AddBackpropIndexedSlicesAccumulator(self, op, grad): 2739 """This is used for accumulating gradients that are IndexedSlices. 2740 2741 This is essentially the equivalent of AddBackpropAccumulator but optimized 2742 for things like updating embeddings from within a while loop. 2743 2744 Args: 2745 op: The Enter op for a loop invariant. 2746 grad: The partial gradients represented as an IndexedSlices. 2747 2748 Returns: 2749 The accumulated IndexedSlices gradient of the loop invariant. 2750 """ 2751 values = grad.values 2752 indices = grad.indices 2753 dense_shape = grad.dense_shape 2754 2755 self.Exit() 2756 if self.outer_context: 2757 self.outer_context.Enter() 2758 if values.get_shape().is_fully_defined(): 2759 values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] + 2760 values.get_shape().dims[1:]) 2761 if self.outer_context: 2762 self.outer_context.Enter() 2763 values_acc = constant_op.constant( 2764 0, values.dtype, shape=values_shape, name="b_acc") 2765 if self.outer_context: 2766 self.outer_context.Exit() 2767 else: 2768 values_shape = _resource_safe_shape(op.inputs[0])[1:] 2769 values_shape = array_ops.concat([[1], values_shape], 0) 2770 values_acc = array_ops.zeros(values_shape, dtype=values.dtype) 2771 indices_acc = constant_op.constant([0], indices.dtype) 2772 shape_acc = None 2773 if dense_shape is not None: 2774 if dense_shape.get_shape().is_fully_defined(): 2775 if self.outer_context: 2776 self.outer_context.Enter() 2777 shape_acc = constant_op.constant( 2778 0, dense_shape.dtype, shape=dense_shape.get_shape()) 2779 if self.outer_context: 2780 self.outer_context.Exit() 2781 else: 2782 shape_acc = array_ops.zeros_like( 2783 array_ops.shape_internal( 2784 op.inputs[0], optimize=False, out_type=dense_shape.dtype), 2785 optimize=False) 2786 2787 if self.outer_context: 2788 self.outer_context.Exit() 2789 2790 self.Enter() 2791 self.AddName(values_acc.name) 2792 self.AddName(indices_acc.name) 2793 init_acc = [indices_acc, values_acc] 2794 if shape_acc is not None: 2795 self.AddName(shape_acc.name) 2796 init_acc.append(shape_acc) 2797 2798 # Set use_input_shape=False since the accumulator tensors will grow in 2799 # size. If use_input_shape=True, the _update_input call below will result in 2800 # incompatible shapes. 2801 enter_acc = [ 2802 _Enter( 2803 x, 2804 self._name, 2805 is_constant=False, 2806 parallel_iterations=self._parallel_iterations, 2807 use_input_shape=False, 2808 name="b_acc") for x in init_acc 2809 ] 2810 # Manually set appropriate partial shapes. 2811 enter_acc[0].set_shape([None]) 2812 if values_acc.shape.dims is not None: 2813 enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:]) 2814 self.loop_enters.extend(enter_acc) 2815 2816 merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc] 2817 switch_acc = [switch(x, self._pivot) for x in merge_acc] 2818 2819 # The actual accumulation. 2820 acc_indexed_slices = [ 2821 array_ops.concat([xa[1], xv], 0) 2822 for xa, xv in zip(switch_acc[:2], [indices, values]) 2823 ] 2824 if shape_acc is not None: 2825 # For the shape we just keep the maximum 2826 acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1])) 2827 2828 next_acc = [_NextIteration(x) for x in acc_indexed_slices] 2829 for xm, xn in zip(merge_acc, next_acc): 2830 xm.op._update_input(1, xn) # pylint: disable=protected-access 2831 2832 exit_acc = [exit(x[0], name="b_acc") for x in switch_acc] 2833 self.loop_exits.extend(exit_acc) 2834 2835 self.ExitResult(exit_acc) 2836 return ops.IndexedSlices( 2837 indices=exit_acc[0], 2838 values=exit_acc[1], 2839 dense_shape=exit_acc[2] if shape_acc is not None else None) 2840 2841 def _InitializeValues(self, values): 2842 """Makes the values known to this context.""" 2843 self._values = set() 2844 for x in values: 2845 if isinstance(x, ops.Tensor): 2846 self._values.add(x.name) 2847 else: 2848 raise TypeError("Type %s not supported" % type(x)) 2849 2850 def _BuildLoop(self, pred, body, original_loop_vars, loop_vars, 2851 shape_invariants): 2852 """Core: Add the loop termination condition and body to the graph.""" 2853 flat_loop_vars = nest.flatten(original_loop_vars, expand_composites=True) 2854 2855 # Let the context know the loop variables so the loop variables 2856 # would be added in the outer contexts properly. 2857 self._InitializeValues(loop_vars) 2858 real_vars = loop_vars 2859 if self._outer_context: 2860 real_vars = [self._outer_context.AddValue(x) for x in loop_vars] 2861 with ops.control_dependencies(None): 2862 enter_vars = [ 2863 _Enter( 2864 x, 2865 self._name, 2866 is_constant=False, 2867 parallel_iterations=self._parallel_iterations, 2868 use_input_shape=(shape_invariants is None)) for x in real_vars 2869 ] 2870 for x in enter_vars: 2871 x.graph.prevent_feeding(x) 2872 if self._outer_context: 2873 self._outer_context.AddInnerOp(x.op) 2874 2875 # Finds the closest enclosing non-None control pivot. 2876 outer_context = self._outer_context 2877 control_pivot = None 2878 while outer_context is not None and control_pivot is None: 2879 control_pivot = outer_context.GetControlPivot() 2880 # pylint: disable=protected-access 2881 outer_context = outer_context._outer_context 2882 # pylint: enable=protected-access 2883 2884 if control_pivot is not None: 2885 for var in enter_vars: 2886 if util.IsLoopConstantEnter(var.op.inputs[0].op): 2887 # pylint: disable=protected-access 2888 var.op._add_control_input(control_pivot.op) 2889 # pylint: enable=protected-access 2890 _SetShapeInvariants(real_vars, enter_vars, shape_invariants) 2891 2892 # Fix the control inputs and control flow context of these enter ops. 2893 self._FixControlInputsAndContext(enter_vars) 2894 self._InitializeValues(enter_vars) 2895 self._loop_enters = enter_vars 2896 2897 merge_vars = [merge([x, x])[0] for x in enter_vars] 2898 self._pivot_for_pred = merge_vars[0] 2899 2900 # Build the graph for pred. 2901 merge_vars_with_tensor_arrays = ( 2902 _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars)) 2903 packed_vars = nest.pack_sequence_as( 2904 structure=original_loop_vars, 2905 flat_sequence=merge_vars_with_tensor_arrays, 2906 expand_composites=True) 2907 c = ops.convert_to_tensor(pred(*packed_vars)) 2908 self._pivot = loop_cond(c, name="LoopCond") 2909 switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars] 2910 2911 # Build the graph for body. 2912 vars_for_body = [_Identity(x[1]) for x in switch_vars] 2913 self._pivot_for_body = vars_for_body[0] 2914 # Convert TensorArray flow variables inside the context back into 2915 # their associated TensorArrays for calling the body. 2916 vars_for_body_with_tensor_arrays = ( 2917 _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body)) 2918 packed_vars_for_body = nest.pack_sequence_as( 2919 structure=original_loop_vars, 2920 flat_sequence=vars_for_body_with_tensor_arrays, 2921 expand_composites=True) 2922 pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2923 body_result = body(*packed_vars_for_body) 2924 post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2925 if not nest.is_sequence_or_composite(body_result): 2926 body_result = [body_result] 2927 if len(post_summaries) > len(pre_summaries): 2928 new_summaries = post_summaries[len(pre_summaries):] 2929 summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access 2930 summary_ref[:] = pre_summaries 2931 with ops.control_dependencies(new_summaries): 2932 2933 def map_fn(x): 2934 # TODO(apassos) figure out how to trigger with tensor arrays as well 2935 if isinstance(x, tensor_array_ops.TensorArray): 2936 return x 2937 return array_ops.identity(x) 2938 2939 body_result = nest.map_structure(map_fn, body_result, 2940 expand_composites=True) 2941 2942 # Compare the structure types of input and output of body. 2943 # For backwards compatibility, the first layer is forced to a list 2944 # during this comparison, because inputs are typically lists and 2945 # outputs of the body are typically tuples. 2946 nest.assert_same_structure(list(packed_vars_for_body), list(body_result), 2947 expand_composites=True) 2948 2949 # Store body_result to keep track of TensorArrays returned by body 2950 original_body_result = body_result 2951 # Convert TensorArrays returned by body into their flow variables 2952 result = nest.map_structure( 2953 _convert_tensorarray_to_flow, 2954 nest.flatten(body_result, expand_composites=True), 2955 expand_composites=True) 2956 result = ops.convert_n_to_tensor_or_composite(result) 2957 2958 # Add NextIteration and the back edges to complete the loop. 2959 if len(merge_vars) != len(result): 2960 raise ValueError("Number of inputs and outputs of body must match " 2961 "loop_vars: %d, %d" % (len(merge_vars), len(result))) 2962 next_vars = [] 2963 for m, v in zip(merge_vars, result): 2964 next_vars.append(_AddNextAndBackEdge(m, v)) 2965 2966 # Add the exit ops. 2967 exit_vars = [exit(x[0]) for x in switch_vars] 2968 self._loop_exits = exit_vars 2969 2970 # Exit the loop. 2971 self.ExitResult(exit_vars) 2972 2973 return original_body_result, exit_vars 2974 2975 def BuildLoop(self, pred, body, loop_vars, shape_invariants, 2976 return_same_structure): 2977 """Add the loop termination condition and body to the graph.""" 2978 2979 # Keep original_loop_vars to identify which are TensorArrays 2980 original_loop_vars = loop_vars 2981 # Convert TensorArrays to their flow variables 2982 loop_vars = nest.map_structure( 2983 _convert_tensorarray_to_flow, 2984 nest.flatten(loop_vars, expand_composites=False), 2985 expand_composites=True) 2986 loop_vars = ops.convert_n_to_tensor_or_composite(loop_vars) 2987 if shape_invariants is None: 2988 shape_invariants = nest.map_structure( 2989 _get_shape_invariant, loop_vars, expand_composites=False) 2990 loop_vars = nest.flatten(loop_vars, expand_composites=True) 2991 try: 2992 self.Enter() 2993 # _BuildLoop calls _update_input in several places. _mutation_lock() 2994 # ensures a Session.run call cannot occur between creating and mutating 2995 # new ops. 2996 with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access 2997 original_body_result, exit_vars = self._BuildLoop( 2998 pred, body, original_loop_vars, loop_vars, shape_invariants) 2999 finally: 3000 self.Exit() 3001 3002 flat_result = nest.flatten(original_body_result, expand_composites=True) 3003 # Convert TensorArray flow variables outside the context back into 3004 # their associated TensorArrays for returning to caller. 3005 exit_vars_with_tensor_arrays = ( 3006 _convert_flows_to_tensorarrays(flat_result, exit_vars)) 3007 packed_exit_vars = nest.pack_sequence_as( 3008 structure=original_body_result, 3009 flat_sequence=exit_vars_with_tensor_arrays, 3010 expand_composites=True) 3011 3012 if return_same_structure: 3013 return packed_exit_vars 3014 else: 3015 return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars 3016 3017 def _FixControlInputsAndContext(self, enters): 3018 graph = ops.get_default_graph() 3019 # pylint: disable=protected-access 3020 for e in enters: 3021 if isinstance(e, ops.Tensor): 3022 xs = [e] 3023 else: 3024 raise TypeError("Type %s not supported" % type(e)) 3025 for x in xs: 3026 inp_op = x.op.inputs[0].op 3027 control_inputs = graph._control_dependencies_for_inputs([inp_op]) 3028 outer_control_inputs = [ 3029 op for op in control_inputs if self._IsInOuterContext(op) 3030 ] 3031 x.op._set_control_flow_context(self) 3032 x.op._add_control_inputs(outer_control_inputs) 3033 graph._record_op_seen_by_control_dependencies(x.op) 3034 # pylint: enable=protected-access 3035 3036 def IsWhileContext(self): 3037 return True 3038 3039 3040# pylint: disable=redefined-outer-name 3041@tf_export("while_loop", v1=[]) 3042def while_loop_v2(cond, 3043 body, 3044 loop_vars, 3045 shape_invariants=None, 3046 parallel_iterations=10, 3047 back_prop=True, 3048 swap_memory=False, 3049 maximum_iterations=None, 3050 name=None): 3051 """Repeat `body` while the condition `cond` is true. 3052 3053 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 3054 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 3055 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 3056 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 3057 `cond` and `body`. `cond` and `body` both take as many arguments as there are 3058 `loop_vars`. 3059 3060 In addition to regular Tensors or IndexedSlices, the body may accept and 3061 return TensorArray objects. The flows of the TensorArray objects will 3062 be appropriately forwarded between loops and during gradient calculations. 3063 3064 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 3065 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 3066 stitches together the graph fragments created during the `cond` and `body` 3067 calls with some additional graph nodes to create the graph flow that 3068 repeats `body` until `cond` returns false. 3069 3070 For correctness, `tf.while_loop()` strictly enforces shape invariants for 3071 the loop variables. A shape invariant is a (possibly partial) shape that 3072 is unchanged across the iterations of the loop. An error will be raised 3073 if the shape of a loop variable after an iteration is determined to be more 3074 general than or incompatible with its shape invariant. For example, a shape 3075 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 3076 compatible with [11, 17]. By default (if the argument `shape_invariants` is 3077 not specified), it is assumed that the initial shape of each tensor in 3078 `loop_vars` is the same in every iteration. The `shape_invariants` argument 3079 allows the caller to specify a less specific shape invariant for each loop 3080 variable, which is needed if the shape varies between iterations. The 3081 `tf.Tensor.set_shape` 3082 function may also be used in the `body` function to indicate that 3083 the output loop variable has a particular shape. The shape invariant for 3084 SparseTensor and IndexedSlices are treated specially as follows: 3085 3086 a) If a loop variable is a SparseTensor, the shape invariant must be 3087 TensorShape([r]) where r is the rank of the dense tensor represented 3088 by the sparse tensor. It means the shapes of the three tensors of the 3089 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 3090 is the shape of the SparseTensor.dense_shape property. It must be the shape of 3091 a vector. 3092 3093 b) If a loop variable is an IndexedSlices, the shape invariant must be 3094 a shape invariant of the values tensor of the IndexedSlices. It means 3095 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 3096 [shape.ndims]). 3097 3098 `while_loop` implements non-strict semantics, enabling multiple iterations 3099 to run in parallel. The maximum number of parallel iterations can be 3100 controlled by `parallel_iterations`, which gives users some control over 3101 memory consumption and execution order. For correct programs, `while_loop` 3102 should return the same result for any parallel_iterations > 0. 3103 3104 For training, TensorFlow stores the tensors that are produced in the 3105 forward inference and are needed in back propagation. These tensors are a 3106 main source of memory consumption and often cause OOM errors when training 3107 on GPUs. When the flag swap_memory is true, we swap out these tensors from 3108 GPU to CPU. This for example allows us to train RNN models with very long 3109 sequences and large batches. 3110 3111 Args: 3112 cond: A callable that represents the termination condition of the loop. 3113 body: A callable that represents the loop body. 3114 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 3115 `Tensor`, and `TensorArray` objects. 3116 shape_invariants: The shape invariants for the loop variables. 3117 parallel_iterations: The number of iterations allowed to run in parallel. It 3118 must be a positive integer. 3119 back_prop: Whether backprop is enabled for this while loop. 3120 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 3121 maximum_iterations: Optional maximum number of iterations of the while loop 3122 to run. If provided, the `cond` output is AND-ed with an additional 3123 condition ensuring the number of iterations executed is no greater than 3124 `maximum_iterations`. 3125 name: Optional name prefix for the returned tensors. 3126 3127 Returns: 3128 The output tensors for the loop variables after the loop. The return value 3129 has the same structure as `loop_vars`. 3130 3131 Raises: 3132 TypeError: if `cond` or `body` is not callable. 3133 ValueError: if `loop_vars` is empty. 3134 3135 Example: 3136 3137 ```python 3138 i = tf.constant(0) 3139 c = lambda i: tf.less(i, 10) 3140 b = lambda i: tf.add(i, 1) 3141 r = tf.while_loop(c, b, [i]) 3142 ``` 3143 3144 Example with nesting and a namedtuple: 3145 3146 ```python 3147 import collections 3148 Pair = collections.namedtuple('Pair', 'j, k') 3149 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 3150 c = lambda i, p: i < 10 3151 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 3152 ijk_final = tf.while_loop(c, b, ijk_0) 3153 ``` 3154 3155 Example using shape_invariants: 3156 3157 ```python 3158 i0 = tf.constant(0) 3159 m0 = tf.ones([2, 2]) 3160 c = lambda i, m: i < 10 3161 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 3162 tf.while_loop( 3163 c, b, loop_vars=[i0, m0], 3164 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 3165 ``` 3166 3167 Example which demonstrates non-strict semantics: In the following 3168 example, the final value of the counter `i` does not depend on `x`. So 3169 the `while_loop` can increment the counter parallel to updates of `x`. 3170 However, because the loop counter at one loop iteration depends 3171 on the value at the previous iteration, the loop counter itself cannot 3172 be incremented in parallel. Hence if we just want the final value of the 3173 counter (which we print on the line `print(sess.run(i))`), then 3174 `x` will never be incremented, but the counter will be updated on a 3175 single thread. Conversely, if we want the value of the output (which we 3176 print on the line `print(sess.run(out).shape)`), then the counter may be 3177 incremented on its own thread, while `x` can be incremented in 3178 parallel on a separate thread. In the extreme case, it is conceivable 3179 that the thread incrementing the counter runs until completion before 3180 `x` is incremented even a single time. The only thing that can never 3181 happen is that the thread updating `x` can never get ahead of the 3182 counter thread because the thread incrementing `x` depends on the value 3183 of the counter. 3184 3185 ```python 3186 import tensorflow as tf 3187 3188 n = 10000 3189 x = tf.constant(list(range(n))) 3190 c = lambda i, x: i < n 3191 b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:")) 3192 i, out = tf.while_loop(c, b, (0, x)) 3193 with tf.Session() as sess: 3194 print(sess.run(i)) # prints [0] ... [9999] 3195 3196 # The following line may increment the counter and x in parallel. 3197 # The counter thread may get ahead of the other thread, but not the 3198 # other way around. So you may see things like 3199 # [9996] x:[9987] 3200 # meaning that the counter thread is on iteration 9996, 3201 # while the other thread is on iteration 9987 3202 print(sess.run(out).shape) 3203 ``` 3204 3205 """ 3206 return while_loop( 3207 cond=cond, 3208 body=body, 3209 loop_vars=loop_vars, 3210 shape_invariants=shape_invariants, 3211 parallel_iterations=parallel_iterations, 3212 back_prop=back_prop, 3213 swap_memory=swap_memory, 3214 name=name, 3215 maximum_iterations=maximum_iterations, 3216 return_same_structure=True) 3217 3218 3219# pylint: disable=redefined-outer-name 3220@tf_export(v1=["while_loop"]) 3221def while_loop(cond, 3222 body, 3223 loop_vars, 3224 shape_invariants=None, 3225 parallel_iterations=10, 3226 back_prop=True, 3227 swap_memory=False, 3228 name=None, 3229 maximum_iterations=None, 3230 return_same_structure=False): 3231 """Repeat `body` while the condition `cond` is true. 3232 3233 `cond` is a callable returning a boolean scalar tensor. `body` is a callable 3234 returning a (possibly nested) tuple, namedtuple or list of tensors of the same 3235 arity (length and structure) and types as `loop_vars`. `loop_vars` is a 3236 (possibly nested) tuple, namedtuple or list of tensors that is passed to both 3237 `cond` and `body`. `cond` and `body` both take as many arguments as there are 3238 `loop_vars`. 3239 3240 In addition to regular Tensors or IndexedSlices, the body may accept and 3241 return TensorArray objects. The flows of the TensorArray objects will 3242 be appropriately forwarded between loops and during gradient calculations. 3243 3244 Note that `while_loop` calls `cond` and `body` *exactly once* (inside the 3245 call to `while_loop`, and not at all during `Session.run()`). `while_loop` 3246 stitches together the graph fragments created during the `cond` and `body` 3247 calls with some additional graph nodes to create the graph flow that 3248 repeats `body` until `cond` returns false. 3249 3250 For correctness, `tf.while_loop()` strictly enforces shape invariants for 3251 the loop variables. A shape invariant is a (possibly partial) shape that 3252 is unchanged across the iterations of the loop. An error will be raised 3253 if the shape of a loop variable after an iteration is determined to be more 3254 general than or incompatible with its shape invariant. For example, a shape 3255 of [11, None] is more general than a shape of [11, 17], and [11, 21] is not 3256 compatible with [11, 17]. By default (if the argument `shape_invariants` is 3257 not specified), it is assumed that the initial shape of each tensor in 3258 `loop_vars` is the same in every iteration. The `shape_invariants` argument 3259 allows the caller to specify a less specific shape invariant for each loop 3260 variable, which is needed if the shape varies between iterations. The 3261 `tf.Tensor.set_shape` 3262 function may also be used in the `body` function to indicate that 3263 the output loop variable has a particular shape. The shape invariant for 3264 SparseTensor and IndexedSlices are treated specially as follows: 3265 3266 a) If a loop variable is a SparseTensor, the shape invariant must be 3267 TensorShape([r]) where r is the rank of the dense tensor represented 3268 by the sparse tensor. It means the shapes of the three tensors of the 3269 SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here 3270 is the shape of the SparseTensor.dense_shape property. It must be the shape of 3271 a vector. 3272 3273 b) If a loop variable is an IndexedSlices, the shape invariant must be 3274 a shape invariant of the values tensor of the IndexedSlices. It means 3275 the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]], 3276 [shape.ndims]). 3277 3278 `while_loop` implements non-strict semantics, enabling multiple iterations 3279 to run in parallel. The maximum number of parallel iterations can be 3280 controlled by `parallel_iterations`, which gives users some control over 3281 memory consumption and execution order. For correct programs, `while_loop` 3282 should return the same result for any parallel_iterations > 0. 3283 3284 For training, TensorFlow stores the tensors that are produced in the 3285 forward inference and are needed in back propagation. These tensors are a 3286 main source of memory consumption and often cause OOM errors when training 3287 on GPUs. When the flag swap_memory is true, we swap out these tensors from 3288 GPU to CPU. This for example allows us to train RNN models with very long 3289 sequences and large batches. 3290 3291 Args: 3292 cond: A callable that represents the termination condition of the loop. 3293 body: A callable that represents the loop body. 3294 loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array, 3295 `Tensor`, and `TensorArray` objects. 3296 shape_invariants: The shape invariants for the loop variables. 3297 parallel_iterations: The number of iterations allowed to run in parallel. It 3298 must be a positive integer. 3299 back_prop: Whether backprop is enabled for this while loop. 3300 swap_memory: Whether GPU-CPU memory swap is enabled for this loop. 3301 name: Optional name prefix for the returned tensors. 3302 maximum_iterations: Optional maximum number of iterations of the while loop 3303 to run. If provided, the `cond` output is AND-ed with an additional 3304 condition ensuring the number of iterations executed is no greater than 3305 `maximum_iterations`. 3306 return_same_structure: If True, output has same structure as `loop_vars`. If 3307 eager execution is enabled, this is ignored (and always treated as True). 3308 3309 Returns: 3310 The output tensors for the loop variables after the loop. 3311 If `return_same_structure` is True, the return value has the same 3312 structure as `loop_vars`. 3313 If `return_same_structure` is False, the return value is a Tensor, 3314 TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list 3315 otherwise. 3316 3317 Raises: 3318 TypeError: if `cond` or `body` is not callable. 3319 ValueError: if `loop_vars` is empty. 3320 3321 Example: 3322 3323 ```python 3324 i = tf.constant(0) 3325 c = lambda i: tf.less(i, 10) 3326 b = lambda i: tf.add(i, 1) 3327 r = tf.while_loop(c, b, [i]) 3328 ``` 3329 3330 Example with nesting and a namedtuple: 3331 3332 ```python 3333 import collections 3334 Pair = collections.namedtuple('Pair', 'j, k') 3335 ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2))) 3336 c = lambda i, p: i < 10 3337 b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k))) 3338 ijk_final = tf.while_loop(c, b, ijk_0) 3339 ``` 3340 3341 Example using shape_invariants: 3342 3343 ```python 3344 i0 = tf.constant(0) 3345 m0 = tf.ones([2, 2]) 3346 c = lambda i, m: i < 10 3347 b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] 3348 tf.while_loop( 3349 c, b, loop_vars=[i0, m0], 3350 shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) 3351 ``` 3352 3353 Example which demonstrates non-strict semantics: In the following 3354 example, the final value of the counter `i` does not depend on `x`. So 3355 the `while_loop` can increment the counter parallel to updates of `x`. 3356 However, because the loop counter at one loop iteration depends 3357 on the value at the previous iteration, the loop counter itself cannot 3358 be incremented in parallel. Hence if we just want the final value of the 3359 counter (which we print on the line `print(sess.run(i))`), then 3360 `x` will never be incremented, but the counter will be updated on a 3361 single thread. Conversely, if we want the value of the output (which we 3362 print on the line `print(sess.run(out).shape)`), then the counter may be 3363 incremented on its own thread, while `x` can be incremented in 3364 parallel on a separate thread. In the extreme case, it is conceivable 3365 that the thread incrementing the counter runs until completion before 3366 `x` is incremented even a single time. The only thing that can never 3367 happen is that the thread updating `x` can never get ahead of the 3368 counter thread because the thread incrementing `x` depends on the value 3369 of the counter. 3370 3371 ```python 3372 import tensorflow as tf 3373 3374 n = 10000 3375 x = tf.constant(list(range(n))) 3376 c = lambda i, x: i < n 3377 b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:")) 3378 i, out = tf.while_loop(c, b, (0, x)) 3379 with tf.Session() as sess: 3380 print(sess.run(i)) # prints [0] ... [9999] 3381 3382 # The following line may increment the counter and x in parallel. 3383 # The counter thread may get ahead of the other thread, but not the 3384 # other way around. So you may see things like 3385 # [9996] x:[9987] 3386 # meaning that the counter thread is on iteration 9996, 3387 # while the other thread is on iteration 9987 3388 print(sess.run(out).shape) 3389 ``` 3390 3391 """ 3392 # Always enable control flow v2 if building a function, regardless of toggle. 3393 if (util.EnableControlFlowV2(ops.get_default_graph()) and 3394 not context.executing_eagerly()): 3395 return while_v2.while_loop( 3396 cond, 3397 body, 3398 loop_vars, 3399 shape_invariants=shape_invariants, 3400 parallel_iterations=parallel_iterations, 3401 maximum_iterations=maximum_iterations, 3402 name=name, 3403 return_same_structure=return_same_structure) 3404 3405 with ops.name_scope(name, "while", loop_vars): 3406 if not loop_vars: 3407 raise ValueError("No loop variables provided") 3408 if not callable(cond): 3409 raise TypeError("cond must be callable.") 3410 if not callable(body): 3411 raise TypeError("body must be callable.") 3412 if parallel_iterations < 1: 3413 raise TypeError("parallel_iterations must be a positive integer.") 3414 3415 if maximum_iterations is not None: 3416 maximum_iterations = ops.convert_to_tensor( 3417 maximum_iterations, name="maximum_iterations") 3418 if maximum_iterations.shape.ndims != 0: 3419 raise ValueError("maximum_iterations must be a scalar, saw shape: %s" % 3420 maximum_iterations.shape) 3421 3422 counter = constant_op.constant( 3423 0, dtype=maximum_iterations.dtype, name="iteration_counter") 3424 orig_cond = cond 3425 orig_body = body 3426 if len(loop_vars) == 1: 3427 loop_vars = (counter, loop_vars[0]) 3428 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 3429 math_ops.logical_and(i < maximum_iterations, orig_cond(lv))) 3430 body = lambda i, lv: (i + 1, orig_body(lv)) 3431 else: 3432 loop_vars = (counter, loop_vars) 3433 cond = lambda i, lv: ( # pylint: disable=g-long-lambda 3434 math_ops.logical_and(i < maximum_iterations, orig_cond(*lv))) 3435 body = lambda i, lv: (i + 1, orig_body(*lv)) 3436 3437 if context.executing_eagerly(): 3438 try_to_pack = len(loop_vars) == 1 3439 packed = False # whether the body result was packed into a 1-item tuple 3440 3441 while cond(*loop_vars): 3442 loop_vars = body(*loop_vars) 3443 if try_to_pack and not isinstance(loop_vars, (list, _basetuple)): 3444 packed = True 3445 loop_vars = (loop_vars,) 3446 3447 def convert(x): 3448 if isinstance(x, tensor_array_ops.TensorArray): 3449 return x 3450 return ops.convert_to_tensor(x) 3451 loop_vars = nest.map_structure(convert, loop_vars) 3452 if maximum_iterations is not None: 3453 return loop_vars[1] 3454 else: 3455 return loop_vars[0] if packed else loop_vars 3456 3457 if shape_invariants is not None: 3458 if maximum_iterations is not None: 3459 shape_invariants = (tensor_shape.TensorShape([]), shape_invariants) 3460 3461 nest.assert_same_structure(loop_vars, shape_invariants, 3462 expand_composites=False) 3463 shape_invariants = nest.map_structure( 3464 _get_shape_invariant, loop_vars, shape_invariants, 3465 expand_composites=False) 3466 3467 loop_context = WhileContext( 3468 maximum_iterations=maximum_iterations, 3469 parallel_iterations=parallel_iterations, 3470 back_prop=back_prop, 3471 swap_memory=swap_memory) 3472 # Only add non-nested loops to the collection. Any nested control flow will 3473 # be encapsulated in the root context. 3474 if loop_context.outer_context is None: 3475 ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) 3476 result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, 3477 return_same_structure) 3478 if maximum_iterations is not None: 3479 return result[1] 3480 else: 3481 return result 3482 3483 3484# pylint: enable=redefined-outer-name 3485 3486 3487def _AsTensorList(x, p): 3488 """Return x as a list of Tensors or IndexedSlices. 3489 3490 For entries of `x` that are Operations, this returns an Identity of `p` 3491 with a dependency on the operation. 3492 3493 Args: 3494 x: A Tensor/IndexedSlices/Operation or a list or tuple of them. 3495 p: A Tensor to return for entries in `x` that are Operations. 3496 3497 Returns: 3498 A list of Tensors or IndexedSlices. 3499 """ 3500 if not isinstance(x, (list, _basetuple)): 3501 x = [x] 3502 3503 l = [] 3504 for v in x: 3505 if isinstance(v, ops.Operation): 3506 v = with_dependencies([v], p) 3507 v = ops.convert_to_tensor_or_composite(v) 3508 if isinstance(v, ops.Tensor): 3509 l.append(array_ops.identity(v)) 3510 else: 3511 l.append( 3512 ops.IndexedSlices( 3513 array_ops.identity(v.values), array_ops.identity(v.indices))) 3514 return l 3515 3516 3517def _CheckResults(a, b): 3518 assert len(a) == len(b), ( 3519 "Values returned by a() and b() must have the same length.") 3520 for x, y in zip(a, b): 3521 assert x.dtype == y.dtype, ( 3522 "Values returned by a() [%s] and b() [%s] must have " 3523 "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name)) 3524 3525 3526def with_dependencies(dependencies, output_tensor, name=None): 3527 """Produces the content of `output_tensor` only after `dependencies`. 3528 3529 In some cases, a user may want the output of an operation to be 3530 consumed externally only after some other dependencies have run 3531 first. This function ensures returns `output_tensor`, but only after all 3532 operations in `dependencies` have run. Note that this means that there is 3533 no guarantee that `output_tensor` will be evaluated after any `dependencies` 3534 have run. 3535 3536 See also `tf.tuple` and `tf.group`. 3537 3538 Args: 3539 dependencies: Iterable of operations to run before this op finishes. 3540 output_tensor: A `Tensor` or `IndexedSlices` that will be returned. 3541 name: (Optional) A name for this operation. 3542 3543 Returns: 3544 Same as `output_tensor`. 3545 3546 Raises: 3547 TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`. 3548 """ 3549 if context.executing_eagerly(): 3550 return output_tensor 3551 with ops.name_scope(name, "control_dependency", 3552 list(dependencies) + [output_tensor]) as name: 3553 with ops.colocate_with(output_tensor): 3554 with ops.control_dependencies(dependencies): 3555 output_tensor = ops.convert_to_tensor_or_composite(output_tensor) 3556 if isinstance(output_tensor, ops.Tensor): 3557 return _Identity(output_tensor, name=name) 3558 else: 3559 return ops.IndexedSlices( 3560 _Identity(output_tensor.values, name=name), output_tensor.indices, 3561 output_tensor.dense_shape) 3562 3563 3564def _GroupControlDeps(dev, deps, name=None): 3565 with ops.control_dependencies(deps): 3566 if dev is None: 3567 return no_op(name=name) 3568 else: 3569 with ops.device(dev): 3570 return no_op(name=name) 3571 3572 3573# TODO(touts): Accept "inputs" as a list. 3574@tf_export("group") 3575def group(*inputs, **kwargs): 3576 """Create an op that groups multiple operations. 3577 3578 When this op finishes, all ops in `inputs` have finished. This op has no 3579 output. 3580 3581 See also `tf.tuple` and 3582 `tf.control_dependencies`. 3583 3584 Args: 3585 *inputs: Zero or more tensors to group. 3586 name: A name for this operation (optional). 3587 3588 Returns: 3589 An Operation that executes all its inputs. 3590 3591 Raises: 3592 ValueError: If an unknown keyword argument is provided. 3593 """ 3594 if context.executing_eagerly(): 3595 return None 3596 name = kwargs.pop("name", None) 3597 if kwargs: 3598 raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys())) 3599 with ops.name_scope(name, "group_deps", inputs) as name: 3600 # Grouping no inputs means do nothing 3601 if not inputs: 3602 return no_op(name=name) 3603 3604 # Sorts *inputs according to their devices. 3605 ops_on_device = {} # device -> operations specified on the device. 3606 for inp in nest.flatten(inputs, expand_composites=True): 3607 if not hasattr(inp, "device"): 3608 raise TypeError("Expected tf.group() expected Tensor arguments not " 3609 "'%s' with type '%s'" % (inp, type(inp))) 3610 dev = inp.device 3611 if dev in ops_on_device: 3612 ops_on_device[dev].append(inp) 3613 else: 3614 ops_on_device[dev] = [inp] 3615 if len(ops_on_device) == 1: 3616 # 1-level tree. The root node is the returned NoOp node. 3617 (dev, deps), = ops_on_device.items() 3618 return _GroupControlDeps(dev, deps, name=name) 3619 3620 # 2-level tree. The root node is the returned NoOp node. 3621 # deps contains 1 NoOp node for each device. 3622 deps = [] 3623 3624 def device_key(dev): 3625 """A sort key that allows None to be compared to strings.""" 3626 return "" if dev is None else dev 3627 3628 for dev in sorted(ops_on_device, key=device_key): 3629 deps.append(_GroupControlDeps(dev, ops_on_device[dev])) 3630 3631 with ops.control_dependencies(deps): 3632 return no_op(name=name) 3633 3634 3635@tf_export("tuple", v1=[]) 3636def tuple_v2(tensors, control_inputs=None, name=None): 3637 """Group tensors together. 3638 3639 This creates a tuple of tensors with the same values as the `tensors` 3640 argument, except that the value of each tensor is only returned after the 3641 values of all tensors have been computed. 3642 3643 `control_inputs` contains additional ops that have to finish before this op 3644 finishes, but whose outputs are not returned. 3645 3646 This can be used as a "join" mechanism for parallel computations: all the 3647 argument tensors can be computed in parallel, but the values of any tensor 3648 returned by `tuple` are only available after all the parallel computations 3649 are done. 3650 3651 See also `tf.group` and 3652 `tf.control_dependencies`. 3653 3654 Args: 3655 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 3656 control_inputs: List of additional ops to finish before returning. 3657 name: (optional) A name to use as a `name_scope` for the operation. 3658 3659 Returns: 3660 Same as `tensors`. 3661 3662 Raises: 3663 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 3664 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 3665 objects. 3666 3667 """ 3668 return tuple(tensors=tensors, name=name, control_inputs=control_inputs) # pylint: disable=redefined-builtin 3669 3670 3671@tf_export(v1=["tuple"]) 3672def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined-builtin 3673 """Group tensors together. 3674 3675 This creates a tuple of tensors with the same values as the `tensors` 3676 argument, except that the value of each tensor is only returned after the 3677 values of all tensors have been computed. 3678 3679 `control_inputs` contains additional ops that have to finish before this op 3680 finishes, but whose outputs are not returned. 3681 3682 This can be used as a "join" mechanism for parallel computations: all the 3683 argument tensors can be computed in parallel, but the values of any tensor 3684 returned by `tuple` are only available after all the parallel computations 3685 are done. 3686 3687 See also `tf.group` and 3688 `tf.control_dependencies`. 3689 3690 Args: 3691 tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`. 3692 name: (optional) A name to use as a `name_scope` for the operation. 3693 control_inputs: List of additional ops to finish before returning. 3694 3695 Returns: 3696 Same as `tensors`. 3697 3698 Raises: 3699 ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`. 3700 TypeError: If `control_inputs` is not a list of `Operation` or `Tensor` 3701 objects. 3702 3703 """ 3704 if context.executing_eagerly(): 3705 return tensors 3706 with ops.name_scope(name, "tuple", tensors) as name: 3707 tensors = [ 3708 t if (isinstance(t, ops.Operation) or tensor_util.is_tensor(t) or 3709 t is None) else ops.convert_to_tensor(t) for t in tensors 3710 ] 3711 gating_ops = [ 3712 t if isinstance(t, ops.Operation) else t.op 3713 for t in tensors 3714 if t is not None 3715 ] 3716 if control_inputs: 3717 for c in control_inputs: 3718 if isinstance(c, ops.Tensor): 3719 c = c.op 3720 elif not isinstance(c, ops.Operation): 3721 raise TypeError("Control input must be Operation or Tensor: %s" % c) 3722 gating_ops.append(c) 3723 # Note that in order to ensure ordering in the pbtxt, we must take care to 3724 # ensure the order here. 3725 gating_ops = sorted(set(gating_ops), key=lambda op: op._id) # Uniquify ops. 3726 if not gating_ops: 3727 raise ValueError("Must have at least one Tensor: %s" % tensors) 3728 gate = group(*gating_ops) 3729 tpl = [] 3730 for t in tensors: 3731 if tensor_util.is_tensor(t): 3732 tpl.append(with_dependencies([gate], t)) 3733 elif isinstance(t, ops.Operation): 3734 with ops.control_dependencies([gate]): 3735 tpl.append(group(t)) 3736 else: 3737 tpl.append(None) 3738 return tpl 3739 3740 3741def _assert_at_most_n_true(predicates, n, msg): 3742 """Returns an Assert op that checks that at most n predicates are True. 3743 3744 Args: 3745 predicates: list of bool scalar tensors. 3746 n: maximum number of true predicates allowed. 3747 msg: Error message. 3748 """ 3749 preds_c = array_ops.stack(predicates, name="preds_c") 3750 num_true_conditions = math_ops.reduce_sum( 3751 math_ops.cast(preds_c, dtypes.int32), name="num_true_conds") 3752 condition = math_ops.less_equal(num_true_conditions, 3753 constant_op.constant(n, name="n_true_conds")) 3754 preds_names = ", ".join(getattr(p, "name", "?") for p in predicates) 3755 error_msg = [ 3756 "%s: more than %d conditions (%s) evaluated as True:" % 3757 (msg, n, preds_names), preds_c 3758 ] 3759 return Assert(condition, data=error_msg, summarize=len(predicates)) 3760 3761 3762def _case_create_default_action(predicates, actions): 3763 """Creates default action for a list of actions and their predicates. 3764 3765 It uses the input actions to select an arbitrary as default and makes sure 3766 that corresponding predicates have valid values. 3767 3768 Args: 3769 predicates: a list of bool scalar tensors 3770 actions: a list of callable objects which return tensors. 3771 3772 Returns: 3773 a callable 3774 """ 3775 k = len(predicates) - 1 # could pick any 3776 predicate, action = predicates[k], actions[k] 3777 other_predicates, other_actions = predicates[:k], actions[:k] 3778 3779 def default_action(): 3780 others_msg = ("Implementation error: " 3781 "selected default action #%d was called, but some of other " 3782 "predicates are True: " % k) 3783 default_msg = ("Input error: " 3784 "None of conditions evaluated as True:", 3785 array_ops.stack(predicates, name="preds_c")) 3786 with ops.control_dependencies([ 3787 _assert_at_most_n_true(other_predicates, n=0, msg=others_msg), 3788 Assert(predicate, data=default_msg) 3789 ]): 3790 return action() 3791 3792 return default_action, other_predicates, other_actions 3793 3794 3795def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name, 3796 allow_python_preds): 3797 """Verifies input arguments for the case function. 3798 3799 Args: 3800 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 3801 callable which returns a list of tensors. 3802 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3803 name: A name for the case operation. 3804 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 3805 addition to boolean Tensors 3806 3807 Raises: 3808 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3809 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3810 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3811 callable. 3812 3813 Returns: 3814 a tuple <list of scalar bool tensors, list of callables>. 3815 """ 3816 if not isinstance(pred_fn_pairs, (list, _basetuple, dict)): 3817 raise TypeError("fns must be a list, tuple, or dict") 3818 3819 if isinstance(pred_fn_pairs, collections.OrderedDict): 3820 pred_fn_pairs = pred_fn_pairs.items() 3821 elif isinstance(pred_fn_pairs, dict): 3822 if context.executing_eagerly(): 3823 # No name to sort on in eager mode. Use dictionary traversal order, 3824 # which is nondeterministic in versions of Python < 3.6 3825 if not exclusive: 3826 raise ValueError("Unordered dictionaries are not supported for the " 3827 "`pred_fn_pairs` argument when `exclusive=False` and " 3828 "eager mode is enabled.") 3829 pred_fn_pairs = list(pred_fn_pairs.items()) 3830 else: 3831 pred_fn_pairs = sorted( 3832 pred_fn_pairs.items(), key=lambda item: item[0].name) 3833 if not exclusive: 3834 logging.warn( 3835 "%s: An unordered dictionary of predicate/fn pairs was " 3836 "provided, but exclusive=False. The order of conditional " 3837 "tests is deterministic but not guaranteed.", name) 3838 for pred_fn_pair in pred_fn_pairs: 3839 if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2: 3840 raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple") 3841 pred, fn = pred_fn_pair 3842 3843 if isinstance(pred, ops.Tensor): 3844 if pred.dtype != dtypes.bool: 3845 raise TypeError("pred must be Tensor of type bool: %s" % pred.name) 3846 elif not allow_python_preds: 3847 raise TypeError("pred must be a Tensor, got: %s" % pred) 3848 elif not isinstance(pred, bool): 3849 raise TypeError("pred must be a Tensor or bool, got: %s" % pred) 3850 3851 if not callable(fn): 3852 raise TypeError("fn for pred %s must be callable." % pred.name) 3853 3854 predicates, actions = zip(*pred_fn_pairs) 3855 return predicates, actions 3856 3857 3858def _case_helper(cond_fn, 3859 pred_fn_pairs, 3860 default, 3861 exclusive, 3862 name, 3863 allow_python_preds=False, 3864 **cond_kwargs): 3865 """Implementation of case that allows for different cond functions. 3866 3867 Args: 3868 cond_fn: method that has signature and semantics of `cond` above. 3869 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a 3870 callable which returns a list of tensors. 3871 default: Optional callable that returns a list of tensors. 3872 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3873 name: A name for this operation (optional). 3874 allow_python_preds: if true, pred_fn_pairs may contain Python bools in 3875 addition to boolean Tensors 3876 **cond_kwargs: keyword arguments that will be passed to `cond_fn`. 3877 3878 Returns: 3879 The tensors returned by the first pair whose predicate evaluated to True, or 3880 those returned by `default` if none does. 3881 3882 Raises: 3883 TypeError: If `pred_fn_pairs` is not a list/dictionary. 3884 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 3885 TypeError: If `fns[i]` is not callable for any i, or `default` is not 3886 callable. 3887 """ 3888 predicates, actions = _case_verify_and_canonicalize_args( 3889 pred_fn_pairs, exclusive, name, allow_python_preds) 3890 with ops.name_scope(name, "case", [predicates]): 3891 if default is None: 3892 default, predicates, actions = _case_create_default_action( 3893 predicates, actions) 3894 fn = default 3895 # To eval conditions in direct order we create nested conditions in reverse: 3896 # cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...)) 3897 for predicate, action in reversed(list(zip(predicates, actions))): 3898 fn = functools.partial( 3899 cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs) 3900 if exclusive: 3901 with ops.control_dependencies([ 3902 _assert_at_most_n_true( 3903 predicates, n=1, msg="Input error: exclusive=True") 3904 ]): 3905 return fn() 3906 else: 3907 return fn() 3908 3909 3910@tf_export("case") 3911def case(pred_fn_pairs, 3912 default=None, 3913 exclusive=False, 3914 strict=False, 3915 name="case"): 3916 """Create a case operation. 3917 3918 The `pred_fn_pairs` parameter is a dict or list of pairs of size N. 3919 Each pair contains a boolean scalar tensor and a python callable that 3920 creates the tensors to be returned if the boolean evaluates to True. 3921 `default` is a callable generating a list of tensors. All the callables 3922 in `pred_fn_pairs` as well as `default` (if provided) should return the same 3923 number and types of tensors. 3924 3925 If `exclusive==True`, all predicates are evaluated, and an exception is 3926 thrown if more than one of the predicates evaluates to `True`. 3927 If `exclusive==False`, execution stops at the first predicate which 3928 evaluates to True, and the tensors generated by the corresponding function 3929 are returned immediately. If none of the predicates evaluate to True, this 3930 operation returns the tensors generated by `default`. 3931 3932 `tf.case` supports nested structures as implemented in 3933 `tf.contrib.framework.nest`. All of the callables must return the same 3934 (possibly nested) value structure of lists, tuples, and/or named tuples. 3935 Singleton lists and tuples form the only exceptions to this: when returned by 3936 a callable, they are implicitly unpacked to single values. This 3937 behavior is disabled by passing `strict=True`. 3938 3939 If an unordered dictionary is used for `pred_fn_pairs`, the order of the 3940 conditional tests is not guaranteed. However, the order is guaranteed to be 3941 deterministic, so that variables created in conditional branches are created 3942 in fixed order across runs. 3943 3944 @compatibility{eager} 3945 Unordered dictionaries are not supported in eager mode when `exclusive=False`. 3946 Use a list of tuples instead. 3947 @end_compatibility 3948 3949 3950 **Example 1:** 3951 3952 Pseudocode: 3953 3954 ``` 3955 if (x < y) return 17; 3956 else return 23; 3957 ``` 3958 3959 Expressions: 3960 3961 ```python 3962 f1 = lambda: tf.constant(17) 3963 f2 = lambda: tf.constant(23) 3964 r = tf.case([(tf.less(x, y), f1)], default=f2) 3965 ``` 3966 3967 **Example 2:** 3968 3969 Pseudocode: 3970 3971 ``` 3972 if (x < y && x > z) raise OpError("Only one predicate may evaluate to True"); 3973 if (x < y) return 17; 3974 else if (x > z) return 23; 3975 else return -1; 3976 ``` 3977 3978 Expressions: 3979 3980 ```python 3981 def f1(): return tf.constant(17) 3982 def f2(): return tf.constant(23) 3983 def f3(): return tf.constant(-1) 3984 r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2}, 3985 default=f3, exclusive=True) 3986 ``` 3987 3988 Args: 3989 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 3990 callable which returns a list of tensors. 3991 default: Optional callable that returns a list of tensors. 3992 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 3993 strict: A boolean that enables/disables 'strict' mode; see above. 3994 name: A name for this operation (optional). 3995 3996 Returns: 3997 The tensors returned by the first pair whose predicate evaluated to True, or 3998 those returned by `default` if none does. 3999 4000 Raises: 4001 TypeError: If `pred_fn_pairs` is not a list/dictionary. 4002 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 4003 TypeError: If `fns[i]` is not callable for any i, or `default` is not 4004 callable. 4005 """ 4006 return _case_helper( 4007 cond, 4008 pred_fn_pairs, 4009 default, 4010 exclusive, 4011 name, 4012 allow_python_preds=False, 4013 strict=strict) 4014 4015 4016class XLAControlFlowContext(ControlFlowContext): 4017 """Base class for XLA and TPU control flow contexts.""" 4018 4019 def __init__(self): 4020 super(XLAControlFlowContext, self).__init__() 4021 self._name = "XLAControlFlowContext" 4022 4023 def to_control_flow_context_def(self, context_def, export_scope=None): 4024 # pylint: disable=useless-super-delegation 4025 # NOTE(slebedev): the method is required by `ControlFlowContext`. 4026 super(XLAControlFlowContext, self).to_control_flow_context_def( 4027 context_def, export_scope) 4028 4029 def IsXLAContext(self): 4030 return True 4031 4032 def AddOp(self, _): 4033 pass 4034 4035 def AddValue(self, x): 4036 return x 4037 4038 4039def from_control_flow_context_def(context_def, import_scope=None): 4040 """Deserializes `context_def` into the appropriate ControlFlowContext. 4041 4042 Args: 4043 context_def: ControlFlowContextDef proto 4044 import_scope: Optional `string`. Name scope to add. 4045 4046 Returns: 4047 A ControlFlowContext subclass 4048 """ 4049 if context_def.HasField("cond_ctxt"): 4050 return CondContext.from_proto( 4051 context_def.cond_ctxt, import_scope=import_scope) 4052 if context_def.HasField("while_ctxt"): 4053 return WhileContext.from_proto( 4054 context_def.while_ctxt, import_scope=import_scope) 4055 raise NotImplementedError("Unknown ControlFlowContextDef field: %s" % 4056 context_def.WhichOneof("ctxt")) 4057 4058 4059ops.register_proto_function( 4060 ops.GraphKeys.COND_CONTEXT, 4061 proto_type=control_flow_pb2.CondContextDef, 4062 to_proto=CondContext.to_proto, 4063 from_proto=CondContext.from_proto) 4064 4065ops.register_proto_function( 4066 ops.GraphKeys.WHILE_CONTEXT, 4067 proto_type=control_flow_pb2.WhileContextDef, 4068 to_proto=WhileContext.to_proto, 4069 from_proto=WhileContext.from_proto) 4070