1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradient tape utilities.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22 23from tensorflow.python import pywrap_tfe 24from tensorflow.python.util.lazy_loader import LazyLoader 25 26# There is a circular dependency between this, ops.py, and 27# distribution_strategy_context. 28# TODO(b/117329403): Remove this circular dependency. 29distribution_strategy_context = LazyLoader( 30 "distribution_strategy_context", globals(), 31 "tensorflow.python.distribute." 32 "distribution_strategy_context") 33 34 35class Tape(object): 36 """Represents a gradient propagation trace.""" 37 38 __slots__ = ["_tape"] 39 40 def __init__(self, tape): 41 self._tape = tape 42 43 def watched_variables(self): 44 return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape) 45 46 47def push_new_tape(persistent=False, watch_accessed_variables=True): 48 """Pushes a new tape onto the tape stack.""" 49 tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables) 50 return Tape(tape) 51 52 53def push_tape(tape): 54 """Pushes an existing tape onto the tape stack.""" 55 pywrap_tfe.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access 56 57 58def watch(tape, tensor): 59 """Marks this tensor to be watched by the given tape.""" 60 pywrap_tfe.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access 61 62 63class VariableWatcher(object): 64 """A scope that tracks all trainable variable accesses within it. 65 66 This explicitly ignores variables that are not marked as trainable. 67 68 Sample usage: 69 70 var = tf.Variable(0.0) 71 with VariableWatcher() as variable_watcher: 72 var.assign_add(1.0) 73 74 assert variable_watcher.watched_variables == [var] 75 """ 76 77 __slots__ = ["_variable_watcher"] 78 79 def __init__(self): 80 self._variable_watcher = None 81 82 def __enter__(self): 83 self._variable_watcher = pywrap_tfe.TFE_Py_VariableWatcherNew() 84 return self 85 86 def __exit__(self, typ, value, traceback): 87 pywrap_tfe.TFE_Py_VariableWatcherRemove(self._variable_watcher) 88 89 def watched_variables(self): 90 """Returns a tuple of variables accessed under this scope.""" 91 return pywrap_tfe.TFE_Py_VariableWatcherWatchedVariables( 92 self._variable_watcher) 93 94 95def watch_variable(tape, variable): 96 """Marks this variable to be watched by the given tape.""" 97 strategy, context = ( 98 distribution_strategy_context.get_strategy_and_replica_context()) 99 if context: 100 variables = [strategy.extended.value_container(variable)] 101 else: 102 variables = strategy.experimental_local_results(variable) 103 for var in variables: 104 pywrap_tfe.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access 105 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) 106 107 108def variable_accessed(variable): 109 """Notifies all tapes in the stack that a variable has been accessed. 110 111 Args: 112 variable: variable to be watched. 113 """ 114 strategy, context = ( 115 distribution_strategy_context.get_strategy_and_replica_context()) 116 if context: 117 variables = [strategy.extended.value_container(variable)] 118 else: 119 variables = strategy.experimental_local_results(variable) 120 for var in variables: 121 pywrap_tfe.TFE_Py_TapeVariableAccessed(var) 122 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) 123 124 125def variables_accessed(variables): 126 """Notifies all tapes in the stack that variables have been accessed. 127 128 Only trainable variables are marked as accessed. 129 130 Args: 131 variables: iterable of variables to mark as accessed. 132 """ 133 strategy, context = ( 134 distribution_strategy_context.get_strategy_and_replica_context()) 135 accessed = [] 136 if context: 137 accessed = [strategy.extended.value_container(variable) 138 for variable in variables if variable.trainable] 139 else: 140 for variable in variables: 141 if variable.trainable: 142 accessed.extend(strategy.experimental_local_results(variable)) 143 144 for var in accessed: 145 pywrap_tfe.TFE_Py_TapeVariableAccessed(var) 146 pywrap_tfe.TFE_Py_VariableWatcherVariableAccessed(var) 147 148 149def pop_tape(tape): 150 """Pops the given tape in the stack.""" 151 pywrap_tfe.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access 152 153 154@contextlib.contextmanager 155def stop_recording(): 156 """Stop all gradient recording (backprop and forwardprop).""" 157 is_stopped = pywrap_tfe.TFE_Py_TapeSetIsStopped() 158 try: 159 if not is_stopped: 160 pywrap_tfe.TFE_Py_TapeSetStopOnThread() 161 yield 162 finally: 163 if not is_stopped: 164 pywrap_tfe.TFE_Py_TapeSetRestartOnThread() 165 166 167def should_record_backprop(tensors): 168 """Returns true if any tape in the stack watches any of these tensors. 169 170 Only takes GradientTapes into account, not forward accumulators. 171 172 Args: 173 tensors: Tensors to check, typically inputs to an operation. 174 175 Returns: 176 Boolean, whether any tape watches any of `tensors`. 177 """ 178 return pywrap_tfe.TFE_Py_TapeSetShouldRecordBackprop(tensors) 179 180 181def record_operation(op_type, output_tensors, input_tensors, backward_function, 182 forward_function=None): 183 """Records the operation on all tapes in the stack.""" 184 pywrap_tfe.TFE_Py_TapeSetRecordOperation(op_type, output_tensors, 185 input_tensors, backward_function, 186 forward_function) 187 188 189def record_operation_backprop_only(op_type, output_tensors, input_tensors, 190 backward_function): 191 """Records the operation on all backward tapes in the stack.""" 192 pywrap_tfe.TFE_Py_TapeSetRecordOperationBackprop(op_type, output_tensors, 193 input_tensors, 194 backward_function) 195 196 197def record_operation_forwardprop_only(op_type, output_tensors, input_tensors, 198 backward_function, 199 forwardprop_output_indices): 200 """Records the operation on all forward accumulators in the stack. 201 202 Args: 203 op_type: a string for the operation type, used in the backprop code 204 output_tensors: a list of Python Tensor objects output by the operation 205 input_tensors: a list of input Tensors to the recorded operation 206 backward_function: the function to be called to, given the gradients of the 207 output tensors, produce the gradients of the input tensors. This function 208 is automatically transposed to produce output gradients given input 209 gradients. 210 forwardprop_output_indices: indicates any output_tensors which contain JVPs. 211 Typically these will have come from TFE_Py_PackForwardGradients. May be 212 None or an empty sequence if there are no JVP outputs from the operation. 213 """ 214 pywrap_tfe.TFE_Py_TapeSetRecordOperationForwardprop( 215 op_type, output_tensors, input_tensors, backward_function, 216 forwardprop_output_indices) 217 218 219def delete_trace(tensor_id): 220 """Deletes traces for this Tensor from all tapes in the stack.""" 221 pywrap_tfe.TFE_Py_TapeSetDeleteTrace(tensor_id) 222 223 224def could_possibly_record(): 225 """Returns True if any tape is active.""" 226 return not pywrap_tfe.TFE_Py_TapeSetIsEmpty() 227