1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Momentum for TensorFlow.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import math_ops 23from tensorflow.python.training import optimizer 24from tensorflow.python.training import training_ops 25from tensorflow.python.util.tf_export import tf_export 26 27 28@tf_export(v1=["train.MomentumOptimizer"]) 29class MomentumOptimizer(optimizer.Optimizer): 30 """Optimizer that implements the Momentum algorithm. 31 32 Computes (if `use_nesterov = False`): 33 34 ``` 35 accumulation = momentum * accumulation + gradient 36 variable -= learning_rate * accumulation 37 ``` 38 39 Note that in the dense version of this algorithm, `accumulation` is updated 40 and applied regardless of a gradient's value, whereas the sparse version (when 41 the gradient is an `IndexedSlices`, typically because of `tf.gather` or an 42 embedding) only updates variable slices and corresponding `accumulation` terms 43 when that part of the variable was used in the forward pass. 44 """ 45 46 def __init__(self, learning_rate, momentum, 47 use_locking=False, name="Momentum", use_nesterov=False): 48 """Construct a new Momentum optimizer. 49 50 Args: 51 learning_rate: A `Tensor` or a floating point value. The learning rate. 52 momentum: A `Tensor` or a floating point value. The momentum. 53 use_locking: If `True` use locks for update operations. 54 name: Optional name prefix for the operations created when applying 55 gradients. Defaults to "Momentum". 56 use_nesterov: If `True` use Nesterov Momentum. 57 See (Sutskever et al., 2013). 58 This implementation always computes gradients at the value of the 59 variable(s) passed to the optimizer. Using Nesterov Momentum makes the 60 variable(s) track the values called `theta_t + mu*v_t` in the paper. 61 This implementation is an approximation of the original formula, valid 62 for high values of momentum. It will compute the "adjusted gradient" 63 in NAG by assuming that the new gradient will be estimated by the 64 current average gradient plus the product of momentum and the change 65 in the average gradient. 66 67 References: 68 On the importance of initialization and momentum in deep learning: 69 [Sutskever et al., 2013] 70 (http://proceedings.mlr.press/v28/sutskever13.html) 71 ([pdf](http://proceedings.mlr.press/v28/sutskever13.pdf)) 72 73 @compatibility(eager) 74 When eager execution is enabled, `learning_rate` and `momentum` can each be 75 a callable that takes no arguments and returns the actual value to use. This 76 can be useful for changing these values across different invocations of 77 optimizer functions. 78 @end_compatibility 79 """ 80 super(MomentumOptimizer, self).__init__(use_locking, name) 81 self._learning_rate = learning_rate 82 self._momentum = momentum 83 self._use_nesterov = use_nesterov 84 85 def _create_slots(self, var_list): 86 for v in var_list: 87 self._zeros_slot(v, "momentum", self._name) 88 89 def _prepare(self): 90 learning_rate = self._learning_rate 91 if callable(learning_rate): 92 learning_rate = learning_rate() 93 self._learning_rate_tensor = ops.convert_to_tensor(learning_rate, 94 name="learning_rate") 95 momentum = self._momentum 96 if callable(momentum): 97 momentum = momentum() 98 self._momentum_tensor = ops.convert_to_tensor(momentum, name="momentum") 99 100 def _apply_dense(self, grad, var): 101 mom = self.get_slot(var, "momentum") 102 return training_ops.apply_momentum( 103 var, mom, 104 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 105 grad, 106 math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), 107 use_locking=self._use_locking, 108 use_nesterov=self._use_nesterov).op 109 110 def _resource_apply_dense(self, grad, var): 111 mom = self.get_slot(var, "momentum") 112 return training_ops.resource_apply_momentum( 113 var.handle, mom.handle, 114 math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype), 115 grad, 116 math_ops.cast(self._momentum_tensor, grad.dtype.base_dtype), 117 use_locking=self._use_locking, 118 use_nesterov=self._use_nesterov) 119 120 def _apply_sparse(self, grad, var): 121 mom = self.get_slot(var, "momentum") 122 return training_ops.sparse_apply_momentum( 123 var, mom, 124 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 125 grad.values, grad.indices, 126 math_ops.cast(self._momentum_tensor, var.dtype.base_dtype), 127 use_locking=self._use_locking, 128 use_nesterov=self._use_nesterov).op 129 130 def _resource_apply_sparse(self, grad, var, indices): 131 mom = self.get_slot(var, "momentum") 132 return training_ops.resource_sparse_apply_momentum( 133 var.handle, mom.handle, 134 math_ops.cast(self._learning_rate_tensor, grad.dtype), 135 grad, indices, 136 math_ops.cast(self._momentum_tensor, grad.dtype), 137 use_locking=self._use_locking, 138 use_nesterov=self._use_nesterov) 139