1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""Built-in optimizer classes.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22from six.moves import zip # pylint: disable=redefined-builtin 23 24from tensorflow.python.distribute import distribution_strategy_context 25from tensorflow.python.framework import ops 26from tensorflow.python.keras import backend as K 27from tensorflow.python.keras.optimizer_v2 import adadelta as adadelta_v2 28from tensorflow.python.keras.optimizer_v2 import adagrad as adagrad_v2 29from tensorflow.python.keras.optimizer_v2 import adam as adam_v2 30from tensorflow.python.keras.optimizer_v2 import adamax as adamax_v2 31from tensorflow.python.keras.optimizer_v2 import ftrl 32from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 33from tensorflow.python.keras.optimizer_v2 import nadam as nadam_v2 34from tensorflow.python.keras.optimizer_v2 import optimizer_v2 35from tensorflow.python.keras.optimizer_v2 import rmsprop as rmsprop_v2 36from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 37from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 38from tensorflow.python.ops import clip_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import state_ops 41from tensorflow.python.training import optimizer as tf_optimizer_module 42from tensorflow.python.training import training_util 43from tensorflow.python.training.tracking import base as trackable 44from tensorflow.python.util.tf_export import keras_export 45 46 47class Optimizer(object): 48 """Abstract optimizer base class. 49 50 Note: this is the parent class of all optimizers, not an actual optimizer 51 that can be used for training models. 52 53 All Keras optimizers support the following keyword arguments: 54 55 clipnorm: float >= 0. Gradients will be clipped 56 when their L2 norm exceeds this value. 57 clipvalue: float >= 0. Gradients will be clipped 58 when their absolute value exceeds this value. 59 """ 60 61 def __init__(self, **kwargs): 62 allowed_kwargs = {'clipnorm', 'clipvalue'} 63 for k in kwargs: 64 if k not in allowed_kwargs: 65 raise TypeError('Unexpected keyword argument ' 66 'passed to optimizer: ' + str(k)) 67 # checks that clipnorm >= 0 and clipvalue >= 0 68 if kwargs[k] < 0: 69 raise ValueError('Expected {} >= 0, received: {}'.format(k, kwargs[k])) 70 self.__dict__.update(kwargs) 71 self.updates = [] 72 self.weights = [] 73 74 def get_updates(self, loss, params): 75 raise NotImplementedError 76 77 def get_gradients(self, loss, params): 78 """Returns gradients of `loss` with respect to `params`. 79 80 Arguments: 81 loss: Loss tensor. 82 params: List of variables. 83 84 Returns: 85 List of gradient tensors. 86 87 Raises: 88 ValueError: In case any gradient cannot be computed (e.g. if gradient 89 function not implemented). 90 """ 91 grads = K.gradients(loss, params) 92 if any(g is None for g in grads): 93 raise ValueError('An operation has `None` for gradient. ' 94 'Please make sure that all of your ops have a ' 95 'gradient defined (i.e. are differentiable). ' 96 'Common ops without gradient: ' 97 'K.argmax, K.round, K.eval.') 98 if hasattr(self, 'clipnorm'): 99 grads = [clip_ops.clip_by_norm(g, self.clipnorm) for g in grads] 100 if hasattr(self, 'clipvalue'): 101 grads = [ 102 clip_ops.clip_by_value(g, -self.clipvalue, self.clipvalue) 103 for g in grads 104 ] 105 return grads 106 107 def set_weights(self, weights): 108 """Sets the weights of the optimizer, from Numpy arrays. 109 110 Should only be called after computing the gradients 111 (otherwise the optimizer has no weights). 112 113 Arguments: 114 weights: a list of Numpy arrays. The number of arrays and their shape 115 must match number of the dimensions of the weights of the optimizer 116 (i.e. it should match the output of `get_weights`). 117 118 Raises: 119 ValueError: in case of incompatible weight shapes. 120 """ 121 params = self.weights 122 if len(params) != len(weights): 123 raise ValueError('Length of the specified weight list (' + 124 str(len(weights)) + 125 ') does not match the number of weights ' 126 'of the optimizer (' + str(len(params)) + ')') 127 weight_value_tuples = [] 128 param_values = K.batch_get_value(params) 129 for pv, p, w in zip(param_values, params, weights): 130 if pv.shape != w.shape: 131 raise ValueError('Optimizer weight shape ' + str(pv.shape) + 132 ' not compatible with ' 133 'provided weight shape ' + str(w.shape)) 134 weight_value_tuples.append((p, w)) 135 K.batch_set_value(weight_value_tuples) 136 137 def get_weights(self): 138 """Returns the current value of the weights of the optimizer. 139 140 Returns: 141 A list of numpy arrays. 142 """ 143 return K.batch_get_value(self.weights) 144 145 def get_config(self): 146 config = {} 147 if hasattr(self, 'clipnorm'): 148 config['clipnorm'] = self.clipnorm 149 if hasattr(self, 'clipvalue'): 150 config['clipvalue'] = self.clipvalue 151 return config 152 153 @classmethod 154 def from_config(cls, config): 155 return cls(**config) 156 157 158class SGD(Optimizer): 159 """Stochastic gradient descent optimizer. 160 161 Includes support for momentum, 162 learning rate decay, and Nesterov momentum. 163 164 Arguments: 165 lr: float >= 0. Learning rate. 166 momentum: float >= 0. Parameter that accelerates SGD in the relevant 167 direction and dampens oscillations. 168 decay: float >= 0. Learning rate decay over each update. 169 nesterov: boolean. Whether to apply Nesterov momentum. 170 """ 171 172 def __init__(self, lr=0.01, momentum=0., decay=0., nesterov=False, **kwargs): 173 super(SGD, self).__init__(**kwargs) 174 with K.name_scope(self.__class__.__name__): 175 self.iterations = K.variable(0, dtype='int64', name='iterations') 176 self.lr = K.variable(lr, name='lr') 177 self.momentum = K.variable(momentum, name='momentum') 178 self.decay = K.variable(decay, name='decay') 179 self.initial_decay = decay 180 self.nesterov = nesterov 181 182 def get_updates(self, loss, params): 183 grads = self.get_gradients(loss, params) 184 self.updates = [state_ops.assign_add(self.iterations, 1)] 185 186 lr = self.lr 187 if self.initial_decay > 0: 188 lr = lr * ( # pylint: disable=g-no-augmented-assignment 189 1. / 190 (1. + 191 self.decay * math_ops.cast(self.iterations, K.dtype(self.decay)))) 192 # momentum 193 shapes = [K.int_shape(p) for p in params] 194 moments = [K.zeros(shape) for shape in shapes] 195 self.weights = [self.iterations] + moments 196 for p, g, m in zip(params, grads, moments): 197 v = self.momentum * m - lr * g # velocity 198 self.updates.append(state_ops.assign(m, v)) 199 200 if self.nesterov: 201 new_p = p + self.momentum * v - lr * g 202 else: 203 new_p = p + v 204 205 # Apply constraints. 206 if getattr(p, 'constraint', None) is not None: 207 new_p = p.constraint(new_p) 208 209 self.updates.append(state_ops.assign(p, new_p)) 210 return self.updates 211 212 def get_config(self): 213 config = { 214 'lr': float(K.get_value(self.lr)), 215 'momentum': float(K.get_value(self.momentum)), 216 'decay': float(K.get_value(self.decay)), 217 'nesterov': self.nesterov 218 } 219 base_config = super(SGD, self).get_config() 220 return dict(list(base_config.items()) + list(config.items())) 221 222 223class RMSprop(Optimizer): 224 """RMSProp optimizer. 225 226 It is recommended to leave the parameters of this optimizer 227 at their default values 228 (except the learning rate, which can be freely tuned). 229 230 Arguments: 231 lr: float >= 0. Learning rate. 232 rho: float >= 0. 233 epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 234 decay: float >= 0. Learning rate decay over each update. 235 """ 236 237 def __init__(self, lr=0.001, rho=0.9, epsilon=None, decay=0., **kwargs): 238 super(RMSprop, self).__init__(**kwargs) 239 with K.name_scope(self.__class__.__name__): 240 self.lr = K.variable(lr, name='lr') 241 self.rho = K.variable(rho, name='rho') 242 self.decay = K.variable(decay, name='decay') 243 self.iterations = K.variable(0, dtype='int64', name='iterations') 244 if epsilon is None: 245 epsilon = K.epsilon() 246 self.epsilon = epsilon 247 self.initial_decay = decay 248 249 def get_updates(self, loss, params): 250 grads = self.get_gradients(loss, params) 251 accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] 252 self.weights = accumulators 253 self.updates = [state_ops.assign_add(self.iterations, 1)] 254 255 lr = self.lr 256 if self.initial_decay > 0: 257 lr = lr * ( # pylint: disable=g-no-augmented-assignment 258 1. / 259 (1. + 260 self.decay * math_ops.cast(self.iterations, K.dtype(self.decay)))) 261 262 for p, g, a in zip(params, grads, accumulators): 263 # update accumulator 264 new_a = self.rho * a + (1. - self.rho) * math_ops.square(g) 265 self.updates.append(state_ops.assign(a, new_a)) 266 new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon) 267 268 # Apply constraints. 269 if getattr(p, 'constraint', None) is not None: 270 new_p = p.constraint(new_p) 271 272 self.updates.append(state_ops.assign(p, new_p)) 273 return self.updates 274 275 def get_config(self): 276 config = { 277 'lr': float(K.get_value(self.lr)), 278 'rho': float(K.get_value(self.rho)), 279 'decay': float(K.get_value(self.decay)), 280 'epsilon': self.epsilon 281 } 282 base_config = super(RMSprop, self).get_config() 283 return dict(list(base_config.items()) + list(config.items())) 284 285 286class Adagrad(Optimizer): 287 """Adagrad optimizer. 288 289 Adagrad is an optimizer with parameter-specific learning rates, 290 which are adapted relative to how frequently a parameter gets 291 updated during training. The more updates a parameter receives, 292 the smaller the updates. 293 294 It is recommended to leave the parameters of this optimizer 295 at their default values. 296 297 # Arguments 298 lr: float >= 0. Initial learning rate. 299 epsilon: float >= 0. If `None`, defaults to `K.epsilon()`. 300 decay: float >= 0. Learning rate decay over each update. 301 302 # References 303 - [Adaptive Subgradient Methods for Online Learning and Stochastic 304 Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) 305 """ 306 307 def __init__(self, lr=0.01, epsilon=None, decay=0., **kwargs): 308 super(Adagrad, self).__init__(**kwargs) 309 with K.name_scope(self.__class__.__name__): 310 self.lr = K.variable(lr, name='lr') 311 self.decay = K.variable(decay, name='decay') 312 self.iterations = K.variable(0, dtype='int64', name='iterations') 313 if epsilon is None: 314 epsilon = K.epsilon() 315 self.epsilon = epsilon 316 self.initial_decay = decay 317 318 def get_updates(self, loss, params): 319 grads = self.get_gradients(loss, params) 320 shapes = [K.int_shape(p) for p in params] 321 accumulators = [K.zeros(shape) for shape in shapes] 322 self.weights = accumulators 323 self.updates = [state_ops.assign_add(self.iterations, 1)] 324 325 lr = self.lr 326 if self.initial_decay > 0: 327 lr = lr * ( # pylint: disable=g-no-augmented-assignment 328 1. / 329 (1. + 330 self.decay * math_ops.cast(self.iterations, K.dtype(self.decay)))) 331 332 for p, g, a in zip(params, grads, accumulators): 333 new_a = a + math_ops.square(g) # update accumulator 334 self.updates.append(state_ops.assign(a, new_a)) 335 new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon) 336 337 # Apply constraints. 338 if getattr(p, 'constraint', None) is not None: 339 new_p = p.constraint(new_p) 340 341 self.updates.append(state_ops.assign(p, new_p)) 342 return self.updates 343 344 def get_config(self): 345 config = { 346 'lr': float(K.get_value(self.lr)), 347 'decay': float(K.get_value(self.decay)), 348 'epsilon': self.epsilon 349 } 350 base_config = super(Adagrad, self).get_config() 351 return dict(list(base_config.items()) + list(config.items())) 352 353 354class Adadelta(Optimizer): 355 """Adadelta optimizer. 356 357 Adadelta is a more robust extension of Adagrad 358 that adapts learning rates based on a moving window of gradient updates, 359 instead of accumulating all past gradients. This way, Adadelta continues 360 learning even when many updates have been done. Compared to Adagrad, in the 361 original version of Adadelta you don't have to set an initial learning 362 rate. In this version, initial learning rate and decay factor can 363 be set, as in most other Keras optimizers. 364 365 It is recommended to leave the parameters of this optimizer 366 at their default values. 367 368 # Arguments 369 lr: float >= 0. Initial learning rate, defaults to 1. 370 It is recommended to leave it at the default value. 371 rho: float >= 0. Adadelta decay factor, corresponding to fraction of 372 gradient to keep at each time step. 373 epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 374 decay: float >= 0. Initial learning rate decay. 375 376 # References 377 - [Adadelta - an adaptive learning rate 378 method](http://arxiv.org/abs/1212.5701) 379 """ 380 381 def __init__(self, lr=1.0, rho=0.95, epsilon=None, decay=0., **kwargs): 382 super(Adadelta, self).__init__(**kwargs) 383 with K.name_scope(self.__class__.__name__): 384 self.lr = K.variable(lr, name='lr') 385 self.decay = K.variable(decay, name='decay') 386 self.iterations = K.variable(0, dtype='int64', name='iterations') 387 if epsilon is None: 388 epsilon = K.epsilon() 389 self.rho = rho 390 self.epsilon = epsilon 391 self.initial_decay = decay 392 393 def get_updates(self, loss, params): 394 grads = self.get_gradients(loss, params) 395 shapes = [K.int_shape(p) for p in params] 396 accumulators = [K.zeros(shape) for shape in shapes] 397 delta_accumulators = [K.zeros(shape) for shape in shapes] 398 self.weights = accumulators + delta_accumulators 399 self.updates = [state_ops.assign_add(self.iterations, 1)] 400 401 lr = self.lr 402 if self.initial_decay > 0: 403 lr = lr * ( # pylint: disable=g-no-augmented-assignment 404 1. / 405 (1. + 406 self.decay * math_ops.cast(self.iterations, K.dtype(self.decay)))) 407 408 for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators): 409 # update accumulator 410 new_a = self.rho * a + (1. - self.rho) * math_ops.square(g) 411 self.updates.append(state_ops.assign(a, new_a)) 412 413 # use the new accumulator and the *old* delta_accumulator 414 update = g * K.sqrt(d_a + self.epsilon) / K.sqrt(new_a + self.epsilon) 415 new_p = p - lr * update 416 417 # Apply constraints. 418 if getattr(p, 'constraint', None) is not None: 419 new_p = p.constraint(new_p) 420 421 self.updates.append(state_ops.assign(p, new_p)) 422 423 # update delta_accumulator 424 new_d_a = self.rho * d_a + (1 - self.rho) * math_ops.square(update) 425 self.updates.append(state_ops.assign(d_a, new_d_a)) 426 return self.updates 427 428 def get_config(self): 429 config = { 430 'lr': float(K.get_value(self.lr)), 431 'rho': self.rho, 432 'decay': float(K.get_value(self.decay)), 433 'epsilon': self.epsilon 434 } 435 base_config = super(Adadelta, self).get_config() 436 return dict(list(base_config.items()) + list(config.items())) 437 438 439class Adam(Optimizer): 440 """Adam optimizer. 441 442 Default parameters follow those provided in the original paper. 443 444 Arguments: 445 lr: float >= 0. Learning rate. 446 beta_1: float, 0 < beta < 1. Generally close to 1. 447 beta_2: float, 0 < beta < 1. Generally close to 1. 448 epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 449 decay: float >= 0. Learning rate decay over each update. 450 amsgrad: boolean. Whether to apply the AMSGrad variant of this algorithm 451 from the paper "On the Convergence of Adam and Beyond". 452 """ 453 454 def __init__(self, 455 lr=0.001, 456 beta_1=0.9, 457 beta_2=0.999, 458 epsilon=None, 459 decay=0., 460 amsgrad=False, 461 **kwargs): 462 super(Adam, self).__init__(**kwargs) 463 with K.name_scope(self.__class__.__name__): 464 self.iterations = K.variable(0, dtype='int64', name='iterations') 465 self.lr = K.variable(lr, name='lr') 466 self.beta_1 = K.variable(beta_1, name='beta_1') 467 self.beta_2 = K.variable(beta_2, name='beta_2') 468 self.decay = K.variable(decay, name='decay') 469 if epsilon is None: 470 epsilon = K.epsilon() 471 self.epsilon = epsilon 472 self.initial_decay = decay 473 self.amsgrad = amsgrad 474 475 def get_updates(self, loss, params): 476 grads = self.get_gradients(loss, params) 477 self.updates = [] 478 479 lr = self.lr 480 if self.initial_decay > 0: 481 lr = lr * ( # pylint: disable=g-no-augmented-assignment 482 1. / 483 (1. + 484 self.decay * math_ops.cast(self.iterations, K.dtype(self.decay)))) 485 486 with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]): 487 t = math_ops.cast(self.iterations, K.floatx()) 488 lr_t = lr * ( 489 K.sqrt(1. - math_ops.pow(self.beta_2, t)) / 490 (1. - math_ops.pow(self.beta_1, t))) 491 492 ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] 493 vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] 494 if self.amsgrad: 495 vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] 496 else: 497 vhats = [K.zeros(1) for _ in params] 498 self.weights = [self.iterations] + ms + vs + vhats 499 500 for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): 501 m_t = (self.beta_1 * m) + (1. - self.beta_1) * g 502 v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g) 503 if self.amsgrad: 504 vhat_t = math_ops.maximum(vhat, v_t) 505 p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon) 506 self.updates.append(state_ops.assign(vhat, vhat_t)) 507 else: 508 p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) 509 510 self.updates.append(state_ops.assign(m, m_t)) 511 self.updates.append(state_ops.assign(v, v_t)) 512 new_p = p_t 513 514 # Apply constraints. 515 if getattr(p, 'constraint', None) is not None: 516 new_p = p.constraint(new_p) 517 518 self.updates.append(state_ops.assign(p, new_p)) 519 return self.updates 520 521 def get_config(self): 522 config = { 523 'lr': float(K.get_value(self.lr)), 524 'beta_1': float(K.get_value(self.beta_1)), 525 'beta_2': float(K.get_value(self.beta_2)), 526 'decay': float(K.get_value(self.decay)), 527 'epsilon': self.epsilon, 528 'amsgrad': self.amsgrad 529 } 530 base_config = super(Adam, self).get_config() 531 return dict(list(base_config.items()) + list(config.items())) 532 533 534class Adamax(Optimizer): 535 """Adamax optimizer from Adam paper's Section 7. 536 537 It is a variant of Adam based on the infinity norm. 538 Default parameters follow those provided in the paper. 539 540 Arguments: 541 lr: float >= 0. Learning rate. 542 beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1. 543 epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 544 decay: float >= 0. Learning rate decay over each update. 545 """ 546 547 def __init__(self, 548 lr=0.002, 549 beta_1=0.9, 550 beta_2=0.999, 551 epsilon=None, 552 decay=0., 553 **kwargs): 554 super(Adamax, self).__init__(**kwargs) 555 with K.name_scope(self.__class__.__name__): 556 self.iterations = K.variable(0, dtype='int64', name='iterations') 557 self.lr = K.variable(lr, name='lr') 558 self.beta_1 = K.variable(beta_1, name='beta_1') 559 self.beta_2 = K.variable(beta_2, name='beta_2') 560 self.decay = K.variable(decay, name='decay') 561 if epsilon is None: 562 epsilon = K.epsilon() 563 self.epsilon = epsilon 564 self.initial_decay = decay 565 566 def get_updates(self, loss, params): 567 grads = self.get_gradients(loss, params) 568 self.updates = [] 569 570 lr = self.lr 571 if self.initial_decay > 0: 572 lr = lr * ( # pylint: disable=g-no-augmented-assignment 573 1. / 574 (1. + 575 self.decay * math_ops.cast(self.iterations, K.dtype(self.decay)))) 576 577 with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]): 578 t = math_ops.cast(self.iterations, K.floatx()) 579 lr_t = lr / (1. - math_ops.pow(self.beta_1, t)) 580 581 shapes = [K.int_shape(p) for p in params] 582 # zero init of 1st moment 583 ms = [K.zeros(shape) for shape in shapes] 584 # zero init of exponentially weighted infinity norm 585 us = [K.zeros(shape) for shape in shapes] 586 self.weights = [self.iterations] + ms + us 587 588 for p, g, m, u in zip(params, grads, ms, us): 589 590 m_t = (self.beta_1 * m) + (1. - self.beta_1) * g 591 u_t = math_ops.maximum(self.beta_2 * u, math_ops.abs(g)) 592 p_t = p - lr_t * m_t / (u_t + self.epsilon) 593 594 self.updates.append(state_ops.assign(m, m_t)) 595 self.updates.append(state_ops.assign(u, u_t)) 596 new_p = p_t 597 598 # Apply constraints. 599 if getattr(p, 'constraint', None) is not None: 600 new_p = p.constraint(new_p) 601 602 self.updates.append(state_ops.assign(p, new_p)) 603 return self.updates 604 605 def get_config(self): 606 config = { 607 'lr': float(K.get_value(self.lr)), 608 'beta_1': float(K.get_value(self.beta_1)), 609 'beta_2': float(K.get_value(self.beta_2)), 610 'decay': float(K.get_value(self.decay)), 611 'epsilon': self.epsilon 612 } 613 base_config = super(Adamax, self).get_config() 614 return dict(list(base_config.items()) + list(config.items())) 615 616 617class Nadam(Optimizer): 618 """Nesterov Adam optimizer. 619 620 Much like Adam is essentially RMSprop with momentum, 621 Nadam is Adam RMSprop with Nesterov momentum. 622 623 Default parameters follow those provided in the paper. 624 It is recommended to leave the parameters of this optimizer 625 at their default values. 626 627 Arguments: 628 lr: float >= 0. Learning rate. 629 beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1. 630 epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. 631 """ 632 633 def __init__(self, 634 lr=0.002, 635 beta_1=0.9, 636 beta_2=0.999, 637 epsilon=None, 638 schedule_decay=0.004, 639 **kwargs): 640 super(Nadam, self).__init__(**kwargs) 641 with K.name_scope(self.__class__.__name__): 642 self.iterations = K.variable(0, dtype='int64', name='iterations') 643 self.m_schedule = K.variable(1., name='m_schedule') 644 self.lr = K.variable(lr, name='lr') 645 self.beta_1 = K.variable(beta_1, name='beta_1') 646 self.beta_2 = K.variable(beta_2, name='beta_2') 647 if epsilon is None: 648 epsilon = K.epsilon() 649 self.epsilon = epsilon 650 self.schedule_decay = schedule_decay 651 652 def get_updates(self, loss, params): 653 grads = self.get_gradients(loss, params) 654 self.updates = [] 655 656 with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]): 657 t = math_ops.cast(self.iterations, K.floatx()) 658 659 # Due to the recommendations in [2], i.e. warming momentum schedule 660 momentum_cache_t = self.beta_1 * ( 661 1. - 0.5 * 662 (math_ops.pow(K.cast_to_floatx(0.96), t * self.schedule_decay))) 663 momentum_cache_t_1 = self.beta_1 * ( 664 1. - 0.5 * 665 (math_ops.pow(K.cast_to_floatx(0.96), (t + 1) * self.schedule_decay))) 666 m_schedule_new = self.m_schedule * momentum_cache_t 667 m_schedule_next = self.m_schedule * momentum_cache_t * momentum_cache_t_1 668 self.updates.append((self.m_schedule, m_schedule_new)) 669 670 shapes = [K.int_shape(p) for p in params] 671 ms = [K.zeros(shape) for shape in shapes] 672 vs = [K.zeros(shape) for shape in shapes] 673 674 self.weights = [self.iterations, self.m_schedule] + ms + vs 675 676 for p, g, m, v in zip(params, grads, ms, vs): 677 # the following equations given in [1] 678 g_prime = g / (1. - m_schedule_new) 679 m_t = self.beta_1 * m + (1. - self.beta_1) * g 680 m_t_prime = m_t / (1. - m_schedule_next) 681 v_t = self.beta_2 * v + (1. - self.beta_2) * math_ops.square(g) 682 v_t_prime = v_t / (1. - math_ops.pow(self.beta_2, t)) 683 m_t_bar = (1. - 684 momentum_cache_t) * g_prime + momentum_cache_t_1 * m_t_prime 685 686 self.updates.append(state_ops.assign(m, m_t)) 687 self.updates.append(state_ops.assign(v, v_t)) 688 689 p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon) 690 new_p = p_t 691 692 # Apply constraints. 693 if getattr(p, 'constraint', None) is not None: 694 new_p = p.constraint(new_p) 695 696 self.updates.append(state_ops.assign(p, new_p)) 697 return self.updates 698 699 def get_config(self): 700 config = { 701 'lr': float(K.get_value(self.lr)), 702 'beta_1': float(K.get_value(self.beta_1)), 703 'beta_2': float(K.get_value(self.beta_2)), 704 'epsilon': self.epsilon, 705 'schedule_decay': self.schedule_decay 706 } 707 base_config = super(Nadam, self).get_config() 708 return dict(list(base_config.items()) + list(config.items())) 709 710 711class TFOptimizer(Optimizer, trackable.Trackable): 712 """Wrapper class for native TensorFlow optimizers.""" 713 714 def __init__(self, optimizer, iterations=None): # pylint: disable=super-init-not-called 715 self.optimizer = optimizer 716 self._track_trackable(optimizer, name='optimizer') 717 if iterations is None: 718 with K.name_scope(self.__class__.__name__): 719 self.iterations = K.variable(0, dtype='int64', name='iterations') 720 else: 721 self.iterations = iterations 722 self._track_trackable(self.iterations, name='global_step') 723 724 def apply_gradients(self, grads): 725 self.optimizer.apply_gradients(grads, global_step=self.iterations) 726 727 def get_grads(self, loss, params): 728 return self.optimizer.compute_gradients(loss, params) 729 730 def get_updates(self, loss, params): 731 if distribution_strategy_context.has_strategy(): 732 self.updates = [] 733 734 if not params: 735 # After the model vars have been created, the second call to get_updates 736 # is called with params as an empty list. This ensures that we call 737 # compute_gradients with params=None. 738 grads = self.optimizer.compute_gradients(loss) 739 else: 740 grads = self.optimizer.compute_gradients(loss, params) 741 global_step = training_util.get_global_step() 742 opt_update = self.optimizer.apply_gradients(grads, global_step) 743 else: 744 if not params: 745 self.updates = [state_ops.assign_add(self.iterations, 1)] 746 return self.updates 747 748 # Updates list starts out empty because the iterations variable is 749 # incremented in optimizer.apply_gradients() 750 self.updates = [] 751 grads = self.optimizer.compute_gradients(loss, params) 752 opt_update = self.optimizer.apply_gradients( 753 grads, global_step=self.iterations) 754 755 self.updates.append(opt_update) 756 return self.updates 757 758 @property 759 def weights(self): 760 raise NotImplementedError 761 762 def get_config(self): 763 raise NotImplementedError 764 765 def from_config(self, config): 766 raise NotImplementedError 767 768 769# Aliases. 770 771sgd = SGD 772rmsprop = RMSprop 773adagrad = Adagrad 774adadelta = Adadelta 775adam = Adam 776adamax = Adamax 777nadam = Nadam 778 779 780@keras_export('keras.optimizers.serialize') 781def serialize(optimizer): 782 return serialize_keras_object(optimizer) 783 784 785@keras_export('keras.optimizers.deserialize') 786def deserialize(config, custom_objects=None): 787 """Inverse of the `serialize` function. 788 789 Arguments: 790 config: Optimizer configuration dictionary. 791 custom_objects: Optional dictionary mapping names (strings) to custom 792 objects (classes and functions) to be considered during deserialization. 793 794 Returns: 795 A Keras Optimizer instance. 796 """ 797 # loss_scale_optimizer has a direct dependency of optimizer, import here 798 # rather than top to avoid the cyclic dependency. 799 from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer # pylint: disable=g-import-not-at-top 800 all_classes = { 801 'adadelta': adadelta_v2.Adadelta, 802 'adagrad': adagrad_v2.Adagrad, 803 'adam': adam_v2.Adam, 804 'adamax': adamax_v2.Adamax, 805 'nadam': nadam_v2.Nadam, 806 'rmsprop': rmsprop_v2.RMSprop, 807 'sgd': gradient_descent_v2.SGD, 808 'ftrl': ftrl.Ftrl, 809 'lossscaleoptimizer': loss_scale_optimizer.LossScaleOptimizer, 810 } 811 812 # Make deserialization case-insensitive for built-in optimizers. 813 if config['class_name'].lower() in all_classes: 814 config['class_name'] = config['class_name'].lower() 815 return deserialize_keras_object( 816 config, 817 module_objects=all_classes, 818 custom_objects=custom_objects, 819 printable_module_name='optimizer') 820 821 822@keras_export('keras.optimizers.get') 823def get(identifier): 824 """Retrieves a Keras Optimizer instance. 825 826 Arguments: 827 identifier: Optimizer identifier, one of 828 - String: name of an optimizer 829 - Dictionary: configuration dictionary. - Keras Optimizer instance (it 830 will be returned unchanged). - TensorFlow Optimizer instance (it 831 will be wrapped as a Keras Optimizer). 832 833 Returns: 834 A Keras Optimizer instance. 835 836 Raises: 837 ValueError: If `identifier` cannot be interpreted. 838 """ 839 if isinstance(identifier, (Optimizer, optimizer_v2.OptimizerV2)): 840 return identifier 841 # Wrap TF optimizer instances 842 elif isinstance(identifier, tf_optimizer_module.Optimizer): 843 opt = TFOptimizer(identifier) 844 K.track_tf_optimizer(opt) 845 return opt 846 elif isinstance(identifier, dict): 847 return deserialize(identifier) 848 elif isinstance(identifier, six.string_types): 849 config = {'class_name': str(identifier), 'config': {}} 850 return deserialize(config) 851 else: 852 raise ValueError('Could not interpret optimizer identifier:', identifier) 853