# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The Shampoo Optimizer. Variant of Adagrad using one preconditioner matrix per variable dimension. For details, see https://arxiv.org/abs/1802.09568 """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from tensorflow.contrib.opt.python.training import matrix_functions from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging from tensorflow.python.training import optimizer def GetParam(var, timestep): if callable(var): return var(timestep) else: return var class ShampooOptimizer(optimizer.Optimizer): """The Shampoo Optimizer Variant of Adagrad using one preconditioner matrix per variable dimension. For details, see https://arxiv.org/abs/1802.09568 gbar is time-weighted accumulated gradient: gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] mat_gbar is time-weighted accumulated gradient square: mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + mat_gbar_weight[t] * gg_j[t] where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation) Update rule: w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t] Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the j'th dimension of gbar[t] with the first dimension of mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter, and n = rank of the variable. Prod_j represents doing this contraction for all j in 0..n-1. Typically learning_rate is constant, but could be time dependent by passing a lambda function that depends on step. """ def __init__(self, global_step=0, max_matrix_size=768, gbar_decay=0.0, gbar_weight=1.0, mat_gbar_decay=1.0, mat_gbar_weight=1.0, learning_rate=1.0, svd_interval=1, precond_update_interval=1, epsilon=1e-4, alpha=0.5, use_iterative_root=False, use_locking=False, name="Shampoo"): """Default values of the various hyper-parameters. gbar_decay, gbar_weight etc. can be a float or a time varying parameter. For time-varying parameters use e.g. "lambda T: T / (T + 1.0)" where the expression in the lambda is a tensorflow expression Args: global_step: tensorflow variable indicating the step. max_matrix_size: We do not perform SVD for matrices larger than this. gbar_decay: gbar_weight: Used to update gbar: gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] mat_gbar_decay: mat_gbar_weight: Used to update mat_gbar: mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] + mat_gbar_weight[t] * gg_j[t] learning_rate: Similar to SGD svd_interval: We should do SVD after this many steps. Default = 1, i.e. every step. Usually 20 leads to no loss of accuracy, and 50 or 100 is also OK. May also want more often early, and less often later - set in caller as for example: "svd_interval = lambda(T): tf.cond( T < 2000, lambda: 20.0, lambda: 1000.0)" precond_update_interval: We should update the preconditioners after this many steps. Default = 1. Usually less than svd_interval. epsilon: epsilon * I_n is added to each mat_gbar_j for stability for non-diagonal version of shampoo. alpha: total power of the preconditioners. use_iterative_root: should the optimizer use SVD (faster) or the iterative root method (for TPU) for finding the roots of PSD matrices. use_locking: name: name of optimizer. """ super(ShampooOptimizer, self).__init__(use_locking, name) self._global_step = math_ops.cast(global_step, dtypes.float32) self._max_matrix_size = max_matrix_size self._gbar_decay = gbar_decay self._gbar_weight = gbar_weight self._mat_gbar_decay = mat_gbar_decay self._mat_gbar_weight = mat_gbar_weight self._learning_rate = learning_rate self._svd_interval = svd_interval self._precond_update_interval = precond_update_interval self._epsilon = epsilon self._alpha = alpha self._use_iterative_root = use_iterative_root self._name = name def _create_slots(self, var_list): for v in var_list: with ops.colocate_with(v): _ = self._zeros_slot(v, "gbar", self._name) shape = np.array(v.get_shape()) for i, d in enumerate(shape): d_tensor = ops.convert_to_tensor(d) if d <= self._max_matrix_size: mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor)) if self._svd_interval > 1: _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor), "H_" + str(i), self._name) else: mat_g_init = array_ops.zeros([d_tensor]) _ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i), self._name) def _resource_apply_dense(self, grad, var): return self._apply_dense(grad, var) def _apply_dense(self, grad, var): return self._apply_gradient(grad, var) def _resource_apply_sparse(self, grad_values, var, grad_indices): return self._apply_sparse_shared(grad_values, grad_indices, var) def _apply_sparse(self, grad, var): return self._apply_sparse_shared(grad.values, grad.indices, var) def _apply_sparse_shared(self, grad_values, grad_indices, var): if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0: # The dimension is small enough, we can make the variable dense and # do a dense update dense_grad = array_ops.scatter_nd( array_ops.expand_dims(grad_indices, axis=1), grad_values, array_ops.shape(var, out_type=grad_indices.dtype)) return self._apply_gradient(dense_grad, var) return self._apply_gradient(grad_values, var, grad_indices) def _weighted_average(self, var, weight, weight_t, rest): """Computes exponential weighted average: var = weight_t * var + rest. Important to ensure that var does not occur in rest, otherwise we can get race conditions in a distributed setting. Args: var: variable to be updated weight: parameter to be checked. If it is a constant, we can optimize. weight_t: current value of parameter, used for weighting rest: the remaining tensor to be added Returns: updated variable. """ if weight == 0.0: return rest # no need to update var, we will never use it. if weight == 1.0: # common case return state_ops.assign_add(var, rest) # The op below can cause race conditions in a distributed setting, # since computing weight_t * var + rest can take some time, during # which var may be set by another worker. To prevent this, it should # be implemented as a C++ op. return var.assign_add((weight_t - 1) * var + rest) def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay, mat_gbar_weight, i): """Updates the cumulative outer products of the gradients. Args: mat_g: the matrix to be updated grad: the gradient of the variable axes: a list of k-1 integers 0 to k-1, except i mat_gbar_decay: constant for weighted average: mat_g = mat_g * decay + grad * weight mat_gbar_weight: constant for weighted average i: index of dimension to be updated. Returns: updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd thus grad_outer is a matrix d_i x d_i, where d_i is the size of the i'th dimension of g. Alternate view: If mat_i(grad) is the flattening of grad to a d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then grad_outer = mat_i(grad) mat_i(grad).transpose """ grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes), name="grad_outer_" + str(i)) return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay, mat_gbar_weight * grad_outer) def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name): """Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix. Args: var: the variable we are updating. mat_g: the symmetric PSD matrix whose power it to be computed mat_g_size: size of mat_g alpha: a real number mat_h_slot_name: name of slot to store the power, if needed. Returns: mat_h = mat_g^alpha Stores mat_h in the appropriate slot, if it exists. Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig. """ if mat_g_size == 1: mat_h = math_ops.pow(mat_g + self._epsilon, alpha) else: damping = self._epsilon * linalg_ops.eye( math_ops.cast(mat_g_size, dtypes.int32)) diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True) mat_h = math_ops.matmul( mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha), array_ops.transpose(mat_u)) if mat_h_slot_name is not None: return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) return mat_h def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name, iter_count=100, epsilon=1e-6): """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.""" mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size, iter_count, self._epsilon) mat_h = matrix_functions.matrix_inverse_pth_root( mat_g_sqrt, mat_g_size, 2 * alpha, iter_count, epsilon, ridge_epsilon=0.0) if mat_h_slot_name is not None: return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) return mat_h def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None): """Just a switch between the iterative power vs svd.""" with ops.name_scope("matrix_iterative_power"): if self._use_iterative_root: return self._compute_power_iter(var, mat_g, mat_g_size, alpha, mat_h_slot_name) else: return self._compute_power_svd(var, mat_g, mat_g_size, alpha, mat_h_slot_name) def _apply_gradient(self, grad, var, indices=None): """The main function to update a variable. Args: grad: A Tensor containing gradient to apply. var: A Tensor containing the variable to update. indices: An array of integers, for sparse update. Returns: Updated variable var = var - learning_rate * preconditioner * grad If the gradient is dense, var and grad have the same shape. If the update is sparse, then the first dimension of the gradient and var may differ, others are all the same. In this case the indices array provides the set of indices of the variable which are to be updated with each row of the gradient. """ global_step = self._global_step + 1 # Update accumulated weighted average of gradients gbar = self.get_slot(var, "gbar") gbar_decay_t = GetParam(self._gbar_decay, global_step) gbar_weight_t = GetParam(self._gbar_weight, global_step) if indices is not None: # Note - the sparse update is not easily implemented, since the # algorithm needs all indices of gbar to be updated # if mat_gbar_decay != 1 or mat_gbar_decay != 0. # One way to make mat_gbar_decay = 1 is by rescaling. # If we want the update: # G_{t+1} = a_{t+1} G_t + b_{t+1} w_t # define: # r_{t+1} = a_{t+1} * r_t # h_t = G_t / r_t # Then: # h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t # So we get the mat_gbar_decay = 1 as desired. # We can implement this in a future version as needed. # However we still need gbar_decay = 0, otherwise all indices # of the variable will need to be updated. if self._gbar_decay != 0.0: tf_logging.warning("Not applying momentum for variable: %s" % var.name) gbar_updated = grad else: gbar_updated = self._weighted_average(gbar, self._gbar_decay, gbar_decay_t, gbar_weight_t * grad) # Update the preconditioners and compute the preconditioned gradient shape = var.get_shape() mat_g_list = [] for i in range(len(shape)): mat_g_list.append(self.get_slot(var, "Gbar_" + str(i))) mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step) mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step) preconditioned_grad = gbar_updated v_rank = len(mat_g_list) neg_alpha = - GetParam(self._alpha, global_step) / v_rank svd_interval = GetParam(self._svd_interval, global_step) precond_update_interval = GetParam(self._precond_update_interval, global_step) for i, mat_g in enumerate(mat_g_list): # axes is the list of indices to reduce - everything but the current i. axes = list(range(i)) + list(range(i+1, v_rank)) if shape[i] <= self._max_matrix_size: # If the tensor size is sufficiently small perform full Shampoo update # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this # is not strictly correct. However we will use it for now, and # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg) # pylint: disable=g-long-lambda,cell-var-from-loop mat_g_updated = control_flow_ops.cond( math_ops.mod(global_step, precond_update_interval) < 1, lambda: self._update_mat_g( mat_g, grad, axes, mat_gbar_decay_t, mat_gbar_weight_t * precond_update_interval, i), lambda: mat_g) mat_g_updated = mat_g_updated / float(shape[i].value) if self._svd_interval == 1: mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha) else: mat_h = control_flow_ops.cond( math_ops.mod(global_step, svd_interval) < 1, lambda: self._compute_power(var, mat_g_updated, shape[i], neg_alpha, "H_" + str(i)), lambda: self.get_slot(var, "H_" + str(i))) # mat_h is a square matrix of size d_i x d_i # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor # After contraction with a d_i x d_i tensor # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor # (the first dimension is contracted out, and the second dimension of # mat_h is appended). After going through all the indices, it becomes # a d_0 x ... x d_n tensor again. preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h, axes=([0], [0]), name="precond_" + str(i)) else: # Tensor size is too large -- perform diagonal Shampoo update # Only normalize non-vector cases. if axes: normalizer = 1.0 if indices is not None else float(shape[i].value) grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer else: grad_outer = grad * grad if i == 0 and indices is not None: assert self._mat_gbar_decay == 1.0 mat_g_updated = state_ops.scatter_add(mat_g, indices, mat_gbar_weight_t * grad_outer) mat_g_updated_slice = array_ops.gather(mat_g_updated, indices) mat_h = array_ops.where( math_ops.greater(mat_g_updated_slice, 0), math_ops.pow(mat_g_updated_slice, neg_alpha), array_ops.zeros_like(mat_g_updated_slice)) else: mat_g_updated = self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay_t, mat_gbar_weight_t * grad_outer) mat_h = array_ops.where( math_ops.greater(mat_g_updated, 0), math_ops.pow(mat_g_updated, neg_alpha), array_ops.zeros_like(mat_g_updated)) # Need to do the transpose to ensure that the tensor becomes # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above. preconditioned_grad = array_ops.transpose( preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h # Update the variable based on the Shampoo update learning_rate_t = GetParam(self._learning_rate, global_step) if indices is not None: var_updated = state_ops.scatter_add( var, indices, -learning_rate_t * preconditioned_grad) else: var_updated = state_ops.assign_sub(var, learning_rate_t * preconditioned_grad) return var_updated