1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""The Shampoo Optimizer. 17 18Variant of Adagrad using one preconditioner matrix per variable dimension. 19For details, see https://arxiv.org/abs/1802.09568 20""" 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import numpy as np 26from tensorflow.contrib.opt.python.training import matrix_functions 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import linalg_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import state_ops 34from tensorflow.python.platform import tf_logging 35from tensorflow.python.training import optimizer 36 37 38def GetParam(var, timestep): 39 if callable(var): 40 return var(timestep) 41 else: 42 return var 43 44 45class ShampooOptimizer(optimizer.Optimizer): 46 """The Shampoo Optimizer 47 48 Variant of Adagrad using one preconditioner matrix per variable dimension. 49 For details, see https://arxiv.org/abs/1802.09568 50 51 gbar is time-weighted accumulated gradient: 52 gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] 53 54 mat_gbar is time-weighted accumulated gradient square: 55 mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] 56 + mat_gbar_weight[t] * gg_j[t] 57 where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation) 58 59 Update rule: 60 w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t] 61 Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the 62 j'th dimension of gbar[t] with the first dimension of 63 mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter, 64 and n = rank of the variable. 65 Prod_j represents doing this contraction for all j in 0..n-1. 66 67 Typically learning_rate is constant, but could be time dependent by passing 68 a lambda function that depends on step. 69 """ 70 71 def __init__(self, 72 global_step=0, 73 max_matrix_size=768, 74 gbar_decay=0.0, 75 gbar_weight=1.0, 76 mat_gbar_decay=1.0, 77 mat_gbar_weight=1.0, 78 learning_rate=1.0, 79 svd_interval=1, 80 precond_update_interval=1, 81 epsilon=1e-4, 82 alpha=0.5, 83 use_iterative_root=False, 84 use_locking=False, 85 name="Shampoo"): 86 """Default values of the various hyper-parameters. 87 88 gbar_decay, gbar_weight etc. can be a float or a time varying parameter. 89 For time-varying parameters use e.g. "lambda T: T / (T + 1.0)" 90 where the expression in the lambda is a tensorflow expression 91 92 Args: 93 global_step: tensorflow variable indicating the step. 94 max_matrix_size: We do not perform SVD for matrices larger than this. 95 gbar_decay: 96 gbar_weight: Used to update gbar: 97 gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t] 98 mat_gbar_decay: 99 mat_gbar_weight: Used to update mat_gbar: 100 mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1] 101 + mat_gbar_weight[t] * gg_j[t] 102 learning_rate: Similar to SGD 103 svd_interval: We should do SVD after this many steps. Default = 1, i.e. 104 every step. Usually 20 leads to no loss of accuracy, and 105 50 or 100 is also OK. May also want more often early, 106 and less often later - set in caller as for example: 107 "svd_interval = lambda(T): tf.cond( 108 T < 2000, lambda: 20.0, lambda: 1000.0)" 109 precond_update_interval: We should update the preconditioners after 110 this many steps. Default = 1. Usually less than 111 svd_interval. 112 epsilon: epsilon * I_n is added to each mat_gbar_j for stability for 113 non-diagonal version of shampoo. 114 alpha: total power of the preconditioners. 115 use_iterative_root: should the optimizer use SVD (faster) or the 116 iterative root method (for TPU) for finding the 117 roots of PSD matrices. 118 use_locking: 119 name: name of optimizer. 120 """ 121 122 super(ShampooOptimizer, self).__init__(use_locking, name) 123 124 self._global_step = math_ops.cast(global_step, dtypes.float32) 125 self._max_matrix_size = max_matrix_size 126 self._gbar_decay = gbar_decay 127 self._gbar_weight = gbar_weight 128 self._mat_gbar_decay = mat_gbar_decay 129 self._mat_gbar_weight = mat_gbar_weight 130 self._learning_rate = learning_rate 131 self._svd_interval = svd_interval 132 self._precond_update_interval = precond_update_interval 133 self._epsilon = epsilon 134 self._alpha = alpha 135 self._use_iterative_root = use_iterative_root 136 self._name = name 137 138 def _create_slots(self, var_list): 139 for v in var_list: 140 with ops.colocate_with(v): 141 _ = self._zeros_slot(v, "gbar", self._name) 142 shape = np.array(v.get_shape()) 143 for i, d in enumerate(shape): 144 d_tensor = ops.convert_to_tensor(d) 145 if d <= self._max_matrix_size: 146 mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor)) 147 if self._svd_interval > 1: 148 _ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor), 149 "H_" + str(i), self._name) 150 else: 151 mat_g_init = array_ops.zeros([d_tensor]) 152 153 _ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i), 154 self._name) 155 156 def _resource_apply_dense(self, grad, var): 157 return self._apply_dense(grad, var) 158 159 def _apply_dense(self, grad, var): 160 return self._apply_gradient(grad, var) 161 162 def _resource_apply_sparse(self, grad_values, var, grad_indices): 163 return self._apply_sparse_shared(grad_values, grad_indices, var) 164 165 def _apply_sparse(self, grad, var): 166 return self._apply_sparse_shared(grad.values, grad.indices, var) 167 168 def _apply_sparse_shared(self, grad_values, grad_indices, var): 169 if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0: 170 # The dimension is small enough, we can make the variable dense and 171 # do a dense update 172 dense_grad = array_ops.scatter_nd( 173 array_ops.expand_dims(grad_indices, axis=1), grad_values, 174 array_ops.shape(var, out_type=grad_indices.dtype)) 175 return self._apply_gradient(dense_grad, var) 176 return self._apply_gradient(grad_values, var, grad_indices) 177 178 def _weighted_average(self, var, weight, weight_t, rest): 179 """Computes exponential weighted average: var = weight_t * var + rest. 180 181 Important to ensure that var does not occur in rest, otherwise 182 we can get race conditions in a distributed setting. 183 184 Args: 185 var: variable to be updated 186 weight: parameter to be checked. If it is a constant, we can optimize. 187 weight_t: current value of parameter, used for weighting 188 rest: the remaining tensor to be added 189 190 Returns: 191 updated variable. 192 """ 193 if weight == 0.0: 194 return rest # no need to update var, we will never use it. 195 if weight == 1.0: # common case 196 return state_ops.assign_add(var, rest) 197 # The op below can cause race conditions in a distributed setting, 198 # since computing weight_t * var + rest can take some time, during 199 # which var may be set by another worker. To prevent this, it should 200 # be implemented as a C++ op. 201 return var.assign_add((weight_t - 1) * var + rest) 202 203 def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay, 204 mat_gbar_weight, i): 205 """Updates the cumulative outer products of the gradients. 206 207 Args: 208 mat_g: the matrix to be updated 209 grad: the gradient of the variable 210 axes: a list of k-1 integers 0 to k-1, except i 211 mat_gbar_decay: constant for weighted average: 212 mat_g = mat_g * decay + grad * weight 213 mat_gbar_weight: constant for weighted average 214 i: index of dimension to be updated. 215 216 Returns: 217 updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight 218 219 In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd 220 thus grad_outer is a matrix d_i x d_i, where d_i is the size of the 221 i'th dimension of g. 222 Alternate view: If mat_i(grad) is the flattening of grad to a 223 d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then 224 grad_outer = mat_i(grad) mat_i(grad).transpose 225 """ 226 grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes), 227 name="grad_outer_" + str(i)) 228 return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay, 229 mat_gbar_weight * grad_outer) 230 231 def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name): 232 """Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix. 233 234 Args: 235 var: the variable we are updating. 236 mat_g: the symmetric PSD matrix whose power it to be computed 237 mat_g_size: size of mat_g 238 alpha: a real number 239 mat_h_slot_name: name of slot to store the power, if needed. 240 241 Returns: 242 mat_h = mat_g^alpha 243 244 Stores mat_h in the appropriate slot, if it exists. 245 Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig. 246 """ 247 if mat_g_size == 1: 248 mat_h = math_ops.pow(mat_g + self._epsilon, alpha) 249 else: 250 damping = self._epsilon * linalg_ops.eye( 251 math_ops.cast(mat_g_size, dtypes.int32)) 252 diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True) 253 mat_h = math_ops.matmul( 254 mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha), 255 array_ops.transpose(mat_u)) 256 if mat_h_slot_name is not None: 257 return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) 258 return mat_h 259 260 def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name, 261 iter_count=100, epsilon=1e-6): 262 """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.""" 263 264 mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size, 265 iter_count, self._epsilon) 266 mat_h = matrix_functions.matrix_inverse_pth_root( 267 mat_g_sqrt, 268 mat_g_size, 269 2 * alpha, 270 iter_count, 271 epsilon, 272 ridge_epsilon=0.0) 273 274 if mat_h_slot_name is not None: 275 return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h) 276 return mat_h 277 278 def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None): 279 """Just a switch between the iterative power vs svd.""" 280 with ops.name_scope("matrix_iterative_power"): 281 if self._use_iterative_root: 282 return self._compute_power_iter(var, mat_g, mat_g_size, alpha, 283 mat_h_slot_name) 284 else: 285 return self._compute_power_svd(var, mat_g, mat_g_size, alpha, 286 mat_h_slot_name) 287 288 def _apply_gradient(self, grad, var, indices=None): 289 """The main function to update a variable. 290 291 Args: 292 grad: A Tensor containing gradient to apply. 293 var: A Tensor containing the variable to update. 294 indices: An array of integers, for sparse update. 295 296 Returns: 297 Updated variable var = var - learning_rate * preconditioner * grad 298 299 If the gradient is dense, var and grad have the same shape. 300 If the update is sparse, then the first dimension of the gradient and var 301 may differ, others are all the same. In this case the indices array 302 provides the set of indices of the variable which are to be updated with 303 each row of the gradient. 304 """ 305 global_step = self._global_step + 1 306 307 # Update accumulated weighted average of gradients 308 gbar = self.get_slot(var, "gbar") 309 gbar_decay_t = GetParam(self._gbar_decay, global_step) 310 gbar_weight_t = GetParam(self._gbar_weight, global_step) 311 if indices is not None: 312 # Note - the sparse update is not easily implemented, since the 313 # algorithm needs all indices of gbar to be updated 314 # if mat_gbar_decay != 1 or mat_gbar_decay != 0. 315 # One way to make mat_gbar_decay = 1 is by rescaling. 316 # If we want the update: 317 # G_{t+1} = a_{t+1} G_t + b_{t+1} w_t 318 # define: 319 # r_{t+1} = a_{t+1} * r_t 320 # h_t = G_t / r_t 321 # Then: 322 # h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t 323 # So we get the mat_gbar_decay = 1 as desired. 324 # We can implement this in a future version as needed. 325 # However we still need gbar_decay = 0, otherwise all indices 326 # of the variable will need to be updated. 327 if self._gbar_decay != 0.0: 328 tf_logging.warning("Not applying momentum for variable: %s" % var.name) 329 gbar_updated = grad 330 else: 331 gbar_updated = self._weighted_average(gbar, self._gbar_decay, 332 gbar_decay_t, 333 gbar_weight_t * grad) 334 335 # Update the preconditioners and compute the preconditioned gradient 336 shape = var.get_shape() 337 mat_g_list = [] 338 for i in range(len(shape)): 339 mat_g_list.append(self.get_slot(var, "Gbar_" + str(i))) 340 mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step) 341 mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step) 342 343 preconditioned_grad = gbar_updated 344 v_rank = len(mat_g_list) 345 neg_alpha = - GetParam(self._alpha, global_step) / v_rank 346 svd_interval = GetParam(self._svd_interval, global_step) 347 precond_update_interval = GetParam(self._precond_update_interval, 348 global_step) 349 for i, mat_g in enumerate(mat_g_list): 350 # axes is the list of indices to reduce - everything but the current i. 351 axes = list(range(i)) + list(range(i+1, v_rank)) 352 if shape[i] <= self._max_matrix_size: 353 # If the tensor size is sufficiently small perform full Shampoo update 354 # Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this 355 # is not strictly correct. However we will use it for now, and 356 # fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg) 357 358 # pylint: disable=g-long-lambda,cell-var-from-loop 359 mat_g_updated = control_flow_ops.cond( 360 math_ops.mod(global_step, precond_update_interval) < 1, 361 lambda: self._update_mat_g( 362 mat_g, grad, axes, mat_gbar_decay_t, 363 mat_gbar_weight_t * precond_update_interval, i), 364 lambda: mat_g) 365 366 mat_g_updated = mat_g_updated / float(shape[i].value) 367 368 if self._svd_interval == 1: 369 mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha) 370 else: 371 mat_h = control_flow_ops.cond( 372 math_ops.mod(global_step, svd_interval) < 1, 373 lambda: self._compute_power(var, mat_g_updated, shape[i], 374 neg_alpha, "H_" + str(i)), 375 lambda: self.get_slot(var, "H_" + str(i))) 376 377 # mat_h is a square matrix of size d_i x d_i 378 # preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor 379 # After contraction with a d_i x d_i tensor 380 # it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor 381 # (the first dimension is contracted out, and the second dimension of 382 # mat_h is appended). After going through all the indices, it becomes 383 # a d_0 x ... x d_n tensor again. 384 preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h, 385 axes=([0], [0]), 386 name="precond_" + str(i)) 387 else: 388 # Tensor size is too large -- perform diagonal Shampoo update 389 # Only normalize non-vector cases. 390 if axes: 391 normalizer = 1.0 if indices is not None else float(shape[i].value) 392 grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer 393 else: 394 grad_outer = grad * grad 395 396 if i == 0 and indices is not None: 397 assert self._mat_gbar_decay == 1.0 398 mat_g_updated = state_ops.scatter_add(mat_g, indices, 399 mat_gbar_weight_t * grad_outer) 400 mat_g_updated_slice = array_ops.gather(mat_g_updated, indices) 401 mat_h = array_ops.where( 402 math_ops.greater(mat_g_updated_slice, 0), 403 math_ops.pow(mat_g_updated_slice, neg_alpha), 404 array_ops.zeros_like(mat_g_updated_slice)) 405 else: 406 mat_g_updated = self._weighted_average(mat_g, 407 self._mat_gbar_decay, 408 mat_gbar_decay_t, 409 mat_gbar_weight_t * grad_outer) 410 mat_h = array_ops.where( 411 math_ops.greater(mat_g_updated, 0), 412 math_ops.pow(mat_g_updated, neg_alpha), 413 array_ops.zeros_like(mat_g_updated)) 414 415 # Need to do the transpose to ensure that the tensor becomes 416 # a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above. 417 preconditioned_grad = array_ops.transpose( 418 preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h 419 420 # Update the variable based on the Shampoo update 421 learning_rate_t = GetParam(self._learning_rate, global_step) 422 if indices is not None: 423 var_updated = state_ops.scatter_add( 424 var, indices, -learning_rate_t * preconditioned_grad) 425 else: 426 var_updated = state_ops.assign_sub(var, 427 learning_rate_t * preconditioned_grad) 428 return var_updated 429