• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for training."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.eager import context
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import graph_io
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import init_ops
25from tensorflow.python.ops import resource_variable_ops
26from tensorflow.python.ops import state_ops
27from tensorflow.python.ops import variable_scope
28from tensorflow.python.ops import variables
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.util.tf_export import tf_export
31
32# Picked a long key value to minimize the chance of collision with user defined
33# collection keys.
34GLOBAL_STEP_READ_KEY = 'global_step_read_op_cache'
35
36# TODO(drpng): remove this after legacy uses are resolved.
37write_graph = graph_io.write_graph
38
39
40@tf_export(v1=['train.global_step'])
41def global_step(sess, global_step_tensor):
42  """Small helper to get the global step.
43
44  ```python
45  # Create a variable to hold the global_step.
46  global_step_tensor = tf.Variable(10, trainable=False, name='global_step')
47  # Create a session.
48  sess = tf.compat.v1.Session()
49  # Initialize the variable
50  sess.run(global_step_tensor.initializer)
51  # Get the variable value.
52  print('global_step: %s' % tf.compat.v1.train.global_step(sess,
53  global_step_tensor))
54
55  global_step: 10
56  ```
57
58  Args:
59    sess: A TensorFlow `Session` object.
60    global_step_tensor:  `Tensor` or the `name` of the operation that contains
61      the global step.
62
63  Returns:
64    The global step value.
65  """
66  if context.executing_eagerly():
67    return int(global_step_tensor.numpy())
68  return int(sess.run(global_step_tensor))
69
70
71@tf_export(v1=['train.get_global_step'])
72def get_global_step(graph=None):
73  """Get the global step tensor.
74
75  The global step tensor must be an integer variable. We first try to find it
76  in the collection `GLOBAL_STEP`, or by name `global_step:0`.
77
78  Args:
79    graph: The graph to find the global step in. If missing, use default graph.
80
81  Returns:
82    The global step variable, or `None` if none was found.
83
84  Raises:
85    TypeError: If the global step tensor has a non-integer type, or if it is not
86      a `Variable`.
87
88  @compatibility(TF2)
89  With the deprecation of global graphs, TF no longer tracks variables in
90  collections. In other words, there are no global variables in TF2. Thus, the
91  global step functions have been removed  (`get_or_create_global_step`,
92  `create_global_step`, `get_global_step`) . You have two options for migrating:
93
94  1. Create a Keras optimizer, which generates an `iterations` variable. This
95     variable is automatically incremented when calling `apply_gradients`.
96  2. Manually create and increment a `tf.Variable`.
97
98  Below is an example of migrating away from using a global step to using a
99  Keras optimizer:
100
101  Define a dummy model and loss:
102
103  >>> def compute_loss(x):
104  ...   v = tf.Variable(3.0)
105  ...   y = x * v
106  ...   loss = x * 5 - x * v
107  ...   return loss, [v]
108
109  Before migrating:
110
111  >>> g = tf.Graph()
112  >>> with g.as_default():
113  ...   x = tf.compat.v1.placeholder(tf.float32, [])
114  ...   loss, var_list = compute_loss(x)
115  ...   global_step = tf.compat.v1.train.get_or_create_global_step()
116  ...   global_init = tf.compat.v1.global_variables_initializer()
117  ...   optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
118  ...   train_op = optimizer.minimize(loss, global_step, var_list)
119  >>> sess = tf.compat.v1.Session(graph=g)
120  >>> sess.run(global_init)
121  >>> print("before training:", sess.run(global_step))
122  before training: 0
123  >>> sess.run(train_op, feed_dict={x: 3})
124  >>> print("after training:", sess.run(global_step))
125  after training: 1
126
127  Using `get_global_step`:
128
129  >>> with g.as_default():
130  ...   print(sess.run(tf.compat.v1.train.get_global_step()))
131  1
132
133  Migrating to a Keras optimizer:
134
135  >>> optimizer = tf.keras.optimizers.SGD(.01)
136  >>> print("before training:", optimizer.iterations.numpy())
137  before training: 0
138  >>> with tf.GradientTape() as tape:
139  ...   loss, var_list = compute_loss(3)
140  ...   grads = tape.gradient(loss, var_list)
141  ...   optimizer.apply_gradients(zip(grads, var_list))
142  >>> print("after training:", optimizer.iterations.numpy())
143  after training: 1
144
145  @end_compatibility
146  """
147  graph = graph or ops.get_default_graph()
148  global_step_tensor = None
149  global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP)
150  if len(global_step_tensors) == 1:
151    global_step_tensor = global_step_tensors[0]
152  elif not global_step_tensors:
153    try:
154      global_step_tensor = graph.get_tensor_by_name('global_step:0')
155    except KeyError:
156      return None
157  else:
158    logging.error('Multiple tensors in global_step collection.')
159    return None
160
161  assert_global_step(global_step_tensor)
162  return global_step_tensor
163
164
165@tf_export(v1=['train.create_global_step'])
166def create_global_step(graph=None):
167  """Create global step tensor in graph.
168
169  Args:
170    graph: The graph in which to create the global step tensor. If missing, use
171      default graph.
172
173  Returns:
174    Global step tensor.
175
176  Raises:
177    ValueError: if global step tensor is already defined.
178
179  @compatibility(TF2)
180  With the deprecation of global graphs, TF no longer tracks variables in
181  collections. In other words, there are no global variables in TF2. Thus, the
182  global step functions have been removed  (`get_or_create_global_step`,
183  `create_global_step`, `get_global_step`) . You have two options for migrating:
184
185  1. Create a Keras optimizer, which generates an `iterations` variable. This
186     variable is automatically incremented when calling `apply_gradients`.
187  2. Manually create and increment a `tf.Variable`.
188
189  Below is an example of migrating away from using a global step to using a
190  Keras optimizer:
191
192  Define a dummy model and loss:
193
194  >>> def compute_loss(x):
195  ...   v = tf.Variable(3.0)
196  ...   y = x * v
197  ...   loss = x * 5 - x * v
198  ...   return loss, [v]
199
200  Before migrating:
201
202  >>> g = tf.Graph()
203  >>> with g.as_default():
204  ...   x = tf.compat.v1.placeholder(tf.float32, [])
205  ...   loss, var_list = compute_loss(x)
206  ...   global_step = tf.compat.v1.train.create_global_step()
207  ...   global_init = tf.compat.v1.global_variables_initializer()
208  ...   optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
209  ...   train_op = optimizer.minimize(loss, global_step, var_list)
210  >>> sess = tf.compat.v1.Session(graph=g)
211  >>> sess.run(global_init)
212  >>> print("before training:", sess.run(global_step))
213  before training: 0
214  >>> sess.run(train_op, feed_dict={x: 3})
215  >>> print("after training:", sess.run(global_step))
216  after training: 1
217
218  Migrating to a Keras optimizer:
219
220  >>> optimizer = tf.keras.optimizers.SGD(.01)
221  >>> print("before training:", optimizer.iterations.numpy())
222  before training: 0
223  >>> with tf.GradientTape() as tape:
224  ...   loss, var_list = compute_loss(3)
225  ...   grads = tape.gradient(loss, var_list)
226  ...   optimizer.apply_gradients(zip(grads, var_list))
227  >>> print("after training:", optimizer.iterations.numpy())
228  after training: 1
229
230  @end_compatibility
231  """
232  graph = graph or ops.get_default_graph()
233  if get_global_step(graph) is not None:
234    raise ValueError('"global_step" already exists.')
235  if context.executing_eagerly():
236    with ops.device('cpu:0'):
237      return variable_scope.get_variable(
238          ops.GraphKeys.GLOBAL_STEP,
239          shape=[],
240          dtype=dtypes.int64,
241          initializer=init_ops.zeros_initializer(),
242          trainable=False,
243          aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
244          collections=[
245              ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP
246          ])
247  # Create in proper graph and base name_scope.
248  with graph.as_default() as g, g.name_scope(None):
249    return variable_scope.get_variable(
250        ops.GraphKeys.GLOBAL_STEP,
251        shape=[],
252        dtype=dtypes.int64,
253        initializer=init_ops.zeros_initializer(),
254        trainable=False,
255        aggregation=variables.VariableAggregation.ONLY_FIRST_REPLICA,
256        collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
257
258
259@tf_export(v1=['train.get_or_create_global_step'])
260def get_or_create_global_step(graph=None):
261  """Returns and create (if necessary) the global step tensor.
262
263  Args:
264    graph: The graph in which to create the global step tensor. If missing, use
265      default graph.
266
267  Returns:
268    The global step tensor.
269
270  @compatibility(TF2)
271  With the deprecation of global graphs, TF no longer tracks variables in
272  collections. In other words, there are no global variables in TF2. Thus, the
273  global step functions have been removed  (`get_or_create_global_step`,
274  `create_global_step`, `get_global_step`) . You have two options for migrating:
275
276  1. Create a Keras optimizer, which generates an `iterations` variable. This
277     variable is automatically incremented when calling `apply_gradients`.
278  2. Manually create and increment a `tf.Variable`.
279
280  Below is an example of migrating away from using a global step to using a
281  Keras optimizer:
282
283  Define a dummy model and loss:
284
285  >>> def compute_loss(x):
286  ...   v = tf.Variable(3.0)
287  ...   y = x * v
288  ...   loss = x * 5 - x * v
289  ...   return loss, [v]
290
291  Before migrating:
292
293  >>> g = tf.Graph()
294  >>> with g.as_default():
295  ...   x = tf.compat.v1.placeholder(tf.float32, [])
296  ...   loss, var_list = compute_loss(x)
297  ...   global_step = tf.compat.v1.train.get_or_create_global_step()
298  ...   global_init = tf.compat.v1.global_variables_initializer()
299  ...   optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.1)
300  ...   train_op = optimizer.minimize(loss, global_step, var_list)
301  >>> sess = tf.compat.v1.Session(graph=g)
302  >>> sess.run(global_init)
303  >>> print("before training:", sess.run(global_step))
304  before training: 0
305  >>> sess.run(train_op, feed_dict={x: 3})
306  >>> print("after training:", sess.run(global_step))
307  after training: 1
308
309  Migrating to a Keras optimizer:
310
311  >>> optimizer = tf.keras.optimizers.SGD(.01)
312  >>> print("before training:", optimizer.iterations.numpy())
313  before training: 0
314  >>> with tf.GradientTape() as tape:
315  ...   loss, var_list = compute_loss(3)
316  ...   grads = tape.gradient(loss, var_list)
317  ...   optimizer.apply_gradients(zip(grads, var_list))
318  >>> print("after training:", optimizer.iterations.numpy())
319  after training: 1
320
321  @end_compatibility
322  """
323  graph = graph or ops.get_default_graph()
324  global_step_tensor = get_global_step(graph)
325  if global_step_tensor is None:
326    global_step_tensor = create_global_step(graph)
327  return global_step_tensor
328
329
330@tf_export(v1=['train.assert_global_step'])
331def assert_global_step(global_step_tensor):
332  """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.
333
334  Args:
335    global_step_tensor: `Tensor` to test.
336  """
337  if not (isinstance(global_step_tensor, variables.Variable) or
338          isinstance(global_step_tensor, ops.Tensor) or
339          resource_variable_ops.is_resource_variable(global_step_tensor)):
340    raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' %
341                    global_step_tensor)
342
343  if not global_step_tensor.dtype.base_dtype.is_integer:
344    raise TypeError('Existing "global_step" does not have integer type: %s' %
345                    global_step_tensor.dtype)
346
347  if (global_step_tensor.get_shape().ndims != 0 and
348      global_step_tensor.get_shape().is_fully_defined()):
349    raise TypeError('Existing "global_step" is not scalar: %s' %
350                    global_step_tensor.get_shape())
351
352
353def _get_global_step_read(graph=None):
354  """Gets global step read tensor in graph.
355
356  Args:
357    graph: The graph in which to create the global step read tensor. If missing,
358      use default graph.
359
360  Returns:
361    Global step read tensor.
362
363  Raises:
364    RuntimeError: if multiple items found in collection GLOBAL_STEP_READ_KEY.
365  """
366  graph = graph or ops.get_default_graph()
367  global_step_read_tensors = graph.get_collection(GLOBAL_STEP_READ_KEY)
368  if len(global_step_read_tensors) > 1:
369    raise RuntimeError('There are multiple items in collection {}. '
370                       'There should be only one.'.format(GLOBAL_STEP_READ_KEY))
371
372  if len(global_step_read_tensors) == 1:
373    return global_step_read_tensors[0]
374  return None
375
376
377def _get_or_create_global_step_read(graph=None):
378  """Gets or creates global step read tensor in graph.
379
380  Args:
381    graph: The graph in which to create the global step read tensor. If missing,
382      use default graph.
383
384  Returns:
385    Global step read tensor if there is global_step_tensor else return None.
386  """
387  graph = graph or ops.get_default_graph()
388  global_step_read_tensor = _get_global_step_read(graph)
389  if global_step_read_tensor is not None:
390    return global_step_read_tensor
391  global_step_tensor = get_global_step(graph)
392  if global_step_tensor is None:
393    return None
394  # add 'zero' so that it will create a copy of variable as Tensor.
395  with graph.as_default() as g, g.name_scope(None):
396    with g.name_scope(global_step_tensor.op.name + '/'):
397      # using initialized_value to ensure that global_step is initialized before
398      # this run. This is needed for example Estimator makes all model_fn build
399      # under global_step_read_tensor dependency.
400      global_step_value = global_step_tensor.initialized_value() if isinstance(
401          global_step_tensor, variables.Variable) else global_step_tensor
402      global_step_read_tensor = global_step_value + 0
403      ops.add_to_collection(GLOBAL_STEP_READ_KEY, global_step_read_tensor)
404  return _get_global_step_read(graph)
405
406
407def _increment_global_step(increment, graph=None):
408  graph = graph or ops.get_default_graph()
409  global_step_tensor = get_global_step(graph)
410  if global_step_tensor is None:
411    raise ValueError(
412        'Global step tensor should be created by '
413        'tf.train.get_or_create_global_step before calling increment.')
414  global_step_read_tensor = _get_or_create_global_step_read(graph)
415  with graph.as_default() as g, g.name_scope(None):
416    with g.name_scope(global_step_tensor.op.name + '/'):
417      with ops.control_dependencies([global_step_read_tensor]):
418        return state_ops.assign_add(global_step_tensor, increment)
419