1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for managing state of v1 control flow for computing gradients.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import control_flow_util as util 28from tensorflow.python.ops import control_flow_v2_func_graphs 29from tensorflow.python.ops import default_gradient 30from tensorflow.python.ops import gen_data_flow_ops 31from tensorflow.python.ops import gen_resource_variable_ops 32from tensorflow.python.ops import resource_variable_ops 33 34# pylint: disable=protected-access 35 36 37def _GetMaxSizeFromNestedMaximumIterations(value, while_ctxt): 38 """Calculate a max_size for use by stack ops inside an XLA while_loop. 39 40 Args: 41 value: The value inside the while_loop forward context. Used for printing 42 error messages. 43 while_ctxt: The forward context inside which value resides. This does not 44 always match the value's immediate context, as `value` may be inside e.g. 45 a cond context inside the while_loop. 46 47 Returns: 48 A tensor containing the `max_size` to feed to a Stack initializer. 49 50 Raises: 51 ValueError: If `value` is nested inside a `while_loop` that either 52 lacks a `maximum_iterations` parameter, or the `maximum_iterations` 53 parameter: 54 55 - is inside a `while_loop` that is a parent of the calling context, and 56 - cannot be evaluated at graph build time to a constant. 57 """ 58 value_name = value.name 59 # curr_ctxt is the context that tf.gradients was called in. 60 curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access 61 62 curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else "" 63 max_size = constant_op.constant(1) 64 65 # Loop through all containing while contexts between value and the 66 # current context, multiplying together each context's 67 # max_iterations to get the maximum stack size. 68 while while_ctxt not in (None, curr_ctxt): 69 max_iter = while_ctxt.maximum_iterations 70 if max_iter is None: 71 raise ValueError( 72 "Cannot create a gradient accumulator for tensor '%s' inside " 73 "XLA while_loop because maximum_iterations was not passed to " 74 "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name)) 75 76 # pylint: disable=protected-access 77 max_iter_ctxt = max_iter.op._get_control_flow_context() 78 # pylint: enable=protected-access 79 80 # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use. 81 if util.IsContainingContext(curr_ctxt, max_iter_ctxt): 82 max_size *= max_iter 83 else: 84 # We cannot use max_iter because it's defined in a nested while 85 # or cond context, so will fail if we try to use it as input to 86 # any ops in curr_ctxt (e.g. max_size or the final accumulator 87 # stack). Attempt to get a constant value out to use instead. 88 const_max_iter = tensor_util.constant_value(max_iter) 89 if const_max_iter is None: 90 raise ValueError( 91 "Cannot create a gradient accumulator for tensor '%s' inside XLA " 92 "while_loop. maximum_iterations tensor '%s' for while_loop context " 93 "'%s' must be statically known (e.g. a constant value or known " 94 "shape dimension), or be defined at or outside the while loop " 95 "context '%s' (currently defined in '%s')." % 96 (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name, 97 max_iter_ctxt.name)) 98 max_size *= const_max_iter 99 100 # Find the next outer WhileContext (or stop if we reach the 101 # tf.gradient's context). 102 while_ctxt = util.GetContainingWhileContext( 103 while_ctxt.outer_context, stop_ctxt=curr_ctxt) 104 105 return max_size 106 107 108class _GradLoopState(object): 109 """The state used for constructing the gradient graph for a while loop. 110 111 We create a _GradLoopState for each while loop in forward and its 112 corresponding while loop in backprop. This gives us access to both 113 the forward and the backprop WhileContexts. 114 115 During the construction of gradient graph, any time when we detect 116 a forward value that is needed for backprop, we create a history 117 accumulator and add it to `history_map`. Any time when we backprop 118 a loop switch op (in _SwitchGrad), we add the grad merge op in 119 `switch_map`. 120 """ 121 122 def __init__(self, forward_ctxt, outer_grad_state): 123 # The grad loop state for the outer while loop. 124 self._outer_grad_state = None 125 126 # The while loop context for forward. 127 self._forward_context = None 128 129 # The loop counter added by AddForwardLoopCounter. It is the value 130 # of the loop counter for the next iteration. 131 self._forward_index = None 132 133 # A sync op for forward. 134 self._forward_sync = None 135 136 # The while loop context for backprop. 137 self._grad_context = None 138 139 # The loop counter added by AddBackpropLoopCounter. It is the value 140 # of the loop counter for the current iteration. 141 self._grad_index = None 142 143 # A sync op for backprop. 144 self._grad_sync = None 145 146 # Information needed by backprop. 147 self._history_map = {} 148 self._switch_map = {} 149 self._unused_exits = [] 150 self._deferred_exits = [] 151 self._forward_loop_exits = list(forward_ctxt.loop_exits) 152 self._pending_exits_count = len(forward_ctxt.loop_exits) 153 154 self._outer_grad_state = outer_grad_state 155 if outer_grad_state: 156 outer_forward_ctxt = outer_grad_state.forward_context 157 else: 158 if not hasattr(forward_ctxt, "outer_context"): 159 raise ValueError("Failed to call gradients on a while loop without" 160 "properly serializing graph via MetaGraphDef") 161 outer_forward_ctxt = forward_ctxt.outer_context 162 163 # Add the forward loop counter. 164 with forward_ctxt._graph.as_default(): # pylint: disable=protected-access 165 if outer_forward_ctxt: 166 outer_forward_ctxt.Enter() 167 cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state) 168 if outer_forward_ctxt: 169 outer_forward_ctxt.Exit() 170 self._forward_context = forward_ctxt 171 self._forward_index = forward_index 172 173 # Add the backprop WhileContext, and the backprop loop counter. 174 if outer_grad_state: 175 # This is a nested loop. Remember the iteration counts for each 176 # execution of this inner loop. 177 outer_forward_ctxt.AddName(cnt.name) 178 history_cnt = outer_grad_state.AddForwardAccumulator(cnt) 179 180 outer_grad_ctxt = outer_grad_state.grad_context 181 outer_grad_ctxt.Enter() 182 self._grad_context = control_flow_ops.WhileContext( 183 maximum_iterations=forward_ctxt.maximum_iterations, 184 parallel_iterations=forward_ctxt.parallel_iterations, 185 back_prop=forward_ctxt.back_prop, 186 swap_memory=forward_ctxt.swap_memory, 187 name=forward_ctxt.name, 188 grad_state=self) 189 real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt) 190 self._grad_index = self._grad_context.AddBackpropLoopCounter( 191 real_cnt, outer_grad_state) 192 outer_grad_ctxt.Exit() 193 else: 194 if outer_forward_ctxt: 195 outer_forward_ctxt.Enter() 196 self._grad_context = control_flow_ops.WhileContext( 197 maximum_iterations=forward_ctxt.maximum_iterations, 198 parallel_iterations=forward_ctxt.parallel_iterations, 199 back_prop=forward_ctxt.back_prop, 200 swap_memory=forward_ctxt.swap_memory, 201 name=forward_ctxt.name, 202 grad_state=self) 203 self._grad_index = self._grad_context.AddBackpropLoopCounter( 204 cnt, outer_grad_state) 205 if outer_forward_ctxt: 206 outer_forward_ctxt.Exit() 207 208 @property 209 def outer_grad_state(self): 210 """The grad loop state for outer loop.""" 211 return self._outer_grad_state 212 213 @property 214 def forward_context(self): 215 """The while loop context for forward.""" 216 return self._forward_context 217 218 @property 219 def forward_index(self): 220 """The loop index of forward loop.""" 221 return self._forward_index 222 223 @property 224 def forward_sync(self): 225 """A control trigger node for synchronization in the forward loop. 226 227 One main use is to keep the push ops of a stack executed in the 228 iteration order. 229 """ 230 if self._forward_sync is None: 231 with ops.control_dependencies(None): 232 self._forward_sync = control_flow_ops.control_trigger(name="f_sync") 233 self._forward_sync._set_control_flow_context(self._forward_context) 234 self._forward_index.op._add_control_input(self._forward_sync) 235 return self._forward_sync 236 237 @property 238 def grad_context(self): 239 """The corresponding WhileContext for gradient.""" 240 return self._grad_context 241 242 @property 243 def grad_index(self): 244 """The loop index of backprop loop.""" 245 return self._grad_index 246 247 @property 248 def grad_sync(self): 249 """A control trigger node for synchronization in the grad loop. 250 251 One main use is to keep the pop ops of a stack executed in the 252 iteration order. 253 """ 254 if self._grad_sync is None: 255 with ops.control_dependencies(None): 256 self._grad_sync = control_flow_ops.control_trigger(name="b_sync") 257 self._grad_sync._set_control_flow_context(self._grad_context) 258 self._grad_index.op._add_control_input(self._grad_sync) 259 if self._grad_context.outer_context: 260 self._grad_context.outer_context.AddInnerOp(self._grad_sync) 261 return self._grad_sync 262 263 @property 264 def history_map(self): 265 """The map that records all the tensors needed for backprop.""" 266 return self._history_map 267 268 @property 269 def switch_map(self): 270 """The map that records all the Switch ops for the while loop.""" 271 return self._switch_map 272 273 @property 274 def unused_exits(self): 275 """The list of "unused" exits.""" 276 return self._unused_exits 277 278 @property 279 def deferred_exits(self): 280 """The list of "deferred" exits.""" 281 return self._deferred_exits 282 283 @property 284 def forward_loop_exits(self): 285 """The list of exits of the forward loop.""" 286 return self._forward_loop_exits 287 288 @property 289 def pending_exits_count(self): 290 """The number of exits we expect to see but haven't.""" 291 return self._pending_exits_count 292 293 @pending_exits_count.setter 294 def pending_exits_count(self, cnt): 295 """Set the pending count to cnt.""" 296 self._pending_exits_count = cnt 297 298 def AddForwardAccumulator(self, value, dead_branch=False): 299 """Add an accumulator for each forward tensor that is needed in backprop. 300 301 This is added to the forward loop at the first time when a tensor 302 in the forward loop is used by backprop gradient computation loop. 303 We create an accumulator that accumulates the value of tensor at each 304 iteration. Called in the control flow context where gradients() is called. 305 306 The pseudocode is: 307 ``` 308 acc = stack(); 309 while (_pivot) { 310 acc = stack_push(acc, value); 311 } 312 ``` 313 314 We make sure that the stack push op in one iteration is executed before 315 next iteration. This is achieved by adding a control edge from 316 `forward_index.op.inputs[0].op` to the push op, and another control 317 edge from the push op to either `forward_index.op` or `forward_sync`. 318 319 Args: 320 value: The source tensor in forward that is to be accumulated. 321 dead_branch: True iff the tensor is on a dead branch of a cond. 322 323 Returns: 324 The stack that contains the accumulated history of the tensor. 325 326 Raises: 327 TypeError: For internal errors involving the value condition context. 328 ValueError: If `value` is inside a XLA scope and a valid max size 329 for the stack can't be found. 330 """ 331 # curr_ctxt is the context that tf.gradients was called in. 332 with self._forward_index.graph.as_default(): 333 curr_ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access 334 with ops.control_dependencies(None): 335 if curr_ctxt: 336 curr_ctxt.Enter() 337 with ops.colocate_with(value): 338 # We only need to pass maximum_iterations to the stack if 339 # we're inside an XLA context. 340 if not util.IsInXLAContext(value.op): 341 max_size = constant_op.constant(-1, dtypes.int32) 342 else: 343 max_size = _GetMaxSizeFromNestedMaximumIterations( 344 value, self.forward_context) 345 acc = gen_data_flow_ops.stack_v2( 346 max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc") 347 if curr_ctxt: 348 curr_ctxt.Exit() 349 350 # Make acc available in the forward context. 351 enter_acc = self.forward_context.AddValue(acc) 352 353 # Add the stack_push op in the context of value.op. 354 swap_enabled = self.forward_context.swap_memory 355 value_ctxt = util.GetOutputContext(value.op) 356 if value_ctxt == self.forward_context: 357 # value is not nested in the forward context. 358 self.forward_context.Enter() 359 push = gen_data_flow_ops.stack_push_v2( 360 enter_acc, value, swap_memory=swap_enabled) 361 self.forward_context.Exit() 362 # Protect stack push and order it before forward_index. 363 self.forward_index.op._add_control_input(push.op) 364 else: 365 # value is in a cond context within the forward context. 366 if not isinstance(value_ctxt, control_flow_ops.CondContext): 367 raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt) 368 if dead_branch: 369 # The special case for creating a zero tensor for a dead 370 # branch of a switch. See _ControlFlowState.ZerosLikeV1WhileLoop(). 371 value_ctxt.outer_context.Enter() 372 push = gen_data_flow_ops.stack_push_v2( 373 enter_acc, value, swap_memory=swap_enabled) 374 value_ctxt.outer_context.Exit() 375 push.op._set_control_flow_context(value_ctxt) 376 else: 377 value_ctxt.Enter() 378 push = gen_data_flow_ops.stack_push_v2( 379 enter_acc, value, swap_memory=swap_enabled) 380 value_ctxt.Exit() 381 # Protect stack push and order it before forward_sync. 382 self.forward_sync._add_control_input(push.op) 383 # Order stack push after the successor of forward_index 384 add_op = self.forward_index.op.inputs[0].op 385 push.op._add_control_input(add_op) 386 return acc 387 388 def AddBackpropAccumulatedValue(self, history_value, value, 389 dead_branch=False): 390 """Add the getter for an accumulated value in the grad context. 391 392 This is added to the backprop loop. Called in the grad context to 393 get the value of an accumulated value. The stack pop op must be guarded 394 by the pred of the controlling cond. 395 396 Args: 397 history_value: The history (a stack) of a value. 398 value: The value that is pushed onto the stack. 399 dead_branch: True iff the tensor is on a dead branch of a cond. 400 401 Returns: 402 The current value (the top of the stack). 403 """ 404 history_ctxt = history_value.op._get_control_flow_context() 405 # Find the cond context that controls history_value if any. 406 cond_ctxt = None 407 value_ctxt = value.op._get_control_flow_context() 408 while value_ctxt and value_ctxt != history_ctxt: 409 if isinstance(value_ctxt, control_flow_ops.CondContext): 410 cond_ctxt = value_ctxt 411 break 412 value_ctxt = value_ctxt.outer_context 413 with ops.control_dependencies(None): 414 self.grad_context.Enter() 415 if cond_ctxt: 416 # Guard stack pop with a switch if it is controlled by a cond. 417 grad_state = self 418 pred = None 419 while pred is None and grad_state: 420 pred = grad_state.history_map.get(cond_ctxt.pred.name) 421 grad_state = grad_state.outer_grad_state 422 if pred is None: 423 pred = cond_ctxt.pred 424 branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch 425 history_value = control_flow_ops._SwitchRefOrTensor( 426 history_value, pred)[branch] 427 pop = gen_data_flow_ops.stack_pop_v2(history_value, 428 value.dtype.base_dtype) 429 pop.set_shape(value.get_shape()) 430 self.grad_context.Exit() 431 parallel_iterations = self.grad_context.parallel_iterations 432 if parallel_iterations > 1: 433 # All pops are ordered after pivot_for_body and before grad_sync. 434 self.grad_sync._add_control_input(pop.op) 435 return pop 436 437 def GetRealValue(self, value): 438 """Get the real value of `value`. 439 440 If backprop "uses" a value produced by forward inference, an accumulator 441 is added in the forward loop to accumulate its values. We use the 442 accumulated value. This method must be called in the grad loop context. 443 `value` must be in forward and needed for backprop. 444 445 Args: 446 value: A tensor to be captured. 447 448 Returns: 449 The same tensor obtained from the saved history. 450 """ 451 assert value.op.type not in ["Variable", "VariableV2"] 452 real_value = self._history_map.get(value.name) 453 if real_value is None: 454 cur_value = value 455 cur_grad_state = self 456 while True: 457 enter_op = util.GetLoopConstantEnter(cur_value) 458 if enter_op: 459 # Special case: cur_value comes from a constant Enter node. 460 cur_value = enter_op.inputs[0] 461 cur_grad_state = cur_grad_state.outer_grad_state 462 if cur_grad_state is None: 463 # We are now outside all nested loops for this gradient(), 464 # so `value` is a loop invariant and there is no need to 465 # save the history of value. Just make cur_value to enter 466 # the right control flow context. 467 real_value = self._grad_context.AddValue(cur_value) 468 break 469 elif constant_op.is_constant(cur_value): 470 # If the value to be forwarded is a constant, clone the constant in 471 # the gradient loop rather than using a stack. 472 # TODO(phawkins): consider hoisting the constant out of the loop 473 # instead. 474 real_value = constant_op.constant( 475 tensor_util.constant_value(cur_value), dtype=cur_value.dtype) 476 break 477 else: 478 # Record the history of this value in forward_ctxt. 479 self._grad_context.Exit() 480 history_value = cur_grad_state.AddForwardAccumulator(cur_value) 481 self._grad_context.Enter() 482 break 483 484 if real_value is None: 485 # Add the stack pop op in the grad context. 486 real_value = cur_grad_state.AddBackpropAccumulatedValue( 487 history_value, cur_value) 488 if cur_grad_state != self: 489 real_value = self._grad_context.AddValue(real_value) 490 self._history_map[value.name] = real_value 491 return real_value 492 493 494class _ControlFlowState(object): 495 """Maintain the mapping from the loops to their grad states.""" 496 497 def __init__(self): 498 self._map = {} # maps forward loop context to _GradLoopState 499 500 def GetGradState(self, op, before): 501 """Return the grad state for this op if it's in a forward loop context.""" 502 if before and util.IsLoopExit(op): 503 forward_ctxt = op._get_control_flow_context() # pylint: disable=protected-access 504 forward_ctxt = forward_ctxt.outer_context 505 if forward_ctxt: 506 forward_ctxt = forward_ctxt.GetWhileContext() 507 else: 508 forward_ctxt = util.GetWhileContext(op) 509 if forward_ctxt: 510 return self._map.get(forward_ctxt) 511 return None 512 513 def ProcessUnusedLoopExits(self, pending_count, to_ops_set): 514 """Process all the "unused" loop exits. 515 516 The "unused" exits of the loops are added to `unused_exits`. An exit is 517 unused if its pending_count is 0. If there is an exit with real gradient, 518 all these deferred exits will enter the backprop loop with zero gradient. 519 Otherwise, they will enter the backprop loop with None. As an example, 520 people often write: 521 522 ```python 523 v1, _ = tf.while_loop(p, b, [x1, x2]) 524 result = gradients(v1, x1) 525 ``` 526 527 The exit node for x2 is not included by the betweenness analysis. But we 528 need to backprop x2 if x2 is involved in computing v1. 529 530 Args: 531 pending_count: The number of backprop inputs for every op. 532 to_ops_set: The set of ops for ys in gradients(ys, xs) 533 534 Returns: 535 The set of unused loop exits that we know at this point we need 536 to backprop. 537 """ 538 loop_exits = [] 539 for grad_state in self._map.values(): 540 for y in grad_state.forward_loop_exits: 541 if pending_count[y.op] == 0: 542 grad_state.pending_exits_count -= 1 543 if y.op not in to_ops_set: 544 grad_state.unused_exits.append(y) 545 if grad_state.pending_exits_count == 0: 546 loop_exits.extend(grad_state.unused_exits) 547 # Need to include Enters in backprop for higher-order gradients. 548 for y in grad_state.forward_context.loop_enters: 549 if pending_count[y.op] == 0: 550 pending_count[y.op] = 1 551 return loop_exits 552 553 def EnterGradWhileContext(self, op, before): 554 """Enter the WhileContext for gradient computation.""" 555 grad_state = self.GetGradState(op, before) 556 if grad_state: 557 grad_state.grad_context.Enter() 558 559 def ExitGradWhileContext(self, op, before): 560 """Exit the WhileContext for gradient computation.""" 561 grad_state = self.GetGradState(op, before) 562 if grad_state: 563 grad_state.grad_context.Exit() 564 565 def AddWhileContext(self, op, between_op_list, between_ops): 566 """Add the grad state for the while loop that op belongs to. 567 568 Note that op is an Exit, and this method must be called in 569 the control flow context where gradients() is called. 570 571 Note that this method modifies `between_op_list` and `between_ops`. 572 """ 573 forward_ctxt = util.GetWhileContext(op) 574 grad_state = self._map.get(forward_ctxt) 575 if grad_state is None: 576 # This is a new while loop so create a grad state for it. 577 outer_forward_ctxt = forward_ctxt.outer_context 578 if outer_forward_ctxt: 579 outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() 580 outer_grad_state = None 581 if outer_forward_ctxt: 582 outer_grad_state = self._map.get(outer_forward_ctxt) 583 grad_state = _GradLoopState(forward_ctxt, outer_grad_state) 584 self._map[forward_ctxt] = grad_state 585 586 # We need to include all exits of a loop for backprop. 587 for loop_exit in grad_state.forward_loop_exits: 588 if loop_exit.op not in between_ops: 589 between_ops.add(loop_exit.op) 590 between_op_list.append(loop_exit.op) 591 592 def ZerosLikeForExit(self, val): 593 """Create zeros_like gradient for a loop exit. 594 595 If the result of a loop variable is not used but is involved in 596 computing the result of some needed loop variable, we create a 597 zero-valued tensor that is fed as gradient for the Exit node of that 598 loop variable. Note that val.op is an Exit, and this method must be 599 called in the control flow context where gradients() is called. 600 601 Args: 602 val: The output tensor of an Exit op. 603 604 Returns: 605 A zero tensor of the same shape of val. 606 """ 607 val_shape = val.get_shape() 608 forward_ctxt = val.op._get_control_flow_context() 609 outer_forward_ctxt = forward_ctxt.outer_context 610 if outer_forward_ctxt: 611 outer_forward_ctxt = outer_forward_ctxt.GetWhileContext() 612 outer_grad_state = None 613 if outer_forward_ctxt: 614 outer_grad_state = self._map.get(outer_forward_ctxt) 615 if outer_grad_state: 616 # This is a nested loop. 617 if val_shape.is_fully_defined(): 618 # If the shape is known statically, just create a zero tensor 619 # with the right shape in the right context. 620 outer_grad_state.grad_context.Enter() 621 result = array_ops.zeros(val_shape.dims, val.dtype) 622 outer_grad_state.grad_context.Exit() 623 else: 624 # Only the shape of value is needed for backprop. 625 forward_ctxt.outer_context.Enter() 626 shape = array_ops.shape_internal(val, optimize=False) 627 forward_ctxt.outer_context.Exit() 628 # Save the shape to a stack. 629 history_shape = outer_grad_state.AddForwardAccumulator(shape) 630 # Get the shape back from the stack. 631 outer_grad_ctxt = outer_grad_state.grad_context 632 outer_grad_ctxt.Enter() 633 real_shape = outer_grad_state.AddBackpropAccumulatedValue( 634 history_shape, shape) 635 result = array_ops.zeros(real_shape, val.dtype) 636 outer_grad_ctxt.Exit() 637 else: 638 # This is not a nested loop. 639 if val_shape.is_fully_defined(): 640 # If the shape is known statically, just create a zero tensor 641 # with the right shape. 642 result = array_ops.zeros(val_shape.dims, val.dtype) 643 else: 644 result = array_ops.zeros_like(val, optimize=False) 645 return result 646 647 def ZerosLikeV1WhileLoop(self, op, index): 648 """Create zeros_like for the specified output of an op. 649 650 If op is in a while loop that is part of gradients(), this method 651 must be called in its grad loop context. 652 653 Args: 654 op: A tensorflow operation. 655 index: the index for a specific output of the op. 656 657 Returns: 658 A zero tensor of the same shape of op.outputs[index]. 659 """ 660 if util.IsLoopSwitch(op): 661 return None 662 if op.graph.building_function: 663 # The optimization here is tricky to apply to functions 664 return array_ops.zeros_like(op.outputs[index]) 665 dead_branch = util.IsSwitch(op) 666 forward_ctxt = util.GetWhileContext(op) 667 grad_state = self._map.get(forward_ctxt) 668 if grad_state is None: 669 # op is not in a while loop that is part of gradients(). 670 return ZerosLike(op, index) 671 op_ctxt = op._get_control_flow_context() 672 val = ops.convert_to_tensor(op.outputs[index], name="tensor") 673 shape = val.get_shape() 674 if shape.is_fully_defined(): 675 # If the shape is known statically, just create a zero tensor with 676 # the right shape in the grad loop context. 677 if val.dtype == dtypes.resource: 678 result = array_ops.zeros( 679 resource_variable_ops.variable_shape(val), 680 dtype=default_gradient.get_zeros_dtype(val)) 681 else: 682 result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype) 683 if dead_branch: 684 # op is a cond switch. Guard the zero tensor with a switch. 685 pred = grad_state.history_map.get(op_ctxt.pred.name) 686 branch = op_ctxt.branch 687 result = control_flow_ops._SwitchRefOrTensor(result, pred)[1 - branch] 688 else: 689 # Unknown shape so keep a history of the shape at runtime. 690 if dead_branch: 691 # Need to add a special switch to guard the value. 692 pred = op_ctxt.pred 693 branch = op_ctxt.branch 694 op_ctxt.outer_context.Enter() 695 val = control_flow_ops._SwitchRefOrTensor(op.inputs[0], 696 pred)[1 - branch] 697 zeros_shape = array_ops.shape_internal(val, optimize=False) 698 op_ctxt.outer_context.Exit() 699 val.op._set_control_flow_context(op_ctxt) 700 zeros_shape.op._set_control_flow_context(op_ctxt) 701 else: 702 op_ctxt.Enter() 703 zeros_shape = array_ops.shape_internal(val, optimize=False) 704 op_ctxt.Exit() 705 706 # Add forward accumulator for shape. 707 grad_state.grad_context.Exit() 708 history_zeros_shape = grad_state.AddForwardAccumulator( 709 zeros_shape, dead_branch=dead_branch) 710 grad_state.grad_context.Enter() 711 712 # Create a zero tensor with the right shape. 713 shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape, 714 zeros_shape, dead_branch) 715 result = array_ops.zeros(shape, val.dtype) 716 return result 717 718 def PostProcessing(self): 719 """Perform postprocessing at the end of gradients(). 720 721 We have created the gradient graph at this point. So this function 722 can be used to perform any postprocessing on the gradient graph. 723 We currently perform the following postprocessing: 724 1. Patch the gradient graph if the output of a loop variable 725 doesn't depend on its input. 726 """ 727 for _, grad_state in self._map.items(): 728 for _, b_merge in grad_state.switch_map.items(): 729 if b_merge.op.inputs[0] == b_merge.op.inputs[1]: 730 # The value of this loop variable at iteration i+1 doesn't 731 # depend on its value at iteration i. So use zeros as the 732 # gradients for all iterations > 0. 733 dtype = b_merge.op.inputs[0].dtype 734 shape = b_merge.op.inputs[0].get_shape() 735 # pylint: disable=protected-access 736 if shape.is_fully_defined(): 737 grad_state.grad_context.Enter() 738 # Create a zeros and use it for iterations > 0. 739 grad_val = constant_op.constant(0, dtype=dtype, shape=shape) 740 next_grad_val = control_flow_ops._NextIteration(grad_val) 741 grad_state.grad_context.Exit() 742 else: 743 # Create a zeros in the outer grad context. 744 outer_grad_ctxt = grad_state.grad_context.outer_context 745 if outer_grad_ctxt: 746 outer_grad_ctxt.Enter() 747 enter_grad_op = b_merge.op.inputs[0].op 748 enter_grad = enter_grad_op.inputs[0] 749 grad_shape = array_ops.shape_internal(enter_grad, optimize=False) 750 grad_val = array_ops.zeros(grad_shape) 751 if outer_grad_ctxt: 752 outer_grad_ctxt.Exit() 753 # Use the zeros for iterations > 0. 754 grad_state.grad_context.Enter() 755 next_grad_val = control_flow_ops._NextIteration(grad_val) 756 grad_state.grad_context.Exit() 757 b_merge.op._update_input(1, next_grad_val) 758 # pylint: enable=protected-access 759 760 761def MaybeCreateControlFlowState(between_op_list, between_ops, 762 colocate_gradients_with_ops): 763 """Create the state for all the while loops involved in one gradients(). 764 765 We create a _ControlFlowState when there are while loops involved in 766 gradients(). In gradients(), control flow logic is only invoked when 767 the _ControlFlowState is not None. 768 769 Note that this method modifies `between_op_list` and `between_ops`. 770 """ 771 loop_state = None 772 for op in between_op_list: 773 if util.IsLoopExit(op): 774 if loop_state is None: 775 loop_state = _ControlFlowState() 776 if colocate_gradients_with_ops: 777 with ops.colocate_with(op): 778 loop_state.AddWhileContext(op, between_op_list, between_ops) 779 else: 780 loop_state.AddWhileContext(op, between_op_list, between_ops) 781 return loop_state 782 783 784def _ZerosLikeV1(op, index): 785 """Branch of ZerosLike for TF1.""" 786 val = op.outputs[index] 787 op_ctxt = op._get_control_flow_context() # pylint: disable=protected-access 788 if op_ctxt: 789 # We are in a cond context. Use a switch to create zeros only when needed. 790 pred = op_ctxt.pred 791 branch = op_ctxt.branch 792 switch_val = control_flow_ops.switch(op.inputs[0], pred)[1 - branch] 793 # A op is created along the branch taken as control dependencies are on 794 # the whole op and not on the tensor output. 795 pivot = array_ops.identity(switch_val) 796 if val.dtype == dtypes.resource: 797 with ops.control_dependencies([pivot]): 798 return array_ops.zeros( 799 gen_resource_variable_ops.variable_shape(switch_val), 800 dtype=default_gradient.get_zeros_dtype(val)) 801 zeros_shape = array_ops.shape_internal(switch_val, optimize=False) 802 # Ensure ops created within array_ops.zeros are dominated by switch in 803 # cond context. 804 with ops.control_dependencies([pivot]): 805 return array_ops.zeros(zeros_shape, dtype=val.dtype) 806 else: 807 return array_ops.zeros_like(val, optimize=False) 808 809 810def _ZerosLikeV2(op, index): 811 """Branch of ZerosLike for TF2.""" 812 val = op.outputs[index] 813 if val.dtype == dtypes.resource: 814 return array_ops.zeros( 815 gen_resource_variable_ops.variable_shape(val), 816 dtype=default_gradient.get_zeros_dtype(val)) 817 if (isinstance(val.op.graph, control_flow_v2_func_graphs.WhileBodyFuncGraph) 818 and val.dtype != dtypes.variant): 819 # In while_v2 we do not want to add a `ZerosLike` op because that will 820 # trigger accumulation of `val`. Normally `ZerosLike` is preferred because 821 # it helps avoid creating extra nodes(possibly Consts) for the shape. 822 # For variants, we must use ZerosLike. 823 if val.shape.is_fully_defined(): 824 return constant_op.constant(0, shape=val.shape.dims, dtype=val.dtype) 825 else: 826 # Note: Even though we add `Shape` in the default graph, while_v2 is smart 827 # enough to place it in the forward graph i.e. `val.graph`. 828 zeros_shape = array_ops.shape_internal(val, optimize=False) 829 return array_ops.zeros(zeros_shape, val.dtype) 830 else: 831 return array_ops.zeros_like(val, optimize=False) 832 833 834def ZerosLike(op, index): 835 """Create zeros_like for the specified output of an op.""" 836 if not util.IsSwitch(op): 837 return _ZerosLikeV2(op, index) 838 else: 839 return _ZerosLikeV1(op, index) 840