1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SGD optimizer implementation.""" 16# pylint: disable=g-classes-have-attributes 17 18from tensorflow.python.framework import ops 19from tensorflow.python.keras.optimizer_v2 import optimizer_v2 20from tensorflow.python.ops import array_ops 21from tensorflow.python.ops import gen_resource_variable_ops 22from tensorflow.python.training import gen_training_ops 23from tensorflow.python.util.tf_export import keras_export 24 25 26@keras_export("keras.optimizers.SGD") 27class SGD(optimizer_v2.OptimizerV2): 28 r"""Gradient descent (with momentum) optimizer. 29 30 Update rule for parameter `w` with gradient `g` when `momentum` is 0: 31 32 ```python 33 w = w - learning_rate * g 34 ``` 35 36 Update rule when `momentum` is larger than 0: 37 38 ```python 39 velocity = momentum * velocity - learning_rate * g 40 w = w + velocity 41 ``` 42 43 When `nesterov=True`, this rule becomes: 44 45 ```python 46 velocity = momentum * velocity - learning_rate * g 47 w = w + momentum * velocity - learning_rate * g 48 ``` 49 50 Args: 51 learning_rate: A `Tensor`, floating point value, or a schedule that is a 52 `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable 53 that takes no arguments and returns the actual value to use. The 54 learning rate. Defaults to 0.01. 55 momentum: float hyperparameter >= 0 that accelerates gradient descent 56 in the relevant 57 direction and dampens oscillations. Defaults to 0, i.e., vanilla gradient 58 descent. 59 nesterov: boolean. Whether to apply Nesterov momentum. 60 Defaults to `False`. 61 name: Optional name prefix for the operations created when applying 62 gradients. Defaults to `"SGD"`. 63 **kwargs: Keyword arguments. Allowed to be one of 64 `"clipnorm"` or `"clipvalue"`. 65 `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips 66 gradients by value. 67 68 Usage: 69 70 >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1) 71 >>> var = tf.Variable(1.0) 72 >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1 73 >>> step_count = opt.minimize(loss, [var]).numpy() 74 >>> # Step is `- learning_rate * grad` 75 >>> var.numpy() 76 0.9 77 78 >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9) 79 >>> var = tf.Variable(1.0) 80 >>> val0 = var.value() 81 >>> loss = lambda: (var ** 2)/2.0 # d(loss)/d(var1) = var1 82 >>> # First step is `- learning_rate * grad` 83 >>> step_count = opt.minimize(loss, [var]).numpy() 84 >>> val1 = var.value() 85 >>> (val0 - val1).numpy() 86 0.1 87 >>> # On later steps, step-size increases because of momentum 88 >>> step_count = opt.minimize(loss, [var]).numpy() 89 >>> val2 = var.value() 90 >>> (val1 - val2).numpy() 91 0.18 92 93 Reference: 94 - For `nesterov=True`, See [Sutskever et al., 2013]( 95 http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). 96 """ 97 98 _HAS_AGGREGATE_GRAD = True 99 100 def __init__(self, 101 learning_rate=0.01, 102 momentum=0.0, 103 nesterov=False, 104 name="SGD", 105 **kwargs): 106 super(SGD, self).__init__(name, **kwargs) 107 self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) 108 self._set_hyper("decay", self._initial_decay) 109 110 self._momentum = False 111 if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0: 112 self._momentum = True 113 if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1): 114 raise ValueError("`momentum` must be between [0, 1].") 115 self._set_hyper("momentum", momentum) 116 117 self.nesterov = nesterov 118 119 def _create_slots(self, var_list): 120 if self._momentum: 121 for var in var_list: 122 self.add_slot(var, "momentum") 123 124 def _prepare_local(self, var_device, var_dtype, apply_state): 125 super(SGD, self)._prepare_local(var_device, var_dtype, apply_state) 126 apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity( 127 self._get_hyper("momentum", var_dtype)) 128 129 def _resource_apply_dense(self, grad, var, apply_state=None): 130 var_device, var_dtype = var.device, var.dtype.base_dtype 131 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 132 or self._fallback_apply_state(var_device, var_dtype)) 133 134 if self._momentum: 135 momentum_var = self.get_slot(var, "momentum") 136 return gen_training_ops.ResourceApplyKerasMomentum( 137 var=var.handle, 138 accum=momentum_var.handle, 139 lr=coefficients["lr_t"], 140 grad=grad, 141 momentum=coefficients["momentum"], 142 use_locking=self._use_locking, 143 use_nesterov=self.nesterov) 144 else: 145 return gen_training_ops.ResourceApplyGradientDescent( 146 var=var.handle, 147 alpha=coefficients["lr_t"], 148 delta=grad, 149 use_locking=self._use_locking) 150 151 def _resource_apply_sparse_duplicate_indices(self, grad, var, indices, 152 **kwargs): 153 if self._momentum: 154 return super(SGD, self)._resource_apply_sparse_duplicate_indices( 155 grad, var, indices, **kwargs) 156 else: 157 var_device, var_dtype = var.device, var.dtype.base_dtype 158 coefficients = (kwargs.get("apply_state", {}).get((var_device, var_dtype)) 159 or self._fallback_apply_state(var_device, var_dtype)) 160 161 return gen_resource_variable_ops.ResourceScatterAdd( 162 resource=var.handle, 163 indices=indices, 164 updates=-grad * coefficients["lr_t"]) 165 166 def _resource_apply_sparse(self, grad, var, indices, apply_state=None): 167 # This method is only needed for momentum optimization. 168 var_device, var_dtype = var.device, var.dtype.base_dtype 169 coefficients = ((apply_state or {}).get((var_device, var_dtype)) 170 or self._fallback_apply_state(var_device, var_dtype)) 171 172 momentum_var = self.get_slot(var, "momentum") 173 return gen_training_ops.ResourceSparseApplyKerasMomentum( 174 var=var.handle, 175 accum=momentum_var.handle, 176 lr=coefficients["lr_t"], 177 grad=grad, 178 indices=indices, 179 momentum=coefficients["momentum"], 180 use_locking=self._use_locking, 181 use_nesterov=self.nesterov) 182 183 def get_config(self): 184 config = super(SGD, self).get_config() 185 config.update({ 186 "learning_rate": self._serialize_hyperparameter("learning_rate"), 187 "decay": self._initial_decay, 188 "momentum": self._serialize_hyperparameter("momentum"), 189 "nesterov": self.nesterov, 190 }) 191 return config 192