1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Gradient tape utilites.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22 23from tensorflow.python import pywrap_tensorflow 24from tensorflow.python.util.lazy_loader import LazyLoader 25 26# There is a circular dependency between this, ops.py, and 27# distribution_strategy_context. 28# TODO(b/117329403): Remove this circular dependency. 29distribution_strategy_context = LazyLoader( 30 "distribution_strategy_context", globals(), 31 "tensorflow.python.distribute." 32 "distribution_strategy_context") 33 34 35class Tape(object): 36 """Represents a gradient propagation trace.""" 37 38 def __init__(self, tape): 39 self._tape = tape 40 41 def watched_variables(self): 42 return pywrap_tensorflow.TFE_Py_TapeWatchedVariables(self._tape) 43 44 45def push_new_tape(persistent=False, watch_accessed_variables=True): 46 """Pushes a new tape onto the tape stack.""" 47 tape = pywrap_tensorflow.TFE_Py_TapeSetNew(persistent, 48 watch_accessed_variables) 49 return Tape(tape) 50 51 52def push_tape(tape): 53 """Pushes an existing tape onto the tape stack.""" 54 pywrap_tensorflow.TFE_Py_TapeSetAdd(tape._tape) # pylint: disable=protected-access 55 56 57def watch(tape, tensor): 58 """Marks this tensor to be watched by the given tape.""" 59 pywrap_tensorflow.TFE_Py_TapeWatch(tape._tape, tensor) # pylint: disable=protected-access 60 61 62def watch_variable(tape, variable): 63 """Marks this variable to be watched by the given tape.""" 64 strategy, context = ( 65 distribution_strategy_context.get_strategy_and_replica_context()) 66 if context: 67 variables = [strategy.extended.value_container(variable)] 68 else: 69 variables = strategy.unwrap(variable) 70 for var in variables: 71 pywrap_tensorflow.TFE_Py_TapeWatchVariable(tape._tape, var) # pylint: disable=protected-access 72 73 74def variable_accessed(variable): 75 """Notifies all tapes in the stack that a variable has been accessed. 76 77 Args: 78 variable: variable to be watched. 79 """ 80 strategy, context = ( 81 distribution_strategy_context.get_strategy_and_replica_context()) 82 if context: 83 variables = [strategy.extended.value_container(variable)] 84 else: 85 variables = strategy.unwrap(variable) 86 for var in variables: 87 pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var) 88 89 90def variables_accessed(variables): 91 """Notifies all tapes in the stack that variables have been accessed. 92 93 Only trainable variables are marked as accessed. 94 95 Args: 96 variables: iterable of variables to mark as accessed. 97 """ 98 strategy, context = ( 99 distribution_strategy_context.get_strategy_and_replica_context()) 100 accessed = [] 101 if context: 102 accessed = [strategy.extended.value_container(variable) 103 for variable in variables if variable.trainable] 104 else: 105 for variable in variables: 106 if variable.trainable: 107 accessed.extend(strategy.unwrap(variable)) 108 109 for var in accessed: 110 pywrap_tensorflow.TFE_Py_TapeVariableAccessed(var) 111 112 113def pop_tape(tape): 114 """Pops the top tape in the stack, if any.""" 115 pywrap_tensorflow.TFE_Py_TapeSetRemove(tape._tape) # pylint: disable=protected-access 116 117 118@contextlib.contextmanager 119def stop_recording(): 120 try: 121 pywrap_tensorflow.TFE_Py_TapeSetStopOnThread() 122 yield 123 finally: 124 pywrap_tensorflow.TFE_Py_TapeSetRestartOnThread() 125 126 127def should_record(tensors): 128 """Returns true if any tape in the stack watches any of these tensors.""" 129 return pywrap_tensorflow.TFE_Py_TapeSetShouldRecord(tensors) 130 131 132def record_operation(op_type, output_tensors, input_tensors, backward_function): 133 """Records the operation on all tapes in the stack.""" 134 pywrap_tensorflow.TFE_Py_TapeSetRecordOperation( 135 op_type, output_tensors, input_tensors, backward_function) 136 137 138def delete_trace(tensor_id): 139 """Deletes traces for this Tensor from all tapes in the stack.""" 140 pywrap_tensorflow.TFE_Py_TapeSetDeleteTrace(tensor_id) 141 142 143def could_possibly_record(): 144 """Returns True if any tape is active.""" 145 return not pywrap_tensorflow.TFE_Py_TapeSetIsEmpty() 146