• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SGD optimizer implementation."""
16# pylint: disable=g-classes-have-attributes
17
18from tensorflow.python.framework import ops
19from tensorflow.python.keras.optimizer_v2 import optimizer_v2
20from tensorflow.python.ops import array_ops
21from tensorflow.python.ops import gen_resource_variable_ops
22from tensorflow.python.training import gen_training_ops
23from tensorflow.python.util.tf_export import keras_export
24
25
26@keras_export("keras.optimizers.SGD")
27class SGD(optimizer_v2.OptimizerV2):
28  r"""Gradient descent (with momentum) optimizer.
29
30  Update rule for parameter `w` with gradient `g` when `momentum` is 0:
31
32  ```python
33  w = w - learning_rate * g
34  ```
35
36  Update rule when `momentum` is larger than 0:
37
38  ```python
39  velocity = momentum * velocity - learning_rate * g
40  w = w + velocity
41  ```
42
43  When `nesterov=True`, this rule becomes:
44
45  ```python
46  velocity = momentum * velocity - learning_rate * g
47  w = w + momentum * velocity - learning_rate * g
48  ```
49
50  Args:
51    learning_rate: A `Tensor`, floating point value, or a schedule that is a
52      `tf.keras.optimizers.schedules.LearningRateSchedule`, or a callable
53      that takes no arguments and returns the actual value to use. The
54      learning rate. Defaults to 0.01.
55    momentum: float hyperparameter >= 0 that accelerates gradient descent
56      in the relevant
57      direction and dampens oscillations. Defaults to 0, i.e., vanilla gradient
58      descent.
59    nesterov: boolean. Whether to apply Nesterov momentum.
60      Defaults to `False`.
61    name: Optional name prefix for the operations created when applying
62      gradients.  Defaults to `"SGD"`.
63    **kwargs: Keyword arguments. Allowed to be one of
64      `"clipnorm"` or `"clipvalue"`.
65      `"clipnorm"` (float) clips gradients by norm; `"clipvalue"` (float) clips
66      gradients by value.
67
68  Usage:
69
70  >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1)
71  >>> var = tf.Variable(1.0)
72  >>> loss = lambda: (var ** 2)/2.0         # d(loss)/d(var1) = var1
73  >>> step_count = opt.minimize(loss, [var]).numpy()
74  >>> # Step is `- learning_rate * grad`
75  >>> var.numpy()
76  0.9
77
78  >>> opt = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
79  >>> var = tf.Variable(1.0)
80  >>> val0 = var.value()
81  >>> loss = lambda: (var ** 2)/2.0         # d(loss)/d(var1) = var1
82  >>> # First step is `- learning_rate * grad`
83  >>> step_count = opt.minimize(loss, [var]).numpy()
84  >>> val1 = var.value()
85  >>> (val0 - val1).numpy()
86  0.1
87  >>> # On later steps, step-size increases because of momentum
88  >>> step_count = opt.minimize(loss, [var]).numpy()
89  >>> val2 = var.value()
90  >>> (val1 - val2).numpy()
91  0.18
92
93  Reference:
94      - For `nesterov=True`, See [Sutskever et al., 2013](
95        http://jmlr.org/proceedings/papers/v28/sutskever13.pdf).
96  """
97
98  _HAS_AGGREGATE_GRAD = True
99
100  def __init__(self,
101               learning_rate=0.01,
102               momentum=0.0,
103               nesterov=False,
104               name="SGD",
105               **kwargs):
106    super(SGD, self).__init__(name, **kwargs)
107    self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
108    self._set_hyper("decay", self._initial_decay)
109
110    self._momentum = False
111    if isinstance(momentum, ops.Tensor) or callable(momentum) or momentum > 0:
112      self._momentum = True
113    if isinstance(momentum, (int, float)) and (momentum < 0 or momentum > 1):
114      raise ValueError("`momentum` must be between [0, 1].")
115    self._set_hyper("momentum", momentum)
116
117    self.nesterov = nesterov
118
119  def _create_slots(self, var_list):
120    if self._momentum:
121      for var in var_list:
122        self.add_slot(var, "momentum")
123
124  def _prepare_local(self, var_device, var_dtype, apply_state):
125    super(SGD, self)._prepare_local(var_device, var_dtype, apply_state)
126    apply_state[(var_device, var_dtype)]["momentum"] = array_ops.identity(
127        self._get_hyper("momentum", var_dtype))
128
129  def _resource_apply_dense(self, grad, var, apply_state=None):
130    var_device, var_dtype = var.device, var.dtype.base_dtype
131    coefficients = ((apply_state or {}).get((var_device, var_dtype))
132                    or self._fallback_apply_state(var_device, var_dtype))
133
134    if self._momentum:
135      momentum_var = self.get_slot(var, "momentum")
136      return gen_training_ops.ResourceApplyKerasMomentum(
137          var=var.handle,
138          accum=momentum_var.handle,
139          lr=coefficients["lr_t"],
140          grad=grad,
141          momentum=coefficients["momentum"],
142          use_locking=self._use_locking,
143          use_nesterov=self.nesterov)
144    else:
145      return gen_training_ops.ResourceApplyGradientDescent(
146          var=var.handle,
147          alpha=coefficients["lr_t"],
148          delta=grad,
149          use_locking=self._use_locking)
150
151  def _resource_apply_sparse_duplicate_indices(self, grad, var, indices,
152                                               **kwargs):
153    if self._momentum:
154      return super(SGD, self)._resource_apply_sparse_duplicate_indices(
155          grad, var, indices, **kwargs)
156    else:
157      var_device, var_dtype = var.device, var.dtype.base_dtype
158      coefficients = (kwargs.get("apply_state", {}).get((var_device, var_dtype))
159                      or self._fallback_apply_state(var_device, var_dtype))
160
161      return gen_resource_variable_ops.ResourceScatterAdd(
162          resource=var.handle,
163          indices=indices,
164          updates=-grad * coefficients["lr_t"])
165
166  def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
167    # This method is only needed for momentum optimization.
168    var_device, var_dtype = var.device, var.dtype.base_dtype
169    coefficients = ((apply_state or {}).get((var_device, var_dtype))
170                    or self._fallback_apply_state(var_device, var_dtype))
171
172    momentum_var = self.get_slot(var, "momentum")
173    return gen_training_ops.ResourceSparseApplyKerasMomentum(
174        var=var.handle,
175        accum=momentum_var.handle,
176        lr=coefficients["lr_t"],
177        grad=grad,
178        indices=indices,
179        momentum=coefficients["momentum"],
180        use_locking=self._use_locking,
181        use_nesterov=self.nesterov)
182
183  def get_config(self):
184    config = super(SGD, self).get_config()
185    config.update({
186        "learning_rate": self._serialize_hyperparameter("learning_rate"),
187        "decay": self._initial_decay,
188        "momentum": self._serialize_hyperparameter("momentum"),
189        "nesterov": self.nesterov,
190    })
191    return config
192