1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Maintain moving averages of parameters.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.distribute import distribute_lib 21from tensorflow.python.distribute import distribution_strategy_context 22from tensorflow.python.distribute import reduce_util as ds_reduce_util 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import control_flow_ops 26from tensorflow.python.ops import init_ops 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import state_ops 29from tensorflow.python.ops import variable_scope 30from tensorflow.python.ops import variables 31from tensorflow.python.training import slot_creator 32from tensorflow.python.util.tf_export import tf_export 33 34 35# TODO(touts): switch to variables.Variable. 36def assign_moving_average(variable, value, decay, zero_debias=True, name=None): 37 """Compute the moving average of a variable. 38 39 The moving average of 'variable' updated with 'value' is: 40 variable * decay + value * (1 - decay) 41 42 The returned Operation sets 'variable' to the newly computed moving average, 43 by performing this subtraction: 44 variable -= (1 - decay) * (variable - value) 45 46 Since variables that are initialized to a `0` value will be `0` biased, 47 `zero_debias` optionally enables scaling by the mathematically correct 48 debiasing factor of 49 1 - decay ** num_updates 50 See Section 3 of (Kingma et al., 2015) for more details. 51 52 The names of the debias shadow variables, by default, include both the scope 53 they were created in and the scope of the variables they debias. They are also 54 given a uniquifying-suffix. 55 56 E.g.: 57 58 ``` 59 with tf.compat.v1.variable_scope('scope1'): 60 with tf.compat.v1.variable_scope('scope2'): 61 var = tf.compat.v1.get_variable('foo') 62 update_1 = tf.assign_moving_average(var, 0.0, 1.0) 63 update_2 = tf.assign_moving_average(var, 0.0, 0.9) 64 65 # var.name: 'scope1/scope2/foo' 66 # shadow var names: 'scope1/scope2/scope1/scope2/foo/biased' 67 # 'scope1/scope2/scope1/scope2/foo/biased_1' 68 ``` 69 70 Args: 71 variable: A Variable. 72 value: A tensor with the same shape as 'variable'. 73 decay: A float Tensor or float value. The moving average decay. 74 zero_debias: A python bool. If true, assume the variable is 0-initialized 75 and unbias it, as in (Kingma et al., 2015). See docstring in 76 `_zero_debias` for more details. 77 name: Optional name of the returned operation. 78 79 Returns: 80 A tensor which if evaluated will compute and return the new moving average. 81 82 References: 83 Adam - A Method for Stochastic Optimization: 84 [Kingma et al., 2015](https://arxiv.org/abs/1412.6980) 85 ([pdf](https://arxiv.org/pdf/1412.6980.pdf)) 86 """ 87 with ops.name_scope(name, "AssignMovingAvg", 88 [variable, value, decay]) as scope: 89 decay = ops.convert_to_tensor(1.0 - decay, name="decay") 90 if decay.dtype != variable.dtype.base_dtype: 91 decay = math_ops.cast(decay, variable.dtype.base_dtype) 92 93 def update_fn(v, value): 94 return state_ops.assign_sub(v, (v - value) * decay, name=scope) 95 96 def update(strategy, v, value): 97 if zero_debias: 98 return _zero_debias(strategy, v, value, decay) 99 else: 100 return _update(strategy, v, update_fn, args=(value,)) 101 102 replica_context = distribution_strategy_context.get_replica_context() 103 if replica_context: 104 # In a replica context, we update variable using the mean of value across 105 # replicas. 106 def merge_fn(strategy, v, value): 107 value = strategy.extended.reduce_to(ds_reduce_util.ReduceOp.MEAN, value, 108 v) 109 return update(strategy, v, value) 110 111 return replica_context.merge_call(merge_fn, args=(variable, value)) 112 else: 113 strategy = distribution_strategy_context.get_cross_replica_context() 114 return update(strategy, variable, value) 115 116 117def weighted_moving_average(value, 118 decay, 119 weight, 120 truediv=True, 121 collections=None, 122 name=None): 123 """Compute the weighted moving average of `value`. 124 125 Conceptually, the weighted moving average is: 126 `moving_average(value * weight) / moving_average(weight)`, 127 where a moving average updates by the rule 128 `new_value = decay * old_value + (1 - decay) * update` 129 Internally, this Op keeps moving average variables of both `value * weight` 130 and `weight`. 131 132 Args: 133 value: A numeric `Tensor`. 134 decay: A float `Tensor` or float value. The moving average decay. 135 weight: `Tensor` that keeps the current value of a weight. Shape should be 136 able to multiply `value`. 137 truediv: Boolean, if `True`, dividing by `moving_average(weight)` is 138 floating point division. If `False`, use division implied by dtypes. 139 collections: List of graph collections keys to add the internal variables 140 `value * weight` and `weight` to. Defaults to 141 `[GraphKeys.GLOBAL_VARIABLES]`. 142 name: Optional name of the returned operation. Defaults to 143 "WeightedMovingAvg". 144 145 Returns: 146 An Operation that updates and returns the weighted moving average. 147 """ 148 # Unlike assign_moving_average, the weighted moving average doesn't modify 149 # user-visible variables. It is the ratio of two internal variables, which are 150 # moving averages of the updates. Thus, the signature of this function is 151 # quite different than assign_moving_average. 152 if collections is None: 153 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 154 with variable_scope.variable_scope(name, "WeightedMovingAvg", 155 [value, weight, decay]) as scope: 156 value_x_weight_var = variable_scope.get_variable( 157 "value_x_weight", 158 shape=value.get_shape(), 159 dtype=value.dtype, 160 initializer=init_ops.zeros_initializer(), 161 trainable=False, 162 collections=collections) 163 weight_var = variable_scope.get_variable( 164 "weight", 165 shape=weight.get_shape(), 166 dtype=weight.dtype, 167 initializer=init_ops.zeros_initializer(), 168 trainable=False, 169 collections=collections) 170 numerator = assign_moving_average( 171 value_x_weight_var, value * weight, decay, zero_debias=False) 172 denominator = assign_moving_average( 173 weight_var, weight, decay, zero_debias=False) 174 175 if truediv: 176 return math_ops.truediv(numerator, denominator, name=scope.name) 177 else: 178 return math_ops.divide(numerator, denominator, name=scope.name) 179 180 181def _update(strategy, var, update_fn, args): 182 """Applies updates depending on the context.""" 183 assert distribution_strategy_context.in_cross_replica_context(), ( 184 "_update can only be called in cross-replica context") 185 if distribute_lib.get_update_replica_id() is not None: 186 # Call update_fn on var to delegate the implementation. We expect `var` will 187 # do the right thing in update context, e.g, if `var` is a MirroredVariable, 188 # it should pick its component variable based on `update_replica_id` and 189 # only update that. 190 return update_fn(var, *args) 191 else: 192 return strategy.extended.update(var, update_fn, args) 193 194 195def _zero_debias(strategy, unbiased_var, value, decay): 196 """Compute the delta required for a debiased Variable. 197 198 All exponential moving averages initialized with Tensors are initialized to 0, 199 and therefore are biased to 0. Variables initialized to 0 and used as EMAs are 200 similarly biased. This function creates the debias updated amount according to 201 a scale factor, as in (Kingma et al., 2015). 202 203 To demonstrate the bias the results from 0-initialization, take an EMA that 204 was initialized to `0` with decay `b`. After `t` timesteps of seeing the 205 constant `c`, the variable have the following value: 206 207 ``` 208 EMA = 0*b^(t) + c*(1 - b)*b^(t-1) + c*(1 - b)*b^(t-2) + ... 209 = c*(1 - b^t) 210 ``` 211 212 To have the true value `c`, we would divide by the scale factor `1 - b^t`. 213 214 In order to perform debiasing, we use two shadow variables. One keeps track of 215 the biased estimate, and the other keeps track of the number of updates that 216 have occurred. 217 218 Args: 219 strategy: `Strategy` used to create and update variables. 220 unbiased_var: A Variable representing the current value of the unbiased EMA. 221 value: A Tensor representing the most recent value. 222 decay: A Tensor representing `1-decay` for the EMA. 223 224 Returns: 225 The amount that the unbiased variable should be updated. Computing this 226 tensor will also update the shadow variables appropriately. 227 228 References: 229 Adam - A Method for Stochastic Optimization: 230 [Kingma et al., 2015](https://arxiv.org/abs/1412.6980) 231 ([pdf](https://arxiv.org/pdf/1412.6980.pdf)) 232 233 """ 234 with variable_scope.variable_scope( 235 unbiased_var.name[:-len(":0")], values=[unbiased_var, value, decay]): 236 with ops.init_scope(): 237 biased_initializer = init_ops.zeros_initializer() 238 local_step_initializer = init_ops.zeros_initializer() 239 240 def _maybe_get_unique(name): 241 """Get name for a unique variable, if not `reuse=True`.""" 242 if variable_scope.get_variable_scope().reuse: 243 return name 244 vs_vars = [ 245 x.op.name 246 for x in variable_scope.get_variable_scope().global_variables() 247 ] 248 full_name = variable_scope.get_variable_scope().name + "/" + name 249 if full_name not in vs_vars: 250 return name 251 idx = 1 252 while full_name + ("_%d" % idx) in vs_vars: 253 idx += 1 254 return name + ("_%d" % idx) 255 256 with strategy.extended.colocate_vars_with(unbiased_var): 257 biased_var = variable_scope.get_variable( 258 _maybe_get_unique("biased"), 259 initializer=biased_initializer, 260 shape=unbiased_var.get_shape(), 261 dtype=unbiased_var.dtype, 262 trainable=False) 263 local_step = variable_scope.get_variable( 264 _maybe_get_unique("local_step"), 265 shape=[], 266 dtype=unbiased_var.dtype, 267 initializer=local_step_initializer, 268 trainable=False) 269 270 def update_fn(v, value, biased_var, local_step): 271 update_biased = state_ops.assign_sub(biased_var, 272 (biased_var - value) * decay) 273 update_local_step = local_step.assign_add(1) 274 275 # This function gets `1 - decay`, so use `1.0 - decay` in the exponent. 276 bias_factor = 1 - math_ops.pow(1.0 - decay, update_local_step) 277 return state_ops.assign( 278 v, update_biased / bias_factor, name=ops.get_name_scope() + "/") 279 280 return _update( 281 strategy, unbiased_var, update_fn, args=(value, biased_var, local_step)) 282 283 284@tf_export("train.ExponentialMovingAverage") 285class ExponentialMovingAverage(object): 286 """Maintains moving averages of variables by employing an exponential decay. 287 288 When training a model, it is often beneficial to maintain moving averages of 289 the trained parameters. Evaluations that use averaged parameters sometimes 290 produce significantly better results than the final trained values. 291 292 The `apply()` method adds shadow copies of trained variables and add ops that 293 maintain a moving average of the trained variables in their shadow copies. 294 It is used when building the training model. The ops that maintain moving 295 averages are typically run after each training step. 296 The `average()` and `average_name()` methods give access to the shadow 297 variables and their names. They are useful when building an evaluation 298 model, or when restoring a model from a checkpoint file. They help use the 299 moving averages in place of the last trained values for evaluations. 300 301 The moving averages are computed using exponential decay. You specify the 302 decay value when creating the `ExponentialMovingAverage` object. The shadow 303 variables are initialized with the same initial values as the trained 304 variables. When you run the ops to maintain the moving averages, each 305 shadow variable is updated with the formula: 306 307 `shadow_variable -= (1 - decay) * (shadow_variable - variable)` 308 309 This is mathematically equivalent to the classic formula below, but the use 310 of an `assign_sub` op (the `"-="` in the formula) allows concurrent lockless 311 updates to the variables: 312 313 `shadow_variable = decay * shadow_variable + (1 - decay) * variable` 314 315 Reasonable values for `decay` are close to 1.0, typically in the 316 multiple-nines range: 0.999, 0.9999, etc. 317 318 Example usage when creating a training model: 319 320 ```python 321 # Create variables. 322 var0 = tf.Variable(...) 323 var1 = tf.Variable(...) 324 # ... use the variables to build a training model... 325 ... 326 # Create an op that applies the optimizer. This is what we usually 327 # would use as a training op. 328 opt_op = opt.minimize(my_loss, [var0, var1]) 329 330 # Create an ExponentialMovingAverage object 331 ema = tf.train.ExponentialMovingAverage(decay=0.9999) 332 333 with tf.control_dependencies([opt_op]): 334 # Create the shadow variables, and add ops to maintain moving averages 335 # of var0 and var1. This also creates an op that will update the moving 336 # averages after each training step. This is what we will use in place 337 # of the usual training op. 338 training_op = ema.apply([var0, var1]) 339 340 ...train the model by running training_op... 341 ``` 342 343 There are two ways to use the moving averages for evaluations: 344 345 * Build a model that uses the shadow variables instead of the variables. 346 For this, use the `average()` method which returns the shadow variable 347 for a given variable. 348 * Build a model normally but load the checkpoint files to evaluate by using 349 the shadow variable names. For this use the `average_name()` method. See 350 the `tf.compat.v1.train.Saver` for more 351 information on restoring saved variables. 352 353 Example of restoring the shadow variable values: 354 355 ```python 356 # Create a Saver that loads variables from their saved shadow values. 357 shadow_var0_name = ema.average_name(var0) 358 shadow_var1_name = ema.average_name(var1) 359 saver = tf.compat.v1.train.Saver({shadow_var0_name: var0, shadow_var1_name: 360 var1}) 361 saver.restore(...checkpoint filename...) 362 # var0 and var1 now hold the moving average values 363 ``` 364 """ 365 366 def __init__(self, 367 decay, 368 num_updates=None, 369 zero_debias=False, 370 name="ExponentialMovingAverage"): 371 """Creates a new ExponentialMovingAverage object. 372 373 The `apply()` method has to be called to create shadow variables and add 374 ops to maintain moving averages. 375 376 The optional `num_updates` parameter allows one to tweak the decay rate 377 dynamically. It is typical to pass the count of training steps, usually 378 kept in a variable that is incremented at each step, in which case the 379 decay rate is lower at the start of training. This makes moving averages 380 move faster. If passed, the actual decay rate used is: 381 382 `min(decay, (1 + num_updates) / (10 + num_updates))` 383 384 Args: 385 decay: Float. The decay to use. 386 num_updates: Optional count of number of updates applied to variables. 387 zero_debias: If `True`, zero debias moving-averages that are initialized 388 with tensors. 389 name: String. Optional prefix name to use for the name of ops added in 390 `apply()`. 391 """ 392 self._decay = decay 393 self._num_updates = num_updates 394 self._zero_debias = zero_debias 395 self._name = name 396 self._averages = {} 397 398 @property 399 def name(self): 400 """The name of this ExponentialMovingAverage object.""" 401 return self._name 402 403 def apply(self, var_list=None): 404 """Maintains moving averages of variables. 405 406 `var_list` must be a list of `Variable` or `Tensor` objects. This method 407 creates shadow variables for all elements of `var_list`. Shadow variables 408 for `Variable` objects are initialized to the variable's initial value. 409 They will be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection. 410 For `Tensor` objects, the shadow variables are initialized to 0 and zero 411 debiased (see docstring in `assign_moving_average` for more details). 412 413 shadow variables are created with `trainable=False` and added to the 414 `GraphKeys.ALL_VARIABLES` collection. They will be returned by calls to 415 `tf.compat.v1.global_variables()`. 416 417 Returns an op that updates all shadow variables from the current value of 418 their associated variables. 419 420 Note that `apply()` can be called multiple times. When eager execution is 421 enabled each call to apply will update the variables once, so this needs to 422 be called in a loop. 423 424 Args: 425 var_list: A list of Variable or Tensor objects. The variables and Tensors 426 must be of types bfloat16, float16, float32, or float64. 427 428 Returns: 429 An Operation that updates the moving averages. 430 431 Raises: 432 TypeError: If the arguments are not an allowed type. 433 """ 434 # TODO(touts): op_scope 435 if var_list is None: 436 var_list = variables.trainable_variables() 437 for v in var_list: 438 if isinstance(v, ops.EagerTensor): 439 raise TypeError( 440 "tf.train.ExponentialMovingAverage does not support non-Variable" 441 " tensors when eager execution is enabled.") 442 zero_debias_true = set() # set of vars to set `zero_debias=True` 443 for var in var_list: 444 if var.dtype.base_dtype not in [ 445 dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64 446 ]: 447 raise TypeError("The variables must be half, float, or double: %s" % 448 var.name) 449 450 if var.ref() not in self._averages: 451 # For variables: to lower communication bandwidth across devices we keep 452 # the moving averages on the same device as the variables. For other 453 # tensors, we rely on the existing device allocation mechanism. 454 with ops.init_scope(): 455 if isinstance(var, variables.Variable): 456 with ops.device(var.device): 457 initialized_value = var.initialized_value() 458 avg = slot_creator.create_slot( 459 var, 460 initialized_value, 461 self.name, 462 colocate_with_primary=True, 463 copy_xla_sharding=True) 464 # NOTE(mrry): We only add `tf.Variable` objects to the 465 # `MOVING_AVERAGE_VARIABLES` collection. 466 ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var) 467 else: 468 avg = slot_creator.create_zeros_slot( 469 var, 470 self.name, 471 colocate_with_primary=(var.op.type in [ 472 "Variable", "VariableV2", "VarHandleOp" 473 ]), 474 copy_xla_sharding=True) 475 if self._zero_debias: 476 zero_debias_true.add(avg.ref()) 477 self._averages[var.ref()] = avg 478 479 with ops.name_scope(self.name) as scope: 480 decay = ops.convert_to_tensor( 481 self._decay, dtype=dtypes.float32, name="decay") 482 if self._num_updates is not None: 483 num_updates = math_ops.cast( 484 self._num_updates, dtypes.float32, name="num_updates") 485 decay = math_ops.minimum(decay, 486 (1.0 + num_updates) / (10.0 + num_updates)) 487 updates = [] 488 for var in var_list: 489 avg = self._averages[var.ref()] 490 zero_debias = avg.ref() in zero_debias_true 491 updates.append(assign_moving_average(avg, var, decay, zero_debias)) 492 return control_flow_ops.group(*updates, name=scope) 493 494 def average(self, var): 495 """Returns the `Variable` holding the average of `var`. 496 497 Args: 498 var: A `Variable` object. 499 500 Returns: 501 A `Variable` object or `None` if the moving average of `var` 502 is not maintained. 503 """ 504 return self._averages.get(var.ref(), None) 505 506 def average_name(self, var): 507 """Returns the name of the `Variable` holding the average for `var`. 508 509 The typical scenario for `ExponentialMovingAverage` is to compute moving 510 averages of variables during training, and restore the variables from the 511 computed moving averages during evaluations. 512 513 To restore variables, you have to know the name of the shadow variables. 514 That name and the original variable can then be passed to a `Saver()` object 515 to restore the variable from the moving average value with: 516 `saver = tf.compat.v1.train.Saver({ema.average_name(var): var})` 517 518 `average_name()` can be called whether or not `apply()` has been called. 519 520 Args: 521 var: A `Variable` object. 522 523 Returns: 524 A string: The name of the variable that will be used or was used 525 by the `ExponentialMovingAverage class` to hold the moving average of 526 `var`. 527 """ 528 if var.ref() in self._averages: 529 return self._averages[var.ref()].op.name 530 return ops.get_default_graph().unique_name( 531 var.op.name + "/" + self.name, mark_as_used=False) 532 533 def variables_to_restore(self, moving_avg_variables=None): 534 """Returns a map of names to `Variables` to restore. 535 536 If a variable has a moving average, use the moving average variable name as 537 the restore name; otherwise, use the variable name. 538 539 For example, 540 541 ```python 542 variables_to_restore = ema.variables_to_restore() 543 saver = tf.compat.v1.train.Saver(variables_to_restore) 544 ``` 545 546 Below is an example of such mapping: 547 548 ``` 549 conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma, 550 conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params, 551 global_step: global_step 552 ``` 553 554 Args: 555 moving_avg_variables: a list of variables that require to use of the 556 moving average variable name to be restored. If None, it will default to 557 variables.moving_average_variables() + variables.trainable_variables() 558 559 Returns: 560 A map from restore_names to variables. The restore_name is either the 561 original or the moving average version of the variable name, depending 562 on whether the variable name is in the `moving_avg_variables`. 563 """ 564 name_map = {} 565 if moving_avg_variables is None: 566 # Include trainable variables and variables which have been explicitly 567 # added to the moving_average_variables collection. 568 moving_avg_variables = variables.trainable_variables() 569 moving_avg_variables += variables.moving_average_variables() 570 # Remove duplicates 571 moving_avg_variables = set(moving_avg_variables) 572 # Collect all the variables with moving average, 573 for v in moving_avg_variables: 574 name_map[self.average_name(v)] = v 575 # Make sure we restore variables without moving averages as well. 576 moving_avg_variable_names = set(v.name for v in moving_avg_variables) 577 for v in list(set(variables.global_variables())): 578 if v.name not in moving_avg_variable_names and v.op.name not in name_map: 579 name_map[v.op.name] = v 580 return name_map 581