1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops for matrix factorization.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import numbers 23 24from six.moves import xrange # pylint: disable=redefined-builtin 25 26from tensorflow.contrib.factorization.python.ops import gen_factorization_ops 27from tensorflow.contrib.util import loader 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import sparse_tensor 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import check_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import data_flow_ops 36from tensorflow.python.ops import embedding_ops 37from tensorflow.python.ops import linalg_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import random_ops 40from tensorflow.python.ops import sparse_ops 41from tensorflow.python.ops import state_ops 42from tensorflow.python.ops import variable_scope 43from tensorflow.python.ops import variables 44from tensorflow.python.platform import resource_loader 45 46_factorization_ops = loader.load_op_library( 47 resource_loader.get_path_to_datafile("_factorization_ops.so")) 48 49 50class WALSModel(object): 51 r"""A model for Weighted Alternating Least Squares matrix factorization. 52 53 It minimizes the following loss function over U, V: 54 $$ 55 \|\sqrt W \odot (A - U V^T)\|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2) 56 $$ 57 where, 58 A: input matrix, 59 W: weight matrix. Note that the (element-wise) square root of the weights 60 is used in the objective function. 61 U, V: row_factors and column_factors matrices, 62 \\(\lambda)\\: regularization. 63 Also we assume that W is of the following special form: 64 \\( W_{ij} = W_0 + R_i * C_j \\) if \\(A_{ij} \ne 0\\), 65 \\(W_{ij} = W_0\\) otherwise. 66 where, 67 \\(W_0\\): unobserved_weight, 68 \\(R_i\\): row_weights, 69 \\(C_j\\): col_weights. 70 71 Note that the current implementation supports two operation modes: The default 72 mode is for the condition where row_factors and col_factors can individually 73 fit into the memory of each worker and these will be cached. When this 74 condition can't be met, setting use_factors_weights_cache to False allows the 75 larger problem sizes with slight performance penalty as this will avoid 76 creating the worker caches and instead the relevant weight and factor values 77 are looked up from parameter servers at each step. 78 79 Loss computation: The loss can be computed efficiently by decomposing it into 80 a sparse term and a Gramian term, see wals.md. 81 The loss is returned by the update_{col, row}_factors(sp_input), and is 82 normalized as follows: 83 _, _, unregularized_loss, regularization, sum_weights = 84 update_row_factors(sp_input) 85 if sp_input contains the rows \\({A_i, i \in I}\\), and the input matrix A 86 has n total rows, then the minibatch loss = unregularized_loss + 87 regularization is 88 $$ 89 (\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 + \lambda \|U_I\|_F^2) * n / |I| + 90 \lambda \|V\|_F^2 91 $$ 92 The sum_weights tensor contains the normalized sum of weights 93 \\(sum(W_I) * n / |I|\\). 94 95 A typical usage example (pseudocode): 96 97 with tf.Graph().as_default(): 98 # Set up the model object. 99 model = tf.contrib.factorization.WALSModel(....) 100 101 # To be run only once as part of session initialization. In distributed 102 # training setting, this should only be run by the chief trainer and all 103 # other trainers should block until this is done. 104 model_init_op = model.initialize_op 105 106 # To be run once per worker after session is available, prior to 107 # the prep_gramian_op for row(column) can be run. 108 worker_init_op = model.worker_init 109 110 # To be run once per iteration sweep before the row(column) update 111 # initialize ops can be run. Note that in the distributed training 112 # situations, this should only be run by the chief trainer. All other 113 # trainers need to block until this is done. 114 row_update_prep_gramian_op = model.row_update_prep_gramian_op 115 col_update_prep_gramian_op = model.col_update_prep_gramian_op 116 117 # To be run once per worker per iteration sweep. Must be run before 118 # any actual update ops can be run. 119 init_row_update_op = model.initialize_row_update_op 120 init_col_update_op = model.initialize_col_update_op 121 122 # Ops to update row(column). This can either take the entire sparse 123 # tensor or slices of sparse tensor. For distributed trainer, each 124 # trainer handles just part of the matrix. 125 _, row_update_op, unreg_row_loss, row_reg, _ = model.update_row_factors( 126 sp_input=matrix_slices_from_queue_for_worker_shard) 127 row_loss = unreg_row_loss + row_reg 128 _, col_update_op, unreg_col_loss, col_reg, _ = model.update_col_factors( 129 sp_input=transposed_matrix_slices_from_queue_for_worker_shard, 130 transpose_input=True) 131 col_loss = unreg_col_loss + col_reg 132 133 ... 134 135 # model_init_op is passed to Supervisor. Chief trainer runs it. Other 136 # trainers wait. 137 sv = tf.train.Supervisor(is_chief=is_chief, 138 ..., 139 init_op=tf.group(..., model_init_op, ...), ...) 140 ... 141 142 with sv.managed_session(...) as sess: 143 # All workers/trainers run it after session becomes available. 144 worker_init_op.run(session=sess) 145 146 ... 147 148 while i in iterations: 149 150 # All trainers need to sync up here. 151 while not_all_ready: 152 wait 153 154 # Row update sweep. 155 if is_chief: 156 row_update_prep_gramian_op.run(session=sess) 157 else: 158 wait_for_chief 159 160 # All workers run upate initialization. 161 init_row_update_op.run(session=sess) 162 163 # Go through the matrix. 164 reset_matrix_slices_queue_for_worker_shard 165 while_matrix_slices: 166 row_update_op.run(session=sess) 167 168 # All trainers need to sync up here. 169 while not_all_ready: 170 wait 171 172 # Column update sweep. 173 if is_chief: 174 col_update_prep_gramian_op.run(session=sess) 175 else: 176 wait_for_chief 177 178 # All workers run upate initialization. 179 init_col_update_op.run(session=sess) 180 181 # Go through the matrix. 182 reset_transposed_matrix_slices_queue_for_worker_shard 183 while_transposed_matrix_slices: 184 col_update_op.run(session=sess) 185 """ 186 187 def __init__(self, 188 input_rows, 189 input_cols, 190 n_components, 191 unobserved_weight=0.1, 192 regularization=None, 193 row_init="random", 194 col_init="random", 195 num_row_shards=1, 196 num_col_shards=1, 197 row_weights=1, 198 col_weights=1, 199 use_factors_weights_cache=True, 200 use_gramian_cache=True, 201 use_scoped_vars=False): 202 """Creates model for WALS matrix factorization. 203 204 Args: 205 input_rows: total number of rows for input matrix. 206 input_cols: total number of cols for input matrix. 207 n_components: number of dimensions to use for the factors. 208 unobserved_weight: weight given to unobserved entries of matrix. 209 regularization: weight of L2 regularization term. If None, no 210 regularization is done. 211 row_init: initializer for row factor. Can be a tensor or numpy constant. 212 If set to "random", the value is initialized randomly. 213 col_init: initializer for column factor. See row_init for details. 214 num_row_shards: number of shards to use for row factors. 215 num_col_shards: number of shards to use for column factors. 216 row_weights: Must be in one of the following three formats: None, a list 217 of lists of non-negative real numbers (or equivalent iterables) or a 218 single non-negative real number. 219 - When set to None, w_ij = unobserved_weight, which simplifies to ALS. 220 Note that col_weights must also be set to "None" in this case. 221 - If it is a list of lists of non-negative real numbers, it needs to be 222 in the form of [[w_0, w_1, ...], [w_k, ... ], [...]], with the number of 223 inner lists matching the number of row factor shards and the elements in 224 each inner list are the weights for the rows of the corresponding row 225 factor shard. In this case, w_ij = unobserved_weight + 226 row_weights[i] * col_weights[j]. 227 - If this is a single non-negative real number, this value is used for 228 all row weights and \\(w_ij\\) = unobserved_weight + row_weights * 229 col_weights[j]. 230 Note that it is allowed to have row_weights as a list while col_weights 231 a single number or vice versa. 232 col_weights: See row_weights. 233 use_factors_weights_cache: When True, the factors and weights will be 234 cached on the workers before the updates start. Defaults to True. Note 235 that the weights cache is initialized through `worker_init`, and the 236 row/col factors cache is initialized through 237 `initialize_{col/row}_update_op`. In the case where the weights are 238 computed outside and set before the training iterations start, it is 239 important to ensure the `worker_init` op is run afterwards for the 240 weights cache to take effect. 241 use_gramian_cache: When True, the Gramians will be cached on the workers 242 before the updates start. Defaults to True. 243 use_scoped_vars: When True, the factor and weight vars will also be nested 244 in a tf.name_scope. 245 """ 246 self._input_rows = input_rows 247 self._input_cols = input_cols 248 self._num_row_shards = num_row_shards 249 self._num_col_shards = num_col_shards 250 self._n_components = n_components 251 self._unobserved_weight = unobserved_weight 252 self._regularization = regularization 253 self._regularization_matrix = ( 254 regularization * linalg_ops.eye(self._n_components) 255 if regularization is not None else None) 256 assert (row_weights is None) == (col_weights is None) 257 self._use_factors_weights_cache = use_factors_weights_cache 258 self._use_gramian_cache = use_gramian_cache 259 260 if use_scoped_vars: 261 with ops.name_scope("row_weights"): 262 self._row_weights = WALSModel._create_weights( 263 row_weights, self._input_rows, self._num_row_shards, "row_weights") 264 with ops.name_scope("col_weights"): 265 self._col_weights = WALSModel._create_weights( 266 col_weights, self._input_cols, self._num_col_shards, "col_weights") 267 with ops.name_scope("row_factors"): 268 self._row_factors = self._create_factors( 269 self._input_rows, self._n_components, self._num_row_shards, 270 row_init, "row_factors") 271 with ops.name_scope("col_factors"): 272 self._col_factors = self._create_factors( 273 self._input_cols, self._n_components, self._num_col_shards, 274 col_init, "col_factors") 275 else: 276 self._row_weights = WALSModel._create_weights( 277 row_weights, self._input_rows, self._num_row_shards, "row_weights") 278 self._col_weights = WALSModel._create_weights( 279 col_weights, self._input_cols, self._num_col_shards, "col_weights") 280 self._row_factors = self._create_factors( 281 self._input_rows, self._n_components, self._num_row_shards, row_init, 282 "row_factors") 283 self._col_factors = self._create_factors( 284 self._input_cols, self._n_components, self._num_col_shards, col_init, 285 "col_factors") 286 287 self._row_gramian = self._create_gramian(self._n_components, "row_gramian") 288 self._col_gramian = self._create_gramian(self._n_components, "col_gramian") 289 with ops.name_scope("row_prepare_gramian"): 290 self._row_update_prep_gramian = self._prepare_gramian( 291 self._col_factors, self._col_gramian) 292 with ops.name_scope("col_prepare_gramian"): 293 self._col_update_prep_gramian = self._prepare_gramian( 294 self._row_factors, self._row_gramian) 295 with ops.name_scope("transient_vars"): 296 self._create_transient_vars() 297 298 @property 299 def row_factors(self): 300 """Returns a list of tensors corresponding to row factor shards.""" 301 return self._row_factors 302 303 @property 304 def col_factors(self): 305 """Returns a list of tensors corresponding to column factor shards.""" 306 return self._col_factors 307 308 @property 309 def row_weights(self): 310 """Returns a list of tensors corresponding to row weight shards.""" 311 return self._row_weights 312 313 @property 314 def col_weights(self): 315 """Returns a list of tensors corresponding to col weight shards.""" 316 return self._col_weights 317 318 @property 319 def initialize_op(self): 320 """Returns an op for initializing tensorflow variables.""" 321 all_vars = self._row_factors + self._col_factors 322 all_vars.extend([self._row_gramian, self._col_gramian]) 323 if self._row_weights is not None: 324 assert self._col_weights is not None 325 all_vars.extend(self._row_weights + self._col_weights) 326 return variables.variables_initializer(all_vars) 327 328 @classmethod 329 def _shard_sizes(cls, dims, num_shards): 330 """Helper function to split dims values into num_shards.""" 331 shard_size, residual = divmod(dims, num_shards) 332 return [shard_size + 1] * residual + [shard_size] * (num_shards - residual) 333 334 @classmethod 335 def _create_factors(cls, rows, cols, num_shards, init, name): 336 """Helper function to create row and column factors.""" 337 if callable(init): 338 init = init() 339 if isinstance(init, list): 340 assert len(init) == num_shards 341 elif isinstance(init, str) and init == "random": 342 pass 343 elif num_shards == 1: 344 init = [init] 345 sharded_matrix = [] 346 sizes = cls._shard_sizes(rows, num_shards) 347 assert len(sizes) == num_shards 348 349 def make_initializer(i, size): 350 351 def initializer(): 352 if init == "random": 353 return random_ops.random_normal([size, cols]) 354 else: 355 return init[i] 356 357 return initializer 358 359 for i, size in enumerate(sizes): 360 var_name = "%s_shard_%d" % (name, i) 361 var_init = make_initializer(i, size) 362 sharded_matrix.append( 363 variable_scope.variable( 364 var_init, dtype=dtypes.float32, name=var_name)) 365 366 return sharded_matrix 367 368 @classmethod 369 def _create_weights(cls, wt_init, num_wts, num_shards, name): 370 """Helper function to create sharded weight vector. 371 372 Args: 373 wt_init: init value for the weight. If None, weights are not created. This 374 can be one of the None, a list of non-negative real numbers or a single 375 non-negative real number (or equivalent iterables). 376 num_wts: total size of all the weight shards 377 num_shards: number of shards for the weights 378 name: name for the new Variables. 379 380 Returns: 381 A list of weight shard Tensors. 382 383 Raises: 384 ValueError: If wt_init is not the right format. 385 """ 386 387 if wt_init is None: 388 return None 389 390 init_mode = "list" 391 if isinstance(wt_init, collections.Iterable): 392 if num_shards == 1 and len(wt_init) == num_wts: 393 wt_init = [wt_init] 394 assert len(wt_init) == num_shards 395 elif isinstance(wt_init, numbers.Real) and wt_init >= 0: 396 init_mode = "scalar" 397 else: 398 raise ValueError( 399 "Invalid weight initialization argument. Must be one of these: " 400 "None, a real non-negative real number, or a list of lists of " 401 "non-negative real numbers (or equivalent iterables) corresponding " 402 "to sharded factors.") 403 404 sizes = cls._shard_sizes(num_wts, num_shards) 405 assert len(sizes) == num_shards 406 407 def make_wt_initializer(i, size): 408 409 def initializer(): 410 if init_mode == "scalar": 411 return wt_init * array_ops.ones([size]) 412 else: 413 return wt_init[i] 414 415 return initializer 416 417 sharded_weight = [] 418 for i, size in enumerate(sizes): 419 var_name = "%s_shard_%d" % (name, i) 420 var_init = make_wt_initializer(i, size) 421 sharded_weight.append( 422 variable_scope.variable( 423 var_init, dtype=dtypes.float32, name=var_name)) 424 425 return sharded_weight 426 427 @staticmethod 428 def _create_gramian(n_components, name): 429 """Helper function to create the gramian variable. 430 431 Args: 432 n_components: number of dimensions of the factors from which the gramian 433 will be calculated. 434 name: name for the new Variables. 435 436 Returns: 437 A gramian Tensor with shape of [n_components, n_components]. 438 """ 439 return variable_scope.variable( 440 array_ops.zeros([n_components, n_components]), 441 dtype=dtypes.float32, 442 name=name) 443 444 @staticmethod 445 def _transient_var(name): 446 """Helper function to create a Variable.""" 447 return variable_scope.variable( 448 1.0, 449 trainable=False, 450 collections=[ops.GraphKeys.LOCAL_VARIABLES], 451 validate_shape=False, 452 name=name) 453 454 def _prepare_gramian(self, factors, gramian): 455 """Helper function to create ops to prepare/calculate gramian. 456 457 Args: 458 factors: Variable or list of Variable representing (sharded) factors. 459 Used to compute the updated corresponding gramian value. 460 gramian: Variable storing the gramian calculated from the factors. 461 462 Returns: 463 An op that updates the gramian with the calculated value from the factors. 464 """ 465 partial_gramians = [] 466 for f in factors: 467 with ops.colocate_with(f): 468 partial_gramians.append(math_ops.matmul(f, f, transpose_a=True)) 469 470 with ops.colocate_with(gramian): 471 prep_gramian = state_ops.assign(gramian, 472 math_ops.add_n(partial_gramians)).op 473 474 return prep_gramian 475 476 def _cached_copy(self, var, name, pass_through=False): 477 """Helper function to create a worker cached copy of a Variable. 478 479 This assigns the var (either a single Variable or a list of Variables) to 480 local transient cache Variable(s). Note that if var is a list of Variables, 481 the assignment is done sequentially to minimize the memory overheads. 482 Also note that if pass_through is set to True, this does not create new 483 Variables but simply return the input back. 484 485 Args: 486 var: A Variable or a list of Variables to cache. 487 name: name of cached Variable. 488 pass_through: when set to True, this simply pass through the var back 489 through identity operator and does not actually creates a cache. 490 491 Returns: 492 Tuple consisting of following three entries: 493 cache: the new transient Variable or list of transient Variables 494 corresponding one-to-one with var. 495 cache_init: op to initialize the Variable or the list of Variables. 496 cache_reset: op to reset the Variable or the list of Variables to some 497 default value. 498 """ 499 if var is None: 500 return None, None, None 501 elif pass_through: 502 cache = var 503 cache_init = control_flow_ops.no_op() 504 cache_reset = control_flow_ops.no_op() 505 elif isinstance(var, variables.Variable): 506 cache = WALSModel._transient_var(name=name) 507 with ops.colocate_with(cache): 508 cache_init = state_ops.assign(cache, var, validate_shape=False) 509 cache_reset = state_ops.assign(cache, 1.0, validate_shape=False) 510 else: 511 assert isinstance(var, list) 512 assert var 513 cache = [ 514 WALSModel._transient_var(name="%s_shard_%d" % (name, i)) 515 for i in xrange(len(var)) 516 ] 517 reset_ops = [] 518 for i, c in enumerate(cache): 519 with ops.colocate_with(c): 520 if i == 0: 521 cache_init = state_ops.assign(c, var[i], validate_shape=False) 522 else: 523 with ops.control_dependencies([cache_init]): 524 cache_init = state_ops.assign(c, var[i], validate_shape=False) 525 reset_ops.append(state_ops.assign(c, 1.0, validate_shape=False)) 526 cache_reset = control_flow_ops.group(*reset_ops) 527 528 return cache, cache_init, cache_reset 529 530 def _create_transient_vars(self): 531 """Creates local cache of factors, weights and gramian for rows and columns. 532 533 Note that currently the caching strategy is as follows: 534 When initiating a row (resp. column) update: 535 - The column (resp. row) gramian is computed. 536 - Optionally, if use_gramian_cache is True, the column (resp. row) Gramian 537 is cached, while the row (resp. column) gramian is reset. 538 - Optionally, if use_factors_weights_cache is True, the column (resp. row) 539 factors and weights are cached, while the row (resp. column) factors and 540 weights are reset. 541 """ 542 543 (self._row_factors_cache, row_factors_cache_init, 544 row_factors_cache_reset) = self._cached_copy( 545 self._row_factors, 546 "row_factors_cache", 547 pass_through=not self._use_factors_weights_cache) 548 (self._col_factors_cache, col_factors_cache_init, 549 col_factors_cache_reset) = self._cached_copy( 550 self._col_factors, 551 "col_factors_cache", 552 pass_through=not self._use_factors_weights_cache) 553 (self._row_wt_cache, row_wt_cache_init, _) = self._cached_copy( 554 self._row_weights, 555 "row_wt_cache", 556 pass_through=not self._use_factors_weights_cache) 557 (self._col_wt_cache, col_wt_cache_init, _) = self._cached_copy( 558 self._col_weights, 559 "col_wt_cache", 560 pass_through=not self._use_factors_weights_cache) 561 (self._row_gramian_cache, row_gramian_cache_init, 562 row_gramian_cache_reset) = self._cached_copy( 563 self._row_gramian, 564 "row_gramian_cache", 565 pass_through=not self._use_gramian_cache) 566 (self._col_gramian_cache, col_gramian_cache_init, 567 col_gramian_cache_reset) = self._cached_copy( 568 self._col_gramian, 569 "col_gramian_cache", 570 pass_through=not self._use_gramian_cache) 571 572 self._row_updates_init = control_flow_ops.group( 573 col_factors_cache_init, row_factors_cache_reset, col_gramian_cache_init, 574 row_gramian_cache_reset) 575 self._col_updates_init = control_flow_ops.group( 576 row_factors_cache_init, col_factors_cache_reset, row_gramian_cache_init, 577 col_gramian_cache_reset) 578 579 if self._row_wt_cache is not None: 580 assert self._col_wt_cache is not None 581 self._worker_init = control_flow_ops.group( 582 row_wt_cache_init, col_wt_cache_init, name="worker_init") 583 else: 584 self._worker_init = control_flow_ops.no_op(name="worker_init") 585 586 @property 587 def worker_init(self): 588 """Op to initialize worker state once before starting any updates. 589 590 Note that specifically this initializes the cache of the row and column 591 weights on workers when `use_factors_weights_cache` is True. In this case, 592 if these weights are being calculated and reset after the object is created, 593 it is important to ensure this ops is run afterwards so the cache reflects 594 the correct values. 595 """ 596 return self._worker_init 597 598 @property 599 def row_update_prep_gramian_op(self): 600 """Op to form the gramian before starting row updates. 601 602 Must be run before initialize_row_update_op and should only be run by one 603 trainer (usually the chief) when doing distributed training. 604 605 Returns: 606 Op to form the gramian. 607 """ 608 return self._row_update_prep_gramian 609 610 @property 611 def col_update_prep_gramian_op(self): 612 """Op to form the gramian before starting col updates. 613 614 Must be run before initialize_col_update_op and should only be run by one 615 trainer (usually the chief) when doing distributed training. 616 617 Returns: 618 Op to form the gramian. 619 """ 620 return self._col_update_prep_gramian 621 622 @property 623 def initialize_row_update_op(self): 624 """Op to initialize worker state before starting row updates.""" 625 return self._row_updates_init 626 627 @property 628 def initialize_col_update_op(self): 629 """Op to initialize worker state before starting column updates.""" 630 return self._col_updates_init 631 632 @staticmethod 633 def _get_sharding_func(size, num_shards): 634 """Create sharding function for scatter update.""" 635 636 def func(ids): 637 if num_shards == 1: 638 return None, ids 639 else: 640 ids_per_shard = size // num_shards 641 extras = size % num_shards 642 assignments = math_ops.maximum(ids // (ids_per_shard + 1), 643 (ids - extras) // ids_per_shard) 644 new_ids = array_ops.where(assignments < extras, 645 ids % (ids_per_shard + 1), 646 (ids - extras) % ids_per_shard) 647 return assignments, new_ids 648 649 return func 650 651 @classmethod 652 def scatter_update(cls, factor, indices, values, sharding_func, name=None): 653 """Helper function for doing sharded scatter update.""" 654 assert isinstance(factor, list) 655 if len(factor) == 1: 656 with ops.colocate_with(factor[0]): 657 # TODO(agarwal): assign instead of scatter update for full batch update. 658 return state_ops.scatter_update( 659 factor[0], indices, values, name=name).op 660 else: 661 num_shards = len(factor) 662 assignments, new_ids = sharding_func(indices) 663 assert assignments is not None 664 assignments = math_ops.cast(assignments, dtypes.int32) 665 sharded_ids = data_flow_ops.dynamic_partition(new_ids, assignments, 666 num_shards) 667 sharded_values = data_flow_ops.dynamic_partition(values, assignments, 668 num_shards) 669 updates = [] 670 for i in xrange(num_shards): 671 updates.append( 672 state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[ 673 i])) 674 return control_flow_ops.group(*updates, name=name) 675 676 def update_row_factors(self, sp_input=None, transpose_input=False): 677 r"""Updates the row factors. 678 679 Args: 680 sp_input: A SparseTensor representing a subset of rows of the full input 681 in any order. Please note that this SparseTensor must retain the 682 indexing as the original input. 683 transpose_input: If true, the input will be logically transposed and the 684 rows corresponding to the transposed input are updated. 685 686 Returns: 687 A tuple consisting of the following elements: 688 new_values: New values for the row factors. 689 update_op: An op that assigns the newly computed values to the row 690 factors. 691 unregularized_loss: A tensor (scalar) that contains the normalized 692 minibatch loss corresponding to sp_input, without the regularization 693 term. If sp_input contains the rows \\({A_{i, :}, i \in I}\\), and the 694 input matrix A has n total rows, then the unregularized loss is: 695 \\(\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 * n / |I|\\) 696 The total loss is unregularized_loss + regularization. 697 regularization: A tensor (scalar) that contains the normalized 698 regularization term for the minibatch loss corresponding to sp_input. 699 If sp_input contains the rows \\({A_{i, :}, i \in I}\\), and the input 700 matrix A has n total rows, then the regularization term is: 701 \\(\lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2\\). 702 sum_weights: The sum of the weights W_I corresponding to sp_input, 703 normalized by a factor of \\(n / |I|\\). The root weighted squared 704 error is: \sqrt(unregularized_loss / sum_weights). 705 """ 706 return self._process_input_helper( 707 True, sp_input=sp_input, transpose_input=transpose_input) 708 709 def update_col_factors(self, sp_input=None, transpose_input=False): 710 r"""Updates the column factors. 711 712 Args: 713 sp_input: A SparseTensor representing a subset of columns of the full 714 input. Please refer to comments for update_row_factors for 715 restrictions. 716 transpose_input: If true, the input will be logically transposed and the 717 columns corresponding to the transposed input are updated. 718 719 Returns: 720 A tuple consisting of the following elements: 721 new_values: New values for the column factors. 722 update_op: An op that assigns the newly computed values to the column 723 factors. 724 unregularized_loss: A tensor (scalar) that contains the normalized 725 minibatch loss corresponding to sp_input, without the regularization 726 term. If sp_input contains the columns \\({A_{:, j}, j \in J}\\), and 727 the input matrix A has m total columns, then the unregularized loss is: 728 \\(\|\sqrt W_J \odot (A_J - U V_J^T)\|_F^2 * m / |I|\\) 729 The total loss is unregularized_loss + regularization. 730 regularization: A tensor (scalar) that contains the normalized 731 regularization term for the minibatch loss corresponding to sp_input. 732 If sp_input contains the columns \\({A_{:, j}, j \in J}\\), and the 733 input matrix A has m total columns, then the regularization term is: 734 \\(\lambda \|V_J\|_F^2) * m / |J| + \lambda \|U\|_F^2\\). 735 sum_weights: The sum of the weights W_J corresponding to sp_input, 736 normalized by a factor of \\(m / |J|\\). The root weighted squared 737 error is: \sqrt(unregularized_loss / sum_weights). 738 """ 739 return self._process_input_helper( 740 False, sp_input=sp_input, transpose_input=transpose_input) 741 742 def project_row_factors(self, 743 sp_input=None, 744 transpose_input=False, 745 projection_weights=None): 746 """Projects the row factors. 747 748 This computes the row embedding \\(u_i\\) for an observed row \\(a_i\\) by 749 solving one iteration of the update equations. 750 751 Args: 752 sp_input: A SparseTensor representing a set of rows. Please note that the 753 column indices of this SparseTensor must match the model column feature 754 indexing while the row indices are ignored. The returned results will be 755 in the same ordering as the input rows. 756 transpose_input: If true, the input will be logically transposed and the 757 rows corresponding to the transposed input are projected. 758 projection_weights: The row weights to be used for the projection. If None 759 then 1.0 is used. This can be either a scaler or a rank-1 tensor with 760 the number of elements matching the number of rows to be projected. 761 Note that the column weights will be determined by the underlying WALS 762 model. 763 764 Returns: 765 Projected row factors. 766 """ 767 if projection_weights is None: 768 projection_weights = 1 769 return self._process_input_helper( 770 True, 771 sp_input=sp_input, 772 transpose_input=transpose_input, 773 row_weights=projection_weights)[0] 774 775 def project_col_factors(self, 776 sp_input=None, 777 transpose_input=False, 778 projection_weights=None): 779 """Projects the column factors. 780 781 This computes the column embedding \\(v_j\\) for an observed column 782 \\(a_j\\) by solving one iteration of the update equations. 783 784 Args: 785 sp_input: A SparseTensor representing a set of columns. Please note that 786 the row indices of this SparseTensor must match the model row feature 787 indexing while the column indices are ignored. The returned results will 788 be in the same ordering as the input columns. 789 transpose_input: If true, the input will be logically transposed and the 790 columns corresponding to the transposed input are projected. 791 projection_weights: The column weights to be used for the projection. If 792 None then 1.0 is used. This can be either a scaler or a rank-1 tensor 793 with the number of elements matching the number of columns to be 794 projected. Note that the row weights will be determined by the 795 underlying WALS model. 796 797 Returns: 798 Projected column factors. 799 """ 800 if projection_weights is None: 801 projection_weights = 1 802 return self._process_input_helper( 803 False, 804 sp_input=sp_input, 805 transpose_input=transpose_input, 806 row_weights=projection_weights)[0] 807 808 def _process_input_helper(self, 809 update_row_factors, 810 sp_input=None, 811 transpose_input=False, 812 row_weights=None): 813 """Creates the graph for processing a sparse slice of input. 814 815 Args: 816 update_row_factors: if True, update or project the row_factors, else 817 update or project the column factors. 818 sp_input: Please refer to comments for update_row_factors, 819 update_col_factors, project_row_factors, and project_col_factors for 820 restrictions. 821 transpose_input: If True, the input is logically transposed and then the 822 corresponding rows/columns of the transposed input are updated. 823 row_weights: If not None, this is the row/column weights to be used for 824 the update or projection. If None, use the corresponding weights from 825 the model. Note that the feature (column/row) weights will be 826 determined by the model. When not None, it can either be a scalar or 827 a rank-1 tensor with the same number of elements as the number of rows 828 of columns to be updated/projected. 829 830 Returns: 831 A tuple consisting of the following elements: 832 new_values: New values for the row/column factors. 833 update_op: An op that assigns the newly computed values to the row/column 834 factors. 835 unregularized_loss: A tensor (scalar) that contains the normalized 836 minibatch loss corresponding to sp_input, without the regularization 837 term. Add the regularization term below to yield the loss. 838 regularization: A tensor (scalar) that contains the normalized 839 regularization term for the minibatch loss corresponding to sp_input. 840 sum_weights: The sum of the weights corresponding to sp_input. This 841 can be used with unregularized loss to calculate the root weighted 842 squared error. 843 """ 844 assert isinstance(sp_input, sparse_tensor.SparseTensor) 845 846 if update_row_factors: 847 left = self._row_factors 848 right_factors = self._col_factors_cache 849 row_wt = self._row_wt_cache 850 col_wt = self._col_wt_cache 851 total_rows = self._input_rows 852 total_cols = self._input_cols 853 sharding_func = WALSModel._get_sharding_func(self._input_rows, 854 self._num_row_shards) 855 gramian = self._col_gramian_cache 856 else: 857 left = self._col_factors 858 right_factors = self._row_factors_cache 859 row_wt = self._col_wt_cache 860 col_wt = self._row_wt_cache 861 total_rows = self._input_cols 862 total_cols = self._input_rows 863 sharding_func = WALSModel._get_sharding_func(self._input_cols, 864 self._num_col_shards) 865 gramian = self._row_gramian_cache 866 transpose_input = not transpose_input 867 868 # Note that the row indices of sp_input are based on the original full input 869 # Here we reindex the rows and give them contiguous ids starting at 0. 870 # We use tf.unique to achieve this reindexing. Note that this is done so 871 # that the downstream kernel can assume that the input is "dense" along the 872 # row dimension. 873 row_ids, col_ids = array_ops.split( 874 value=sp_input.indices, num_or_size_splits=2, axis=1) 875 update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0]) 876 update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0]) 877 col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1) 878 row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1) 879 880 if transpose_input: 881 update_indices = update_col_indices 882 row_shape = [ 883 math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64) 884 ] 885 gather_indices = update_row_indices 886 else: 887 update_indices = update_row_indices 888 row_shape = [ 889 math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64) 890 ] 891 gather_indices = update_col_indices 892 893 num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64) 894 col_shape = [num_rows] 895 right = embedding_ops.embedding_lookup( 896 right_factors, gather_indices, partition_strategy="div") 897 new_sp_indices = array_ops.concat([row_ids, col_ids], 1) 898 new_sp_shape = (array_ops.concat([row_shape, col_shape], 0) 899 if transpose_input else 900 array_ops.concat([col_shape, row_shape], 0)) 901 new_sp_input = sparse_tensor.SparseTensor( 902 indices=new_sp_indices, 903 values=sp_input.values, 904 dense_shape=new_sp_shape) 905 906 # Compute lhs and rhs of the normal equations 907 total_lhs = (self._unobserved_weight * gramian) 908 if self._regularization_matrix is not None: 909 total_lhs += self._regularization_matrix 910 if self._row_weights is None: 911 # Special case of ALS. Use a much simpler update rule. 912 total_rhs = ( 913 self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul( 914 new_sp_input, right, adjoint_a=transpose_input)) 915 # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of 916 # transposing explicitly. 917 # TODO(rmlarsen): multi-thread tf.matrix_solve. 918 new_left_values = array_ops.transpose( 919 linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs))) 920 else: 921 if row_weights is None: 922 # TODO(yifanchen): Add special handling for single shard without using 923 # embedding_lookup and perform benchmarks for those cases. Same for 924 # col_weights lookup below. 925 row_weights_slice = embedding_ops.embedding_lookup( 926 row_wt, update_indices, partition_strategy="div") 927 else: 928 num_indices = array_ops.shape(update_indices)[0] 929 with ops.control_dependencies( 930 [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]): 931 row_weights_slice = control_flow_ops.cond( 932 math_ops.equal(array_ops.rank(row_weights), 0), 933 lambda: (array_ops.ones([num_indices]) * row_weights), 934 lambda: math_ops.cast(row_weights, dtypes.float32)) 935 936 col_weights = embedding_ops.embedding_lookup( 937 col_wt, gather_indices, partition_strategy="div") 938 partial_lhs, total_rhs = ( 939 gen_factorization_ops.wals_compute_partial_lhs_and_rhs( 940 right, 941 col_weights, 942 self._unobserved_weight, 943 row_weights_slice, 944 new_sp_input.indices, 945 new_sp_input.values, 946 [], 947 num_rows, 948 transpose_input, 949 name="wals_compute_partial_lhs_rhs")) 950 total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs 951 total_rhs = array_ops.expand_dims(total_rhs, -1) 952 new_left_values = array_ops.squeeze( 953 linalg_ops.matrix_solve(total_lhs, total_rhs), [2]) 954 955 update_op_name = "row_update" if update_row_factors else "col_update" 956 update_op = self.scatter_update( 957 left, 958 update_indices, 959 new_left_values, 960 sharding_func, 961 name=update_op_name) 962 963 # Create the loss subgraph 964 loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input) 965 if transpose_input else new_sp_input) 966 # sp_approx is the low rank estimate of the input matrix, formed by 967 # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices. 968 sp_approx_vals = gen_factorization_ops.masked_matmul( 969 new_left_values, 970 right, 971 loss_sp_input.indices, 972 transpose_a=False, 973 transpose_b=True) 974 sp_approx = sparse_tensor.SparseTensor( 975 loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape) 976 sp_approx_sq = math_ops.square(sp_approx) 977 sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1)) 978 sp_residual_sq = math_ops.square(sp_residual) 979 row_wt_mat = (constant_op.constant(0.) 980 if self._row_weights is None else array_ops.expand_dims( 981 row_weights_slice, 1)) 982 col_wt_mat = (constant_op.constant(0.) 983 if self._col_weights is None else array_ops.expand_dims( 984 col_weights, 0)) 985 986 # We return the normalized loss 987 partial_row_gramian = math_ops.matmul( 988 new_left_values, new_left_values, transpose_a=True) 989 normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32) 990 991 unregularized_loss = ( 992 self._unobserved_weight * ( # pyformat line break 993 sparse_ops.sparse_reduce_sum(sp_residual_sq) - # pyformat break 994 sparse_ops.sparse_reduce_sum(sp_approx_sq) + # pyformat break 995 math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) + 996 sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat)) 997 ) * normalization_factor 998 999 if self._regularization is not None: 1000 regularization = self._regularization * ( 1001 math_ops.trace(partial_row_gramian) * normalization_factor + 1002 math_ops.trace(gramian)) 1003 else: 1004 regularization = constant_op.constant(0.) 1005 1006 sum_weights = self._unobserved_weight * math_ops.cast( 1007 total_rows * total_cols, dtypes.float32) 1008 if self._row_weights is not None and self._col_weights is not None: 1009 ones = sparse_tensor.SparseTensor( 1010 indices=loss_sp_input.indices, 1011 values=array_ops.ones(array_ops.shape(loss_sp_input.values)), 1012 dense_shape=loss_sp_input.dense_shape) 1013 sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * ( 1014 ones * col_wt_mat)) * normalization_factor 1015 1016 return (new_left_values, update_op, unregularized_loss, regularization, 1017 sum_weights) 1018