1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions for training.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.eager import context 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import graph_io 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import init_ops 25from tensorflow.python.ops import resource_variable_ops 26from tensorflow.python.ops import state_ops 27from tensorflow.python.ops import variable_scope 28from tensorflow.python.ops import variables 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.util.tf_export import tf_export 31 32# Picked a long key value to minimize the chance of collision with user defined 33# collection keys. 34GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache' 35 36# TODO(drpng): remove this after legacy uses are resolved. 37write_graph = graph_io.write_graph 38 39 40@tf_export(v1=['train.global_step']) 41def global_step(sess, global_step_tensor): 42 """Small helper to get the global step. 43 44 ```python 45 # Create a variable to hold the global_step. 46 global_step_tensor = tf.Variable(10, trainable=False, name='global_step') 47 # Create a session. 48 sess = tf.compat.v1.Session() 49 # Initialize the variable 50 sess.run(global_step_tensor.initializer) 51 # Get the variable value. 52 print('global_step: %s' % tf.compat.v1.train.global_step(sess, 53 global_step_tensor)) 54 55 global_step: 10 56 ``` 57 58 Args: 59 sess: A TensorFlow `Session` object. 60 global_step_tensor: `Tensor` or the `name` of the operation that contains 61 the global step. 62 63 Returns: 64 The global step value. 65 """ 66 if context.executing_eagerly(): 67 return int(global_step_tensor.numpy()) 68 return int(sess.run(global_step_tensor)) 69 70 71@tf_export(v1=['train.get_global_step']) 72def get_global_step(graph=None): 73 """Get the global step tensor. 74 75 The global step tensor must be an integer variable. We first try to find it 76 in the collection `GLOBAL_STEP`, or by name `global_step:0`. 77 78 Args: 79 graph: The graph to find the global step in. If missing, use default graph. 80 81 Returns: 82 The global step variable, or `None` if none was found. 83 84 Raises: 85 TypeError: If the global step tensor has a non-integer type, or if it is not 86 a `Variable`. 87 88 @compatibility(TF2) 89 With the deprecation of global graphs, TF no longer tracks variables in 90 collections. In other words, there are no global variables in TF2. Thus, the 91 global step functions have been removed (`get_or_create_global_step`, 92 `create_global_step`, `get_global_step`) . You have two options for migrating: 93 94 1. Create a Keras optimizer, which generates an `iterations` variable. This 95 variable is automatically incremented when calling `apply_gradients`. 96 2. Manually create and increment a `tf.Variable`. 97 98 Below is an example of migrating away from using a global step to using a 99 Keras optimizer: 100 101 Define a dummy model and loss: 102 103 >>> def compute_loss(x): 104 ... v = tf.Variable(3.0) 105 ... y = x * v 106 ... loss = x * 5 - x * v 107 ... return loss, [v] 108 109 Before migrating: 110 111 >>> g = tf.Graph() 112 >>> with g.as_default(): 113 ... x = tf.compat.v1.placeholder(tf.float32, []) 114 ... loss, var_list = compute_loss(x) 115 ... global_step = tf.compat.v1.train.get_or_create_global_step() 116 ... global_init = tf.compat.v1.global_variables_initializer() 117 ... optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1) 118 ... train_op = optimizer.minimize(loss, global_step, var_list) 119 >>> sess = tf.compat.v1.Session(graph=g) 120 >>> sess.run(global_init) 121 >>> print("before training:", sess.run(global_step)) 122 before training: 0 123 >>> sess.run(train_op, feed_dict={x: 3}) 124 >>> print("after training:", sess.run(global_step)) 125 after training: 1 126 127 Using `get_global_step`: 128 129 >>> with g.as_default(): 130 ... print(sess.run(tf.compat.v1.train.get_global_step())) 131 1 132 133 Migrating to a Keras optimizer: 134 135 >>> optimizer = tf.keras.optimizers.SGD(.01) 136 >>> print("before training:", optimizer.iterations.numpy()) 137 before training: 0 138 >>> with tf.GradientTape() as tape: 139 ... loss, var_list = compute_loss(3) 140 ... grads = tape.gradient(loss, var_list) 141 ... optimizer.apply_gradients(zip(grads, var_list)) 142 >>> print("after training:", optimizer.iterations.numpy()) 143 after training: 1 144 145 @end_compatibility 146 """ 147 graph = graph or ops.get_default_graph() 148 global_step_tensor = None 149 global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP) 150 if len(global_step_tensors) == 1: 151 global_step_tensor = global_step_tensors[0] 152 elif not global_step_tensors: 153 try: 154 global_step_tensor = graph.get_tensor_by_name('global_step:0') 155 except KeyError: 156 return None 157 else: 158 logging.error('Multiple tensors in global_step collection.') 159 return None 160 161 assert_global_step(global_step_tensor) 162 return global_step_tensor 163 164 165@tf_export(v1=['train.create_global_step']) 166def create_global_step(graph=None): 167 """Create global step tensor in graph. 168 169 Args: 170 graph: The graph in which to create the global step tensor. If missing, use 171 default graph. 172 173 Returns: 174 Global step tensor. 175 176 Raises: 177 ValueError: if global step tensor is already defined. 178 179 @compatibility(TF2) 180 With the deprecation of global graphs, TF no longer tracks variables in 181 collections. In other words, there are no global variables in TF2. Thus, the 182 global step functions have been removed (`get_or_create_global_step`, 183 `create_global_step`, `get_global_step`) . You have two options for migrating: 184 185 1. Create a Keras optimizer, which generates an `iterations` variable. This 186 variable is automatically incremented when calling `apply_gradients`. 187 2. Manually create and increment a `tf.Variable`. 188 189 Below is an example of migrating away from using a global step to using a 190 Keras optimizer: 191 192 Define a dummy model and loss: 193 194 >>> def compute_loss(x): 195 ... v = tf.Variable(3.0) 196 ... y = x * v 197 ... loss = x * 5 - x * v 198 ... return loss, [v] 199 200 Before migrating: 201 202 >>> g = tf.Graph() 203 >>> with g.as_default(): 204 ... x = tf.compat.v1.placeholder(tf.float32, []) 205 ... loss, var_list = compute_loss(x) 206 ... global_step = tf.compat.v1.train.create_global_step() 207 ... global_init = tf.compat.v1.global_variables_initializer() 208 ... optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1) 209 ... train_op = optimizer.minimize(loss, global_step, var_list) 210 >>> sess = tf.compat.v1.Session(graph=g) 211 >>> sess.run(global_init) 212 >>> print("before training:", sess.run(global_step)) 213 before training: 0 214 >>> sess.run(train_op, feed_dict={x: 3}) 215 >>> print("after training:", sess.run(global_step)) 216 after training: 1 217 218 Migrating to a Keras optimizer: 219 220 >>> optimizer = tf.keras.optimizers.SGD(.01) 221 >>> print("before training:", optimizer.iterations.numpy()) 222 before training: 0 223 >>> with tf.GradientTape() as tape: 224 ... loss, var_list = compute_loss(3) 225 ... grads = tape.gradient(loss, var_list) 226 ... optimizer.apply_gradients(zip(grads, var_list)) 227 >>> print("after training:", optimizer.iterations.numpy()) 228 after training: 1 229 230 @end_compatibility 231 """ 232 graph = graph or ops.get_default_graph() 233 if get_global_step(graph) is not None: 234 raise ValueError('"global_step" already exists.') 235 if context.executing_eagerly(): 236 with ops.device('cpu:0'): 237 return variable_scope.get_variable( 238 ops.GraphKeys.GLOBAL_STEP, 239 shape=[], 240 dtype=dtypes.int64, 241 initializer=init_ops.zeros_initializer(), 242 trainable=False, 243 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA, 244 collections=[ 245 ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP 246 ]) 247 # Create in proper graph and base name_scope. 248 with graph.as_default() as g, g.name_scope(None): 249 return variable_scope.get_variable( 250 ops.GraphKeys.GLOBAL_STEP, 251 shape=[], 252 dtype=dtypes.int64, 253 initializer=init_ops.zeros_initializer(), 254 trainable=False, 255 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA, 256 collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) 257 258 259@tf_export(v1=['train.get_or_create_global_step']) 260def get_or_create_global_step(graph=None): 261 """Returns and create (if necessary) the global step tensor. 262 263 Args: 264 graph: The graph in which to create the global step tensor. If missing, use 265 default graph. 266 267 Returns: 268 The global step tensor. 269 270 @compatibility(TF2) 271 With the deprecation of global graphs, TF no longer tracks variables in 272 collections. In other words, there are no global variables in TF2. Thus, the 273 global step functions have been removed (`get_or_create_global_step`, 274 `create_global_step`, `get_global_step`) . You have two options for migrating: 275 276 1. Create a Keras optimizer, which generates an `iterations` variable. This 277 variable is automatically incremented when calling `apply_gradients`. 278 2. Manually create and increment a `tf.Variable`. 279 280 Below is an example of migrating away from using a global step to using a 281 Keras optimizer: 282 283 Define a dummy model and loss: 284 285 >>> def compute_loss(x): 286 ... v = tf.Variable(3.0) 287 ... y = x * v 288 ... loss = x * 5 - x * v 289 ... return loss, [v] 290 291 Before migrating: 292 293 >>> g = tf.Graph() 294 >>> with g.as_default(): 295 ... x = tf.compat.v1.placeholder(tf.float32, []) 296 ... loss, var_list = compute_loss(x) 297 ... global_step = tf.compat.v1.train.get_or_create_global_step() 298 ... global_init = tf.compat.v1.global_variables_initializer() 299 ... optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1) 300 ... train_op = optimizer.minimize(loss, global_step, var_list) 301 >>> sess = tf.compat.v1.Session(graph=g) 302 >>> sess.run(global_init) 303 >>> print("before training:", sess.run(global_step)) 304 before training: 0 305 >>> sess.run(train_op, feed_dict={x: 3}) 306 >>> print("after training:", sess.run(global_step)) 307 after training: 1 308 309 Migrating to a Keras optimizer: 310 311 >>> optimizer = tf.keras.optimizers.SGD(.01) 312 >>> print("before training:", optimizer.iterations.numpy()) 313 before training: 0 314 >>> with tf.GradientTape() as tape: 315 ... loss, var_list = compute_loss(3) 316 ... grads = tape.gradient(loss, var_list) 317 ... optimizer.apply_gradients(zip(grads, var_list)) 318 >>> print("after training:", optimizer.iterations.numpy()) 319 after training: 1 320 321 @end_compatibility 322 """ 323 graph = graph or ops.get_default_graph() 324 global_step_tensor = get_global_step(graph) 325 if global_step_tensor is None: 326 global_step_tensor = create_global_step(graph) 327 return global_step_tensor 328 329 330@tf_export(v1=['train.assert_global_step']) 331def assert_global_step(global_step_tensor): 332 """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`. 333 334 Args: 335 global_step_tensor: `Tensor` to test. 336 """ 337 if not (isinstance(global_step_tensor, variables.Variable) or 338 isinstance(global_step_tensor, ops.Tensor) or 339 resource_variable_ops.is_resource_variable(global_step_tensor)): 340 raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' % 341 global_step_tensor) 342 343 if not global_step_tensor.dtype.base_dtype.is_integer: 344 raise TypeError('Existing "global_step" does not have integer type: %s' % 345 global_step_tensor.dtype) 346 347 if (global_step_tensor.get_shape().ndims != 0 and 348 global_step_tensor.get_shape().is_fully_defined()): 349 raise TypeError('Existing "global_step" is not scalar: %s' % 350 global_step_tensor.get_shape()) 351 352 353def _get_global_step_read(graph=None): 354 """Gets global step read tensor in graph. 355 356 Args: 357 graph: The graph in which to create the global step read tensor. If missing, 358 use default graph. 359 360 Returns: 361 Global step read tensor. 362 363 Raises: 364 RuntimeError: if multiple items found in collection GLOBAL_STEP_READ_KEY. 365 """ 366 graph = graph or ops.get_default_graph() 367 global_step_read_tensors = graph.get_collection(GLOBAL_STEP_READ_KEY) 368 if len(global_step_read_tensors) > 1: 369 raise RuntimeError('There are multiple items in collection {}. ' 370 'There should be only one.'.format(GLOBAL_STEP_READ_KEY)) 371 372 if len(global_step_read_tensors) == 1: 373 return global_step_read_tensors[0] 374 return None 375 376 377def _get_or_create_global_step_read(graph=None): 378 """Gets or creates global step read tensor in graph. 379 380 Args: 381 graph: The graph in which to create the global step read tensor. If missing, 382 use default graph. 383 384 Returns: 385 Global step read tensor if there is global_step_tensor else return None. 386 """ 387 graph = graph or ops.get_default_graph() 388 global_step_read_tensor = _get_global_step_read(graph) 389 if global_step_read_tensor is not None: 390 return global_step_read_tensor 391 global_step_tensor = get_global_step(graph) 392 if global_step_tensor is None: 393 return None 394 # add 'zero' so that it will create a copy of variable as Tensor. 395 with graph.as_default() as g, g.name_scope(None): 396 with g.name_scope(global_step_tensor.op.name + '/'): 397 # using initialized_value to ensure that global_step is initialized before 398 # this run. This is needed for example Estimator makes all model_fn build 399 # under global_step_read_tensor dependency. 400 global_step_value = global_step_tensor.initialized_value() if isinstance( 401 global_step_tensor, variables.Variable) else global_step_tensor 402 global_step_read_tensor = global_step_value + 0 403 ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor) 404 return _get_global_step_read(graph) 405 406 407def _increment_global_step(increment, graph=None): 408 graph = graph or ops.get_default_graph() 409 global_step_tensor = get_global_step(graph) 410 if global_step_tensor is None: 411 raise ValueError( 412 'Global step tensor should be created by ' 413 'tf.train.get_or_create_global_step before calling increment.') 414 global_step_read_tensor = _get_or_create_global_step_read(graph) 415 with graph.as_default() as g, g.name_scope(None): 416 with g.name_scope(global_step_tensor.op.name + '/'): 417 with ops.control_dependencies([global_step_read_tensor]): 418 return state_ops.assign_add(global_step_tensor, increment) 419