1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utility functions for training.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import graph_io 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import init_ops 26from tensorflow.python.ops import resource_variable_ops 27from tensorflow.python.ops import state_ops 28from tensorflow.python.ops import variable_scope 29from tensorflow.python.ops import variables 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.util.tf_export import tf_export 32 33# Picked a long key value to minimize the chance of collision with user defined 34# collection keys. 35GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache' 36 37 38# TODO(drpng): remove this after legacy uses are resolved. 39write_graph = graph_io.write_graph 40 41 42@tf_export(v1=['train.global_step']) 43def global_step(sess, global_step_tensor): 44 """Small helper to get the global step. 45 46 ```python 47 # Create a variable to hold the global_step. 48 global_step_tensor = tf.Variable(10, trainable=False, name='global_step') 49 # Create a session. 50 sess = tf.Session() 51 # Initialize the variable 52 sess.run(global_step_tensor.initializer) 53 # Get the variable value. 54 print('global_step: %s' % tf.train.global_step(sess, global_step_tensor)) 55 56 global_step: 10 57 ``` 58 59 Args: 60 sess: A TensorFlow `Session` object. 61 global_step_tensor: `Tensor` or the `name` of the operation that contains 62 the global step. 63 64 Returns: 65 The global step value. 66 """ 67 if context.executing_eagerly(): 68 return int(global_step_tensor.numpy()) 69 return int(sess.run(global_step_tensor)) 70 71 72@tf_export(v1=['train.get_global_step']) 73def get_global_step(graph=None): 74 """Get the global step tensor. 75 76 The global step tensor must be an integer variable. We first try to find it 77 in the collection `GLOBAL_STEP`, or by name `global_step:0`. 78 79 Args: 80 graph: The graph to find the global step in. If missing, use default graph. 81 82 Returns: 83 The global step variable, or `None` if none was found. 84 85 Raises: 86 TypeError: If the global step tensor has a non-integer type, or if it is not 87 a `Variable`. 88 """ 89 graph = graph or ops.get_default_graph() 90 global_step_tensor = None 91 global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP) 92 if len(global_step_tensors) == 1: 93 global_step_tensor = global_step_tensors[0] 94 elif not global_step_tensors: 95 try: 96 global_step_tensor = graph.get_tensor_by_name('global_step:0') 97 except KeyError: 98 return None 99 else: 100 logging.error('Multiple tensors in global_step collection.') 101 return None 102 103 assert_global_step(global_step_tensor) 104 return global_step_tensor 105 106 107@tf_export(v1=['train.create_global_step']) 108def create_global_step(graph=None): 109 """Create global step tensor in graph. 110 111 Args: 112 graph: The graph in which to create the global step tensor. If missing, 113 use default graph. 114 115 Returns: 116 Global step tensor. 117 118 Raises: 119 ValueError: if global step tensor is already defined. 120 """ 121 graph = graph or ops.get_default_graph() 122 if get_global_step(graph) is not None: 123 raise ValueError('"global_step" already exists.') 124 if context.executing_eagerly(): 125 with ops.device('cpu:0'): 126 return variable_scope.get_variable( 127 ops.GraphKeys.GLOBAL_STEP, 128 shape=[], 129 dtype=dtypes.int64, 130 initializer=init_ops.zeros_initializer(), 131 trainable=False, 132 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA, 133 collections=[ops.GraphKeys.GLOBAL_VARIABLES, 134 ops.GraphKeys.GLOBAL_STEP]) 135 # Create in proper graph and base name_scope. 136 with graph.as_default() as g, g.name_scope(None): 137 return variable_scope.get_variable( 138 ops.GraphKeys.GLOBAL_STEP, 139 shape=[], 140 dtype=dtypes.int64, 141 initializer=init_ops.zeros_initializer(), 142 trainable=False, 143 aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA, 144 collections=[ops.GraphKeys.GLOBAL_VARIABLES, 145 ops.GraphKeys.GLOBAL_STEP]) 146 147 148@tf_export(v1=['train.get_or_create_global_step']) 149def get_or_create_global_step(graph=None): 150 """Returns and create (if necessary) the global step tensor. 151 152 Args: 153 graph: The graph in which to create the global step tensor. If missing, use 154 default graph. 155 156 Returns: 157 The global step tensor. 158 """ 159 graph = graph or ops.get_default_graph() 160 global_step_tensor = get_global_step(graph) 161 if global_step_tensor is None: 162 global_step_tensor = create_global_step(graph) 163 return global_step_tensor 164 165 166@tf_export(v1=['train.assert_global_step']) 167def assert_global_step(global_step_tensor): 168 """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`. 169 170 Args: 171 global_step_tensor: `Tensor` to test. 172 """ 173 if not (isinstance(global_step_tensor, variables.Variable) or 174 isinstance(global_step_tensor, ops.Tensor) or 175 resource_variable_ops.is_resource_variable(global_step_tensor)): 176 raise TypeError( 177 'Existing "global_step" must be a Variable or Tensor: %s.' % 178 global_step_tensor) 179 180 if not global_step_tensor.dtype.base_dtype.is_integer: 181 raise TypeError('Existing "global_step" does not have integer type: %s' % 182 global_step_tensor.dtype) 183 184 if (global_step_tensor.get_shape().ndims != 0 and 185 global_step_tensor.get_shape().is_fully_defined()): 186 raise TypeError('Existing "global_step" is not scalar: %s' % 187 global_step_tensor.get_shape()) 188 189 190def _get_global_step_read(graph=None): 191 """Gets global step read tensor in graph. 192 193 Args: 194 graph: The graph in which to create the global step read tensor. If missing, 195 use default graph. 196 197 Returns: 198 Global step read tensor. 199 200 Raises: 201 RuntimeError: if multiple items found in collection GLOBAL_STEP_READ_KEY. 202 """ 203 graph = graph or ops.get_default_graph() 204 global_step_read_tensors = graph.get_collection(GLOBAL_STEP_READ_KEY) 205 if len(global_step_read_tensors) > 1: 206 raise RuntimeError('There are multiple items in collection {}. ' 207 'There should be only one.'.format(GLOBAL_STEP_READ_KEY)) 208 209 if len(global_step_read_tensors) == 1: 210 return global_step_read_tensors[0] 211 return None 212 213 214def _get_or_create_global_step_read(graph=None): 215 """Gets or creates global step read tensor in graph. 216 217 Args: 218 graph: The graph in which to create the global step read tensor. If missing, 219 use default graph. 220 221 Returns: 222 Global step read tensor if there is global_step_tensor else return None. 223 """ 224 graph = graph or ops.get_default_graph() 225 global_step_read_tensor = _get_global_step_read(graph) 226 if global_step_read_tensor is not None: 227 return global_step_read_tensor 228 global_step_tensor = get_global_step(graph) 229 if global_step_tensor is None: 230 return None 231 # add 'zero' so that it will create a copy of variable as Tensor. 232 with graph.as_default() as g, g.name_scope(None): 233 with g.name_scope(global_step_tensor.op.name + '/'): 234 # using initialized_value to ensure that global_step is initialized before 235 # this run. This is needed for example Estimator makes all model_fn build 236 # under global_step_read_tensor dependency. 237 global_step_value = global_step_tensor.initialized_value() if isinstance( 238 global_step_tensor, variables.Variable) else global_step_tensor 239 global_step_read_tensor = global_step_value + 0 240 ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor) 241 return _get_global_step_read(graph) 242 243 244def _increment_global_step(increment, graph=None): 245 graph = graph or ops.get_default_graph() 246 global_step_tensor = get_global_step(graph) 247 if global_step_tensor is None: 248 raise ValueError( 249 'Global step tensor should be created by ' 250 'tf.train.get_or_create_global_step before calling increment.') 251 global_step_read_tensor = _get_or_create_global_step_read(graph) 252 with graph.as_default() as g, g.name_scope(None): 253 with g.name_scope(global_step_tensor.op.name + '/'): 254 with ops.control_dependencies([global_step_read_tensor]): 255 return state_ops.assign_add(global_step_tensor, increment) 256