1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for training.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.eager import context 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import graph_io 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import init_ops 25from tensorflow.python.ops import resource_variable_ops 26from tensorflow.python.ops import state_ops 27from tensorflow.python.ops import variable_scope 28from tensorflow.python.ops import variables 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.util.tf_export import tf_export 31 32# Picked a long key value to minimize the chance of collision with user defined 33# collection keys. 34GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache' 35 36# TODO(drpng): remove this after legacy uses are resolved. 37write_graph = graph_io.write_graph 38 39 40@tf_export(v1=['train.global_step']) 41def global_step(sess, global_step_tensor): 42 """Small helper to get the global step. 43 44 ```python 45 # Create a variable to hold the global_step. 46 global_step_tensor = tf.Variable(10, trainable=False, name='global_step') 47 # Create a session. 48 sess = tf.compat.v1.Session() 49 # Initialize the variable 50 sess.run(global_step_tensor.initializer) 51 # Get the variable value. 52 print('global_step: %s' % tf.compat.v1.train.global_step(sess, 53 global_step_tensor)) 54 55 global_step: 10 56 ``` 57 58 Args: 59 sess: A TensorFlow `Session` object. 60 global_step_tensor: `Tensor` or the `name` of the operation that contains 61 the global step. 62 63 Returns: 64 The global step value. 65 """ 66 if context.executing_eagerly(): 67 return int(global_step_tensor.numpy()) 68 return int(sess.run(global_step_tensor)) 69 70 71@tf_export(v1=['train.get_global_step']) 72def get_global_step(graph=None): 73 """Get the global step tensor. 74 75 The global step tensor must be an integer variable. We first try to find it 76 in the collection `GLOBAL_STEP`, or by name `global_step:0`. 77 78 Args: 79 graph: The graph to find the global step in. If missing, use default graph. 80 81 Returns: 82 The global step variable, or `None` if none was found. 83 84 Raises: 85 TypeError: If the global step tensor has a non-integer type, or if it is not 86 a `Variable`. 87 """ 88 graph = graph or ops.get_default_graph() 89 global_step_tensor = None 90 global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP) 91 if len(global_step_tensors) == 1: 92 global_step_tensor = global_step_tensors[0] 93 elif not global_step_tensors: 94 try: 95 global_step_tensor = graph.get_tensor_by_name('global_step:0') 96 except KeyError: 97 return None 98 else: 99 logging.error('Multiple tensors in global_step collection.') 100 return None 101 102 assert_global_step(global_step_tensor) 103 return global_step_tensor 104 105 106@tf_export(v1=['train.create_global_step']) 107def create_global_step(graph=None): 108 """Create global step tensor in graph. 109 110 Args: 111 graph: The graph in which to create the global step tensor. If missing, use 112 default graph. 113 114 Returns: 115 Global step tensor. 116 117 Raises: 118 ValueError: if global step tensor is already defined. 119 """ 120 graph = graph or ops.get_default_graph() 121 if get_global_step(graph) is not None: 122 raise ValueError('"global_step" already exists.') 123 if context.executing_eagerly(): 124 with ops.device('cpu:0'): 125 return variable_scope.get_variable( 126 ops.GraphKeys.GLOBAL_STEP, 127 shape=[], 128 dtype=dtypes.int64, 129 initializer=init_ops.zeros_initializer(), 130 trainable=False, 131 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA, 132 collections=[ 133 ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP 134 ]) 135 # Create in proper graph and base name_scope. 136 with graph.as_default() as g, g.name_scope(None): 137 return variable_scope.get_variable( 138 ops.GraphKeys.GLOBAL_STEP, 139 shape=[], 140 dtype=dtypes.int64, 141 initializer=init_ops.zeros_initializer(), 142 trainable=False, 143 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA, 144 collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) 145 146 147@tf_export(v1=['train.get_or_create_global_step']) 148def get_or_create_global_step(graph=None): 149 """Returns and create (if necessary) the global step tensor. 150 151 Args: 152 graph: The graph in which to create the global step tensor. If missing, use 153 default graph. 154 155 Returns: 156 The global step tensor. 157 """ 158 graph = graph or ops.get_default_graph() 159 global_step_tensor = get_global_step(graph) 160 if global_step_tensor is None: 161 global_step_tensor = create_global_step(graph) 162 return global_step_tensor 163 164 165@tf_export(v1=['train.assert_global_step']) 166def assert_global_step(global_step_tensor): 167 """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`. 168 169 Args: 170 global_step_tensor: `Tensor` to test. 171 """ 172 if not (isinstance(global_step_tensor, variables.Variable) or 173 isinstance(global_step_tensor, ops.Tensor) or 174 resource_variable_ops.is_resource_variable(global_step_tensor)): 175 raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' % 176 global_step_tensor) 177 178 if not global_step_tensor.dtype.base_dtype.is_integer: 179 raise TypeError('Existing "global_step" does not have integer type: %s' % 180 global_step_tensor.dtype) 181 182 if (global_step_tensor.get_shape().ndims != 0 and 183 global_step_tensor.get_shape().is_fully_defined()): 184 raise TypeError('Existing "global_step" is not scalar: %s' % 185 global_step_tensor.get_shape()) 186 187 188def _get_global_step_read(graph=None): 189 """Gets global step read tensor in graph. 190 191 Args: 192 graph: The graph in which to create the global step read tensor. If missing, 193 use default graph. 194 195 Returns: 196 Global step read tensor. 197 198 Raises: 199 RuntimeError: if multiple items found in collection GLOBAL_STEP_READ_KEY. 200 """ 201 graph = graph or ops.get_default_graph() 202 global_step_read_tensors = graph.get_collection(GLOBAL_STEP_READ_KEY) 203 if len(global_step_read_tensors) > 1: 204 raise RuntimeError('There are multiple items in collection {}. ' 205 'There should be only one.'.format(GLOBAL_STEP_READ_KEY)) 206 207 if len(global_step_read_tensors) == 1: 208 return global_step_read_tensors[0] 209 return None 210 211 212def _get_or_create_global_step_read(graph=None): 213 """Gets or creates global step read tensor in graph. 214 215 Args: 216 graph: The graph in which to create the global step read tensor. If missing, 217 use default graph. 218 219 Returns: 220 Global step read tensor if there is global_step_tensor else return None. 221 """ 222 graph = graph or ops.get_default_graph() 223 global_step_read_tensor = _get_global_step_read(graph) 224 if global_step_read_tensor is not None: 225 return global_step_read_tensor 226 global_step_tensor = get_global_step(graph) 227 if global_step_tensor is None: 228 return None 229 # add 'zero' so that it will create a copy of variable as Tensor. 230 with graph.as_default() as g, g.name_scope(None): 231 with g.name_scope(global_step_tensor.op.name + '/'): 232 # using initialized_value to ensure that global_step is initialized before 233 # this run. This is needed for example Estimator makes all model_fn build 234 # under global_step_read_tensor dependency. 235 global_step_value = global_step_tensor.initialized_value() if isinstance( 236 global_step_tensor, variables.Variable) else global_step_tensor 237 global_step_read_tensor = global_step_value + 0 238 ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor) 239 return _get_global_step_read(graph) 240 241 242def _increment_global_step(increment, graph=None): 243 graph = graph or ops.get_default_graph() 244 global_step_tensor = get_global_step(graph) 245 if global_step_tensor is None: 246 raise ValueError( 247 'Global step tensor should be created by ' 248 'tf.train.get_or_create_global_step before calling increment.') 249 global_step_read_tensor = _get_or_create_global_step_read(graph) 250 with graph.as_default() as g, g.name_scope(None): 251 with g.name_scope(global_step_tensor.op.name + '/'): 252 with ops.control_dependencies([global_step_read_tensor]): 253 return state_ops.assign_add(global_step_tensor, increment) 254