• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Maintain moving averages of parameters."""
16from tensorflow.python.distribute import distribute_lib
17from tensorflow.python.distribute import distribution_strategy_context
18from tensorflow.python.distribute import reduce_util as ds_reduce_util
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import control_flow_ops
22from tensorflow.python.ops import init_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops import state_ops
25from tensorflow.python.ops import variable_scope
26from tensorflow.python.ops import variables
27from tensorflow.python.training import slot_creator
28from tensorflow.python.util.tf_export import tf_export
29from tensorflow.tools.docs import doc_controls
30
31
32@tf_export("__internal__.train.assign_moving_average", v1=[])
33def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
34  """Compute the moving average of a variable.
35
36  The moving average of 'variable' updated with 'value' is:
37    variable * decay + value * (1 - decay)
38
39  The returned Operation sets 'variable' to the newly computed moving average,
40  by performing this subtraction:
41     variable -= (1 - decay) * (variable - value)
42
43  Since variables that are initialized to a `0` value will be `0` biased,
44  `zero_debias` optionally enables scaling by the mathematically correct
45  debiasing factor of
46    1 - decay ** num_updates
47  See Section 3 of (Kingma et al., 2015) for more details.
48
49  The names of the debias shadow variables, by default, include both the scope
50  they were created in and the scope of the variables they debias. They are also
51  given a uniquifying-suffix.
52
53  E.g.:
54
55  ```
56    with tf.compat.v1.variable_scope('scope1'):
57      with tf.compat.v1.variable_scope('scope2'):
58        var = tf.compat.v1.get_variable('foo')
59        update_1 = tf.assign_moving_average(var, 0.0, 1.0)
60        update_2 = tf.assign_moving_average(var, 0.0, 0.9)
61
62    # var.name: 'scope1/scope2/foo'
63    # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
64    #                   'scope1/scope2/scope1/scope2/foo/biased_1'
65  ```
66
67  Args:
68    variable: A Variable.
69    value: A tensor with the same shape as 'variable'.
70    decay: A float `Tensor` or float value. The moving average decay.
71    zero_debias: A python bool. If true, assume the variable is 0-initialized
72      and unbias it, as in (Kingma et al., 2015). See docstring in
73        `_zero_debias` for more details.
74    name: Optional name of the returned operation.
75
76  Returns:
77    A tensor which if evaluated will compute and return the new moving average.
78
79  References:
80    Adam - A Method for Stochastic Optimization:
81      [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
82      ([pdf](https://arxiv.org/pdf/1412.6980.pdf))
83  """
84  with ops.name_scope(name, "AssignMovingAvg",
85                      [variable, value, decay]) as scope:
86    decay = ops.convert_to_tensor(1.0 - decay, name="decay")
87    if decay.dtype != variable.dtype.base_dtype:
88      decay = math_ops.cast(decay, variable.dtype.base_dtype)
89
90    def update_fn(v, value):
91      return state_ops.assign_sub(v, (v - value) * decay, name=scope)
92
93    def update(strategy, v, value):
94      if zero_debias:
95        return _zero_debias(strategy, v, value, decay)
96      else:
97        return _update(strategy, v, update_fn, args=(value,))
98
99    replica_context = distribution_strategy_context.get_replica_context()
100    if replica_context:
101      # In a replica context, we update variable using the mean of value across
102      # replicas.
103      def merge_fn(strategy, v, value):
104        value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value,
105                                            v)
106        return update(strategy, v, value)
107
108      return replica_context.merge_call(merge_fn, args=(variable, value))
109    else:
110      strategy = distribution_strategy_context.get_cross_replica_context()
111      return update(strategy, variable, value)
112
113
114def weighted_moving_average(value,
115                            decay,
116                            weight,
117                            truediv=True,
118                            collections=None,
119                            name=None):
120  """Compute the weighted moving average of `value`.
121
122  Conceptually, the weighted moving average is:
123    `moving_average(value * weight) / moving_average(weight)`,
124  where a moving average updates by the rule
125    `new_value = decay * old_value + (1 - decay) * update`
126  Internally, this Op keeps moving average variables of both `value * weight`
127  and `weight`.
128
129  Args:
130    value: A numeric `Tensor`.
131    decay: A float `Tensor` or float value. The moving average decay.
132    weight:  `Tensor` that keeps the current value of a weight. Shape should be
133      able to multiply `value`.
134    truediv:  Boolean, if `True`, dividing by `moving_average(weight)` is
135      floating point division.  If `False`, use division implied by dtypes.
136    collections:  List of graph collections keys to add the internal variables
137      `value * weight` and `weight` to. Defaults to
138      `[GraphKeys.GLOBAL_VARIABLES]`.
139    name: Optional name of the returned operation. Defaults to
140      "WeightedMovingAvg".
141
142  Returns:
143    An Operation that updates and returns the weighted moving average.
144  """
145  # Unlike assign_moving_average, the weighted moving average doesn't modify
146  # user-visible variables. It is the ratio of two internal variables, which are
147  # moving averages of the updates.  Thus, the signature of this function is
148  # quite different than assign_moving_average.
149  if collections is None:
150    collections = [ops.GraphKeys.GLOBAL_VARIABLES]
151  with variable_scope.variable_scope(name, "WeightedMovingAvg",
152                                     [value, weight, decay]) as scope:
153    value_x_weight_var = variable_scope.get_variable(
154        "value_x_weight",
155        shape=value.get_shape(),
156        dtype=value.dtype,
157        initializer=init_ops.zeros_initializer(),
158        trainable=False,
159        collections=collections)
160    weight_var = variable_scope.get_variable(
161        "weight",
162        shape=weight.get_shape(),
163        dtype=weight.dtype,
164        initializer=init_ops.zeros_initializer(),
165        trainable=False,
166        collections=collections)
167    numerator = assign_moving_average(
168        value_x_weight_var, value * weight, decay, zero_debias=False)
169    denominator = assign_moving_average(
170        weight_var, weight, decay, zero_debias=False)
171
172    if truediv:
173      return math_ops.truediv(numerator, denominator, name=scope.name)
174    else:
175      return math_ops.divide(numerator, denominator, name=scope.name)
176
177
178def _update(strategy, var, update_fn, args):
179  """Applies updates depending on the context."""
180  assert distribution_strategy_context.in_cross_replica_context(), (
181      "_update can only be called in cross-replica context")
182  if distribute_lib.get_update_replica_id() is not None:
183    # Call update_fn on var to delegate the implementation. We expect `var` will
184    # do the right thing in update context, e.g, if `var` is a MirroredVariable,
185    # it should pick its component variable based on `update_replica_id` and
186    # only update that.
187    return update_fn(var, *args)
188  else:
189    return strategy.extended.update(var, update_fn, args)
190
191
192def _zero_debias(strategy, unbiased_var, value, decay):
193  """Compute the delta required for a debiased Variable.
194
195  All exponential moving averages initialized with Tensors are initialized to 0,
196  and therefore are biased to 0. Variables initialized to 0 and used as EMAs are
197  similarly biased. This function creates the debias updated amount according to
198  a scale factor, as in (Kingma et al., 2015).
199
200  To demonstrate the bias the results from 0-initialization, take an EMA that
201  was initialized to `0` with decay `b`. After `t` timesteps of seeing the
202  constant `c`, the variable have the following value:
203
204  ```
205    EMA = 0*b^(t) + c*(1 - b)*b^(t-1) + c*(1 - b)*b^(t-2) + ...
206        = c*(1 - b^t)
207  ```
208
209  To have the true value `c`, we would divide by the scale factor `1 - b^t`.
210
211  In order to perform debiasing, we use two shadow variables. One keeps track of
212  the biased estimate, and the other keeps track of the number of updates that
213  have occurred.
214
215  Args:
216    strategy: `Strategy` used to create and update variables.
217    unbiased_var: A Variable representing the current value of the unbiased EMA.
218    value: A Tensor representing the most recent value.
219    decay: A Tensor representing `1-decay` for the EMA.
220
221  Returns:
222    The amount that the unbiased variable should be updated. Computing this
223    tensor will also update the shadow variables appropriately.
224
225  References:
226    Adam - A Method for Stochastic Optimization:
227      [Kingma et al., 2015](https://arxiv.org/abs/1412.6980)
228      ([pdf](https://arxiv.org/pdf/1412.6980.pdf))
229
230  """
231  with variable_scope.variable_scope(
232      unbiased_var.name[:-len(":0")], values=[unbiased_var, value, decay]):
233    with ops.init_scope():
234      biased_initializer = init_ops.zeros_initializer()
235      local_step_initializer = init_ops.zeros_initializer()
236
237    def _maybe_get_unique(name):
238      """Get name for a unique variable, if not `reuse=True`."""
239      if variable_scope.get_variable_scope().reuse:
240        return name
241      vs_vars = [
242          x.op.name
243          for x in variable_scope.get_variable_scope().global_variables()
244      ]
245      full_name = variable_scope.get_variable_scope().name + "/" + name
246      if full_name not in vs_vars:
247        return name
248      idx = 1
249      while full_name + ("_%d" % idx) in vs_vars:
250        idx += 1
251      return name + ("_%d" % idx)
252
253    with strategy.extended.colocate_vars_with(unbiased_var):
254      biased_var = variable_scope.get_variable(
255          _maybe_get_unique("biased"),
256          initializer=biased_initializer,
257          shape=unbiased_var.get_shape(),
258          dtype=unbiased_var.dtype,
259          trainable=False)
260      local_step = variable_scope.get_variable(
261          _maybe_get_unique("local_step"),
262          shape=[],
263          dtype=unbiased_var.dtype,
264          initializer=local_step_initializer,
265          trainable=False)
266
267  def update_fn(v, value, biased_var, local_step):
268    update_biased = state_ops.assign_sub(biased_var,
269                                         (biased_var - value) * decay)
270    update_local_step = local_step.assign_add(1)
271
272    # This function gets `1 - decay`, so use `1.0 - decay` in the exponent.
273    bias_factor = 1 - math_ops.pow(1.0 - decay, update_local_step)
274    return state_ops.assign(
275        v, update_biased / bias_factor, name=ops.get_name_scope() + "/")
276
277  return _update(
278      strategy, unbiased_var, update_fn, args=(value, biased_var, local_step))
279
280
281@tf_export("train.ExponentialMovingAverage")
282class ExponentialMovingAverage:
283  """Maintains moving averages of variables by employing an exponential decay.
284
285  When training a model, it is often beneficial to maintain moving averages of
286  the trained parameters.  Evaluations that use averaged parameters sometimes
287  produce significantly better results than the final trained values.
288
289  The `apply()` method adds shadow copies of trained variables the first time
290  it is called, and maintains a moving average of the trained variables in
291  their shadow copies at every additional invocation.
292  It should generally be called immediately after creating the model weights,
293  and then after each training step.
294
295  The `average()` method gives access to the shadow variables.
296  It allows you to use the moving averages in place of the last trained values
297  for evaluations, by loading the moving averages into your model via
298  `var.assign(ema.average(var))`.
299  Additionally, although `ExponentialMovingAverage`
300  objects are not directly trackable by checkpoints,
301  `average()` returns the moving average variables for your model weights,
302  which you can then checkpoint. (There is an example
303  of this near the bottom of this docstring).
304  So, `average()` is useful when
305  building an evaluation model, or when restoring a model from a checkpoint
306  file.
307
308  The moving averages are computed using exponential decay.  You specify the
309  decay value (as a scalar float value, `Tensor`, or `Variable`) when creating
310  the `ExponentialMovingAverage` object.  The shadow variables are initialized
311  with the same initial values as the trained variables.  When you run `apply`
312  to update the moving averages, each shadow variable is updated with the
313  formula:
314
315    `shadow_variable -= (1 - decay) * (shadow_variable - variable)`
316
317  This is mathematically equivalent to the classic formula below, but the use
318  of an `assign_sub` op (the `"-="` in the formula) allows concurrent lockless
319  updates to the variables:
320
321    `shadow_variable = decay * shadow_variable + (1 - decay) * variable`
322
323  Reasonable values for `decay` are close to 1.0, typically in the
324  multiple-nines range: 0.999, 0.9999, etc.
325
326  To have fine-grained control over the value of the decay parameter during
327  training, pass a scalar `tf.Variable` as the `decay` value to the constructor,
328  and update the variable as needed.
329
330  Example usage when creating a training model:
331
332  ```python
333  # Create variables.
334  var0 = tf.Variable(...)
335  var1 = tf.Variable(...)
336  # ... use the variables to build a training model...
337
338  # Create an ExponentialMovingAverage object
339  ema = tf.train.ExponentialMovingAverage(decay=0.9999)
340
341  # The first `apply` creates the shadow variables that hold the moving averages
342  ema.apply([var0, var1])
343
344  # grab the moving averages for checkpointing purposes or to be able to
345  # load the moving averages into the model weights
346  averages = [ema.average(var0), ema.average(var1)]
347
348  ...
349  def train_step(...):
350  ...
351    # Apply the optimizer.
352    opt.minimize(my_loss, [var0, var1])
353
354    # Update the moving averages
355    # of var0 and var1 with additional calls to `apply`
356    ema.apply([var0, var1])
357
358  ...train the model by running train_step multiple times...
359  ```
360
361  There are several ways to use the moving averages for evaluations:
362
363  1. Assign the values of the shadow variables to your model variables with
364     `Variable.assign(...)` before evaluating your
365     model. You can use the `average()`
366     method to get the shadow variable for a given variable. To continue
367     training after using this approach, make sure to record the unaveraged
368     weights and restore them before continuing to train. You can see the
369     tensorflow-addons' MovingAverage optimizer's `swap_weights` method for
370     one example of how to swap variables efficiently in distributed settings:
371     https://github.com/tensorflow/addons/blob/v0.13.0/tensorflow_addons/optimizers/moving_average.py#L151
372  2. Make sure to checkpoint out your moving average variables in your
373     `tf.train.Checkpoint`. At evaluation time, create your shadow variables and
374     use `tf.train.Checkpoint` to restore the moving averages into the shadow
375     variables. Then, load the moving averages into the actual model weights via
376     `var.assign(moving_avg)`.
377  3. Checkpoint out your moving average variables in your `tf.train.Checkpoint`.
378     For evaluation, restore your model weights directly from the moving
379     averages instead of from the non-averaged weights.
380     Caution: If you choose this approach, include only the object-graph paths
381     to the averaged path in your checkpoint restore.
382     If you point both the unaveraged and averaged paths in a checkpoint
383     restore to the same variables, it is hard to reason about whether your
384     model will restore the averaged or non-averaged variables.
385
386  Example of saving out then restoring the shadow variable values:
387
388  ```python
389  # Create variables.
390  var0 = tf.Variable(...)
391  var1 = tf.Variable(...)
392  # ... use the variables to build a training model...
393
394  # Create an ExponentialMovingAverage object, create the shadow variables,
395  # and grab the moving averages for checkpointing purposes.
396  # (The ExponentialMovingAverage object itself is not checkpointable)
397  ema = tf.train.ExponentialMovingAverage(decay=0.9999)
398  ema.apply([var0, var1])
399  avg_var0 = ema.average(var0)
400  avg_var1 = ema.average(var1)
401
402  # Create a Checkpoint that will manage the model weights and the averages,
403  checkpoint = tf.train.Checkpoint(model_weights=[var0, var1],
404                                   averaged_weights=[avg_var0, avg_var1])
405  ... # Do training
406
407  # Save out the checkpoint including the model weights and the moving averages
408  checkpoint.save(...)
409  ```
410
411  Restore option: restore all averaged & non-averaged weights, then load
412  moving averages into the model via `var.assign()`
413  ```python
414  # Create variables.
415  var0 = tf.Variable(...)
416  var1 = tf.Variable(...)
417  # ... use the variables to build a training model...
418
419  # Create an ExponentialMovingAverage object, create the shadow variables,
420  # and grab the moving averages for checkpoint restore purposes.
421  # (The ExponentialMovingAverage object itself is not checkpointable)
422  ema = tf.train.ExponentialMovingAverage(decay=0.9999)
423  ema.apply([var0, var1])
424  avg_var0 = ema.average(var0)
425  avg_var1 = ema.average(var1)
426
427  # Create a Checkpoint that will manage the model weights and the averages,
428  checkpoint = tf.train.Checkpoint(model_weights=[var0, var1],
429                                   averaged_weights=[avg_var0, avg_var1])
430  checkpoint.restore(...)
431  var0.assign(avg_var0)
432  var1.assign(avg_var1)
433  # var0 and var1 now hold the moving average values
434  ```
435
436  Restore option: Directly restore the moving averages into the model weights.
437  ```python
438  # Create variables.
439  var0 = tf.Variable(...)
440  var1 = tf.Variable(...)
441  # ... use the variables to build a training model...
442
443  # Create a Checkpoint that will manage two objects with trackable state,
444  checkpoint = tf.train.Checkpoint(averaged_weights=[var0, var1])
445  checkpoint.restore(...)
446  # var0 and var1 now hold the moving average values
447  ```
448  """
449
450  def __init__(self,
451               decay,
452               num_updates=None,
453               zero_debias=False,
454               name="ExponentialMovingAverage"):
455    """Creates a new ExponentialMovingAverage object.
456
457    The `apply()` method has to be called to create shadow variables.
458    Follow-on calls to the `apply()` method will update the moving averages
459    in the shadow variables.
460    (In TF 1.x graphs `apply()` will return an update op to update
461    the moving averages which must be explicitly run).
462
463    The optional `num_updates` parameter allows one to tweak the decay rate
464    dynamically. It is typical to pass the count of training steps, usually
465    kept in a variable that is incremented at each step, in which case the
466    decay rate is lower at the start of training.  This makes moving averages
467    move faster.  If passed, the actual decay rate used is:
468
469      `min(decay, (1 + num_updates) / (10 + num_updates))`
470
471    Args:
472      decay: A scalar float value, `Tensor`, or `Variable`. The decay parameter.
473      num_updates: Optional count of number of updates applied to variables.
474      zero_debias: If `True`, zero debias moving-averages that are initialized
475        with tensors. (Note: moving averages may not be initialized with
476        non-variable tensors when eager execution is enabled).
477      name: String. Optional prefix name to use for the name of ops added in
478        `apply()`.
479    """
480    self._decay = decay
481    self._num_updates = num_updates
482    self._zero_debias = zero_debias
483    self._name = name
484    self._averages = {}
485
486  @property
487  def name(self):
488    """The name of this ExponentialMovingAverage object."""
489    return self._name
490
491  def apply(self, var_list=None):
492    """Maintains moving averages of variables.
493
494    `var_list` must be a list of `Variable` objects.  This method
495    creates shadow variables (holding the moving averages)
496    for all elements of `var_list`, and
497    updates the moving averages using the current `var_list` values. Shadow
498    variables for `Variable` objects are initialized to the variable's initial
499    value.
500
501    Shadow variables are created with `trainable=False`. To access them you
502    can use the EMA object's `average` method. Note that `EMA` objects are
503    not trackable by checkpoints, so if you want to checkpoint or restore the
504    moving variables you will need to manually grab the shadow
505    variables via `average()` and assign them as `tf.Module` properties or
506    directly pass them to your `tf.train.Checkpoint`.
507
508    Note that `apply()` can be called multiple times. When eager execution is
509    enabled each call to apply will update the variables once, so this needs to
510    be called in a loop.
511
512    In legacy TF 1.x graphs, this method returns an op that updates all
513    shadow variables from the current value of their associated variables. In
514    TF 1.x graphs without automatically control dependencies this op needs to be
515    manually run.
516
517    Args:
518      var_list: A list of Variable objects. The variables
519        must be of types bfloat16, float16, float32, or float64.
520        (In legacy TF 1.x graphs these may be tensors, but this is unsupported
521        when eager execution is enabled.)
522
523    Returns:
524      An Operation that updates the moving averages.
525
526    Raises:
527      TypeError: If the arguments are not an allowed type.
528    """
529    # TODO(touts): op_scope
530    if var_list is None:
531      var_list = variables.trainable_variables()
532    for v in var_list:
533      if (isinstance(v, ops.Tensor)
534          and ops.executing_eagerly_outside_functions()):
535        raise TypeError(
536            "tf.train.ExponentialMovingAverage does not support non-Variable"
537            " tensors when eager execution is enabled.")
538    zero_debias_true = set()  # set of vars to set `zero_debias=True`
539    for var in var_list:
540      if var.dtype.base_dtype not in [
541          dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64
542      ]:
543        raise TypeError("The variables must be half, float, or double: %s" %
544                        var.name)
545
546      if var.ref() not in self._averages:
547        # For variables: to lower communication bandwidth across devices we keep
548        # the moving averages on the same device as the variables. For other
549        # tensors, we rely on the existing device allocation mechanism.
550        with ops.init_scope():
551          if isinstance(var, variables.Variable):
552            with ops.device(var.device):
553              initialized_value = var.initialized_value()
554            avg = slot_creator.create_slot(
555                var,
556                initialized_value,
557                self.name,
558                colocate_with_primary=True,
559                copy_xla_sharding=True)
560            # NOTE(mrry): We only add `tf.Variable` objects to the
561            # `MOVING_AVERAGE_VARIABLES` collection.
562            ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
563          else:
564            avg = slot_creator.create_zeros_slot(
565                var,
566                self.name,
567                colocate_with_primary=(var.op.type in [
568                    "Variable", "VariableV2", "VarHandleOp"
569                ]),
570                copy_xla_sharding=True)
571            if self._zero_debias:
572              zero_debias_true.add(avg.ref())
573        self._averages[var.ref()] = avg
574
575    with ops.name_scope(self.name) as scope:
576      decay = ops.convert_to_tensor(
577          self._decay, dtype=dtypes.float32, name="decay")
578      if self._num_updates is not None:
579        num_updates = math_ops.cast(
580            self._num_updates, dtypes.float32, name="num_updates")
581        decay = math_ops.minimum(decay,
582                                 (1.0 + num_updates) / (10.0 + num_updates))
583      updates = []
584      for var in var_list:
585        avg = self._averages[var.ref()]
586        zero_debias = avg.ref() in zero_debias_true
587        updates.append(assign_moving_average(avg, var, decay, zero_debias))
588      return control_flow_ops.group(*updates, name=scope)
589
590  def average(self, var):
591    """Returns the `Variable` holding the average of `var`.
592
593    Args:
594      var: A `Variable` object.
595
596    Returns:
597      A `Variable` object or `None` if the moving average of `var`
598      is not maintained.
599    """
600    return self._averages.get(var.ref(), None)
601
602  @doc_controls.do_not_generate_docs
603  def average_name(self, var):
604    """[Meant for TF1] Returns name of `Variable` holding the average for `var`.
605
606    (Designed to work with legacy `tf.compat.v1.train.Saver`, it is sensitive to
607    specific variable names and not recommended for TF2)
608
609    The typical scenario for `ExponentialMovingAverage` is to compute moving
610    averages of variables during training, and restore the variables from the
611    computed moving averages during evaluations.
612
613    To restore variables, you have to know the name of the shadow variables.
614    That name and the original variable can then be passed to a `Saver()` object
615    to restore the variable from the moving average value with:
616      `saver = tf.compat.v1.train.Saver({ema.average_name(var): var})`
617
618    `average_name()` can be called whether or not `apply()` has been called.
619
620    Args:
621      var: A `Variable` object.
622
623    Returns:
624      A string: The name of the variable that will be used or was used
625      by the `ExponentialMovingAverage class` to hold the moving average of
626      `var`.
627    """
628    if var.ref() in self._averages:
629      return self._averages[var.ref()].name[:-len(":0")]
630    return ops.get_default_graph().unique_name(
631        var.name[:-len(":0")] + "/" + self.name, mark_as_used=False)
632
633  @doc_controls.do_not_generate_docs
634  def variables_to_restore(self, moving_avg_variables=None):
635    """[Designed for TF 1.x] Returns a map of names to `Variables` to restore.
636
637    (Designed to work with legacy `tf.compat.v1.train.Saver`, sensitive to
638    specific variable names and not recommended for TF2)
639
640    If a variable has a moving average, use the moving average variable name as
641    the restore name; otherwise, use the variable name.
642
643    For example,
644
645    ```python
646      variables_to_restore = ema.variables_to_restore()
647      saver = tf.compat.v1.train.Saver(variables_to_restore)
648    ```
649
650    Below is an example of such mapping:
651
652    ```
653      conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma,
654      conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params,
655      global_step: global_step
656    ```
657
658    Args:
659      moving_avg_variables: a list of variables that require to use of the
660        moving average variable name to be restored. If None, it will default to
661        variables.moving_average_variables() + variables.trainable_variables()
662
663    Returns:
664      A map from restore_names to variables. The restore_name is either the
665      original or the moving average version of the variable name, depending
666      on whether the variable name is in the `moving_avg_variables`.
667    """
668    name_map = {}
669    if moving_avg_variables is None:
670      # Include trainable variables and variables which have been explicitly
671      # added to the moving_average_variables collection.
672      moving_avg_variables = variables.trainable_variables()
673      moving_avg_variables += variables.moving_average_variables()
674    # Remove duplicates
675    moving_avg_variables = set(v.ref() for v in moving_avg_variables)
676    # Collect all the variables with moving average,
677    for v in moving_avg_variables:
678      name_map[self.average_name(v.deref())] = v.deref()
679    # Make sure we restore variables without moving averages as well.
680    moving_avg_variable_names = set(
681        v.deref().name for v in moving_avg_variables)
682    for v in list(set(variables.global_variables())):
683      if v.name not in moving_avg_variable_names and v.op.name not in name_map:
684        name_map[v.op.name] = v
685    return name_map
686