• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops for matrix factorization."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import numbers
23
24from six.moves import xrange  # pylint: disable=redefined-builtin
25
26from tensorflow.contrib.factorization.python.ops import gen_factorization_ops
27from tensorflow.contrib.util import loader
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import check_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import data_flow_ops
36from tensorflow.python.ops import embedding_ops
37from tensorflow.python.ops import linalg_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import random_ops
40from tensorflow.python.ops import sparse_ops
41from tensorflow.python.ops import state_ops
42from tensorflow.python.ops import variable_scope
43from tensorflow.python.ops import variables
44from tensorflow.python.platform import resource_loader
45
46_factorization_ops = loader.load_op_library(
47    resource_loader.get_path_to_datafile("_factorization_ops.so"))
48
49
50class WALSModel(object):
51  r"""A model for Weighted Alternating Least Squares matrix factorization.
52
53  It minimizes the following loss function over U, V:
54  $$
55   \|\sqrt W \odot (A - U V^T)\|_F^2 + \lambda (\|U\|_F^2 + \|V\|_F^2)
56  $$
57    where,
58    A: input matrix,
59    W: weight matrix. Note that the (element-wise) square root of the weights
60      is used in the objective function.
61    U, V: row_factors and column_factors matrices,
62    \\(\lambda)\\: regularization.
63  Also we assume that W is of the following special form:
64  \\( W_{ij} = W_0 + R_i * C_j \\)  if \\(A_{ij} \ne 0\\),
65  \\(W_{ij} = W_0\\) otherwise.
66  where,
67  \\(W_0\\): unobserved_weight,
68  \\(R_i\\): row_weights,
69  \\(C_j\\): col_weights.
70
71  Note that the current implementation supports two operation modes: The default
72  mode is for the condition where row_factors and col_factors can individually
73  fit into the memory of each worker and these will be cached. When this
74  condition can't be met, setting use_factors_weights_cache to False allows the
75  larger problem sizes with slight performance penalty as this will avoid
76  creating the worker caches and instead the relevant weight and factor values
77  are looked up from parameter servers at each step.
78
79  Loss computation: The loss can be computed efficiently by decomposing it into
80  a sparse term and a Gramian term, see wals.md.
81  The loss is returned by the update_{col, row}_factors(sp_input), and is
82  normalized as follows:
83    _, _, unregularized_loss, regularization, sum_weights =
84        update_row_factors(sp_input)
85  if sp_input contains the rows \\({A_i, i \in I}\\), and the input matrix A
86  has n total rows, then the minibatch loss = unregularized_loss +
87  regularization is
88   $$
89   (\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 + \lambda \|U_I\|_F^2) * n / |I| +
90   \lambda \|V\|_F^2
91   $$
92  The sum_weights tensor contains the normalized sum of weights
93  \\(sum(W_I) * n / |I|\\).
94
95  A typical usage example (pseudocode):
96
97    with tf.Graph().as_default():
98      # Set up the model object.
99      model = tf.contrib.factorization.WALSModel(....)
100
101      # To be run only once as part of session initialization. In distributed
102      # training setting, this should only be run by the chief trainer and all
103      # other trainers should block until this is done.
104      model_init_op = model.initialize_op
105
106      # To be run once per worker after session is available, prior to
107      # the prep_gramian_op for row(column) can be run.
108      worker_init_op = model.worker_init
109
110      # To be run once per iteration sweep before the row(column) update
111      # initialize ops can be run. Note that in the distributed training
112      # situations, this should only be run by the chief trainer. All other
113      # trainers need to block until this is done.
114      row_update_prep_gramian_op = model.row_update_prep_gramian_op
115      col_update_prep_gramian_op = model.col_update_prep_gramian_op
116
117      # To be run once per worker per iteration sweep. Must be run before
118      # any actual update ops can be run.
119      init_row_update_op = model.initialize_row_update_op
120      init_col_update_op = model.initialize_col_update_op
121
122      # Ops to update row(column). This can either take the entire sparse
123      # tensor or slices of sparse tensor. For distributed trainer, each
124      # trainer handles just part of the matrix.
125      _, row_update_op, unreg_row_loss, row_reg, _ = model.update_row_factors(
126           sp_input=matrix_slices_from_queue_for_worker_shard)
127      row_loss = unreg_row_loss + row_reg
128      _, col_update_op, unreg_col_loss, col_reg, _ = model.update_col_factors(
129           sp_input=transposed_matrix_slices_from_queue_for_worker_shard,
130           transpose_input=True)
131      col_loss = unreg_col_loss + col_reg
132
133      ...
134
135      # model_init_op is passed to Supervisor. Chief trainer runs it. Other
136      # trainers wait.
137      sv = tf.train.Supervisor(is_chief=is_chief,
138                         ...,
139                         init_op=tf.group(..., model_init_op, ...), ...)
140      ...
141
142      with sv.managed_session(...) as sess:
143        # All workers/trainers run it after session becomes available.
144        worker_init_op.run(session=sess)
145
146        ...
147
148        while i in iterations:
149
150          # All trainers need to sync up here.
151          while not_all_ready:
152            wait
153
154          # Row update sweep.
155          if is_chief:
156            row_update_prep_gramian_op.run(session=sess)
157          else:
158            wait_for_chief
159
160          # All workers run upate initialization.
161          init_row_update_op.run(session=sess)
162
163          # Go through the matrix.
164          reset_matrix_slices_queue_for_worker_shard
165          while_matrix_slices:
166            row_update_op.run(session=sess)
167
168          # All trainers need to sync up here.
169          while not_all_ready:
170            wait
171
172          # Column update sweep.
173          if is_chief:
174            col_update_prep_gramian_op.run(session=sess)
175          else:
176            wait_for_chief
177
178          # All workers run upate initialization.
179          init_col_update_op.run(session=sess)
180
181          # Go through the matrix.
182          reset_transposed_matrix_slices_queue_for_worker_shard
183          while_transposed_matrix_slices:
184            col_update_op.run(session=sess)
185  """
186
187  def __init__(self,
188               input_rows,
189               input_cols,
190               n_components,
191               unobserved_weight=0.1,
192               regularization=None,
193               row_init="random",
194               col_init="random",
195               num_row_shards=1,
196               num_col_shards=1,
197               row_weights=1,
198               col_weights=1,
199               use_factors_weights_cache=True,
200               use_gramian_cache=True,
201               use_scoped_vars=False):
202    """Creates model for WALS matrix factorization.
203
204    Args:
205      input_rows: total number of rows for input matrix.
206      input_cols: total number of cols for input matrix.
207      n_components: number of dimensions to use for the factors.
208      unobserved_weight: weight given to unobserved entries of matrix.
209      regularization: weight of L2 regularization term. If None, no
210        regularization is done.
211      row_init: initializer for row factor. Can be a tensor or numpy constant.
212        If set to "random", the value is initialized randomly.
213      col_init: initializer for column factor. See row_init for details.
214      num_row_shards: number of shards to use for row factors.
215      num_col_shards: number of shards to use for column factors.
216      row_weights: Must be in one of the following three formats: None, a list
217        of lists of non-negative real numbers (or equivalent iterables) or a
218        single non-negative real number.
219        - When set to None, w_ij = unobserved_weight, which simplifies to ALS.
220        Note that col_weights must also be set to "None" in this case.
221        - If it is a list of lists of non-negative real numbers, it needs to be
222        in the form of [[w_0, w_1, ...], [w_k, ... ], [...]], with the number of
223        inner lists matching the number of row factor shards and the elements in
224        each inner list are the weights for the rows of the corresponding row
225        factor shard. In this case,  w_ij = unobserved_weight +
226                                            row_weights[i] * col_weights[j].
227        - If this is a single non-negative real number, this value is used for
228        all row weights and \\(w_ij\\) = unobserved_weight + row_weights *
229                                   col_weights[j].
230        Note that it is allowed to have row_weights as a list while col_weights
231        a single number or vice versa.
232      col_weights: See row_weights.
233      use_factors_weights_cache: When True, the factors and weights will be
234        cached on the workers before the updates start. Defaults to True. Note
235        that the weights cache is initialized through `worker_init`, and the
236        row/col factors cache is initialized through
237        `initialize_{col/row}_update_op`. In the case where the weights are
238        computed outside and set before the training iterations start, it is
239        important to ensure the `worker_init` op is run afterwards for the
240        weights cache to take effect.
241      use_gramian_cache: When True, the Gramians will be cached on the workers
242        before the updates start. Defaults to True.
243      use_scoped_vars: When True, the factor and weight vars will also be nested
244        in a tf.name_scope.
245    """
246    self._input_rows = input_rows
247    self._input_cols = input_cols
248    self._num_row_shards = num_row_shards
249    self._num_col_shards = num_col_shards
250    self._n_components = n_components
251    self._unobserved_weight = unobserved_weight
252    self._regularization = regularization
253    self._regularization_matrix = (
254        regularization * linalg_ops.eye(self._n_components)
255        if regularization is not None else None)
256    assert (row_weights is None) == (col_weights is None)
257    self._use_factors_weights_cache = use_factors_weights_cache
258    self._use_gramian_cache = use_gramian_cache
259
260    if use_scoped_vars:
261      with ops.name_scope("row_weights"):
262        self._row_weights = WALSModel._create_weights(
263            row_weights, self._input_rows, self._num_row_shards, "row_weights")
264      with ops.name_scope("col_weights"):
265        self._col_weights = WALSModel._create_weights(
266            col_weights, self._input_cols, self._num_col_shards, "col_weights")
267      with ops.name_scope("row_factors"):
268        self._row_factors = self._create_factors(
269            self._input_rows, self._n_components, self._num_row_shards,
270            row_init, "row_factors")
271      with ops.name_scope("col_factors"):
272        self._col_factors = self._create_factors(
273            self._input_cols, self._n_components, self._num_col_shards,
274            col_init, "col_factors")
275    else:
276      self._row_weights = WALSModel._create_weights(
277          row_weights, self._input_rows, self._num_row_shards, "row_weights")
278      self._col_weights = WALSModel._create_weights(
279          col_weights, self._input_cols, self._num_col_shards, "col_weights")
280      self._row_factors = self._create_factors(
281          self._input_rows, self._n_components, self._num_row_shards, row_init,
282          "row_factors")
283      self._col_factors = self._create_factors(
284          self._input_cols, self._n_components, self._num_col_shards, col_init,
285          "col_factors")
286
287    self._row_gramian = self._create_gramian(self._n_components, "row_gramian")
288    self._col_gramian = self._create_gramian(self._n_components, "col_gramian")
289    with ops.name_scope("row_prepare_gramian"):
290      self._row_update_prep_gramian = self._prepare_gramian(
291          self._col_factors, self._col_gramian)
292    with ops.name_scope("col_prepare_gramian"):
293      self._col_update_prep_gramian = self._prepare_gramian(
294          self._row_factors, self._row_gramian)
295    with ops.name_scope("transient_vars"):
296      self._create_transient_vars()
297
298  @property
299  def row_factors(self):
300    """Returns a list of tensors corresponding to row factor shards."""
301    return self._row_factors
302
303  @property
304  def col_factors(self):
305    """Returns a list of tensors corresponding to column factor shards."""
306    return self._col_factors
307
308  @property
309  def row_weights(self):
310    """Returns a list of tensors corresponding to row weight shards."""
311    return self._row_weights
312
313  @property
314  def col_weights(self):
315    """Returns a list of tensors corresponding to col weight shards."""
316    return self._col_weights
317
318  @property
319  def initialize_op(self):
320    """Returns an op for initializing tensorflow variables."""
321    all_vars = self._row_factors + self._col_factors
322    all_vars.extend([self._row_gramian, self._col_gramian])
323    if self._row_weights is not None:
324      assert self._col_weights is not None
325      all_vars.extend(self._row_weights + self._col_weights)
326    return variables.variables_initializer(all_vars)
327
328  @classmethod
329  def _shard_sizes(cls, dims, num_shards):
330    """Helper function to split dims values into num_shards."""
331    shard_size, residual = divmod(dims, num_shards)
332    return [shard_size + 1] * residual + [shard_size] * (num_shards - residual)
333
334  @classmethod
335  def _create_factors(cls, rows, cols, num_shards, init, name):
336    """Helper function to create row and column factors."""
337    if callable(init):
338      init = init()
339    if isinstance(init, list):
340      assert len(init) == num_shards
341    elif isinstance(init, str) and init == "random":
342      pass
343    elif num_shards == 1:
344      init = [init]
345    sharded_matrix = []
346    sizes = cls._shard_sizes(rows, num_shards)
347    assert len(sizes) == num_shards
348
349    def make_initializer(i, size):
350
351      def initializer():
352        if init == "random":
353          return random_ops.random_normal([size, cols])
354        else:
355          return init[i]
356
357      return initializer
358
359    for i, size in enumerate(sizes):
360      var_name = "%s_shard_%d" % (name, i)
361      var_init = make_initializer(i, size)
362      sharded_matrix.append(
363          variable_scope.variable(
364              var_init, dtype=dtypes.float32, name=var_name))
365
366    return sharded_matrix
367
368  @classmethod
369  def _create_weights(cls, wt_init, num_wts, num_shards, name):
370    """Helper function to create sharded weight vector.
371
372    Args:
373      wt_init: init value for the weight. If None, weights are not created. This
374        can be one of the None, a list of non-negative real numbers or a single
375        non-negative real number (or equivalent iterables).
376      num_wts: total size of all the weight shards
377      num_shards: number of shards for the weights
378      name: name for the new Variables.
379
380    Returns:
381      A list of weight shard Tensors.
382
383    Raises:
384      ValueError: If wt_init is not the right format.
385    """
386
387    if wt_init is None:
388      return None
389
390    init_mode = "list"
391    if isinstance(wt_init, collections.Iterable):
392      if num_shards == 1 and len(wt_init) == num_wts:
393        wt_init = [wt_init]
394      assert len(wt_init) == num_shards
395    elif isinstance(wt_init, numbers.Real) and wt_init >= 0:
396      init_mode = "scalar"
397    else:
398      raise ValueError(
399          "Invalid weight initialization argument. Must be one of these: "
400          "None, a real non-negative real number, or a list of lists of "
401          "non-negative real numbers (or equivalent iterables) corresponding "
402          "to sharded factors.")
403
404    sizes = cls._shard_sizes(num_wts, num_shards)
405    assert len(sizes) == num_shards
406
407    def make_wt_initializer(i, size):
408
409      def initializer():
410        if init_mode == "scalar":
411          return wt_init * array_ops.ones([size])
412        else:
413          return wt_init[i]
414
415      return initializer
416
417    sharded_weight = []
418    for i, size in enumerate(sizes):
419      var_name = "%s_shard_%d" % (name, i)
420      var_init = make_wt_initializer(i, size)
421      sharded_weight.append(
422          variable_scope.variable(
423              var_init, dtype=dtypes.float32, name=var_name))
424
425    return sharded_weight
426
427  @staticmethod
428  def _create_gramian(n_components, name):
429    """Helper function to create the gramian variable.
430
431    Args:
432      n_components: number of dimensions of the factors from which the gramian
433        will be calculated.
434      name: name for the new Variables.
435
436    Returns:
437      A gramian Tensor with shape of [n_components, n_components].
438    """
439    return variable_scope.variable(
440        array_ops.zeros([n_components, n_components]),
441        dtype=dtypes.float32,
442        name=name)
443
444  @staticmethod
445  def _transient_var(name):
446    """Helper function to create a Variable."""
447    return variable_scope.variable(
448        1.0,
449        trainable=False,
450        collections=[ops.GraphKeys.LOCAL_VARIABLES],
451        validate_shape=False,
452        name=name)
453
454  def _prepare_gramian(self, factors, gramian):
455    """Helper function to create ops to prepare/calculate gramian.
456
457    Args:
458      factors: Variable or list of Variable representing (sharded) factors.
459        Used to compute the updated corresponding gramian value.
460      gramian: Variable storing the gramian calculated from the factors.
461
462    Returns:
463      An op that updates the gramian with the calculated value from the factors.
464    """
465    partial_gramians = []
466    for f in factors:
467      with ops.colocate_with(f):
468        partial_gramians.append(math_ops.matmul(f, f, transpose_a=True))
469
470    with ops.colocate_with(gramian):
471      prep_gramian = state_ops.assign(gramian,
472                                      math_ops.add_n(partial_gramians)).op
473
474    return prep_gramian
475
476  def _cached_copy(self, var, name, pass_through=False):
477    """Helper function to create a worker cached copy of a Variable.
478
479    This assigns the var (either a single Variable or a list of Variables) to
480    local transient cache Variable(s). Note that if var is a list of Variables,
481    the assignment is done sequentially to minimize the memory overheads.
482    Also note that if pass_through is set to True, this does not create new
483    Variables but simply return the input back.
484
485    Args:
486      var: A Variable or a list of Variables to cache.
487      name: name of cached Variable.
488      pass_through: when set to True, this simply pass through the var back
489        through identity operator and does not actually creates a cache.
490
491    Returns:
492      Tuple consisting of following three entries:
493      cache: the new transient Variable or list of transient Variables
494        corresponding one-to-one with var.
495      cache_init: op to initialize the Variable or the list of Variables.
496      cache_reset: op to reset the Variable or the list of Variables to some
497        default value.
498    """
499    if var is None:
500      return None, None, None
501    elif pass_through:
502      cache = var
503      cache_init = control_flow_ops.no_op()
504      cache_reset = control_flow_ops.no_op()
505    elif isinstance(var, variables.Variable):
506      cache = WALSModel._transient_var(name=name)
507      with ops.colocate_with(cache):
508        cache_init = state_ops.assign(cache, var, validate_shape=False)
509        cache_reset = state_ops.assign(cache, 1.0, validate_shape=False)
510    else:
511      assert isinstance(var, list)
512      assert var
513      cache = [
514          WALSModel._transient_var(name="%s_shard_%d" % (name, i))
515          for i in xrange(len(var))
516      ]
517      reset_ops = []
518      for i, c in enumerate(cache):
519        with ops.colocate_with(c):
520          if i == 0:
521            cache_init = state_ops.assign(c, var[i], validate_shape=False)
522          else:
523            with ops.control_dependencies([cache_init]):
524              cache_init = state_ops.assign(c, var[i], validate_shape=False)
525          reset_ops.append(state_ops.assign(c, 1.0, validate_shape=False))
526      cache_reset = control_flow_ops.group(*reset_ops)
527
528    return cache, cache_init, cache_reset
529
530  def _create_transient_vars(self):
531    """Creates local cache of factors, weights and gramian for rows and columns.
532
533    Note that currently the caching strategy is as follows:
534    When initiating a row (resp. column) update:
535      - The column (resp. row) gramian is computed.
536      - Optionally, if use_gramian_cache is True, the column (resp. row) Gramian
537        is cached, while the row (resp. column) gramian is reset.
538      - Optionally, if use_factors_weights_cache is True, the column (resp. row)
539        factors and weights are cached, while the row (resp. column) factors and
540        weights are reset.
541    """
542
543    (self._row_factors_cache, row_factors_cache_init,
544     row_factors_cache_reset) = self._cached_copy(
545         self._row_factors,
546         "row_factors_cache",
547         pass_through=not self._use_factors_weights_cache)
548    (self._col_factors_cache, col_factors_cache_init,
549     col_factors_cache_reset) = self._cached_copy(
550         self._col_factors,
551         "col_factors_cache",
552         pass_through=not self._use_factors_weights_cache)
553    (self._row_wt_cache, row_wt_cache_init, _) = self._cached_copy(
554        self._row_weights,
555        "row_wt_cache",
556        pass_through=not self._use_factors_weights_cache)
557    (self._col_wt_cache, col_wt_cache_init, _) = self._cached_copy(
558        self._col_weights,
559        "col_wt_cache",
560        pass_through=not self._use_factors_weights_cache)
561    (self._row_gramian_cache, row_gramian_cache_init,
562     row_gramian_cache_reset) = self._cached_copy(
563         self._row_gramian,
564         "row_gramian_cache",
565         pass_through=not self._use_gramian_cache)
566    (self._col_gramian_cache, col_gramian_cache_init,
567     col_gramian_cache_reset) = self._cached_copy(
568         self._col_gramian,
569         "col_gramian_cache",
570         pass_through=not self._use_gramian_cache)
571
572    self._row_updates_init = control_flow_ops.group(
573        col_factors_cache_init, row_factors_cache_reset, col_gramian_cache_init,
574        row_gramian_cache_reset)
575    self._col_updates_init = control_flow_ops.group(
576        row_factors_cache_init, col_factors_cache_reset, row_gramian_cache_init,
577        col_gramian_cache_reset)
578
579    if self._row_wt_cache is not None:
580      assert self._col_wt_cache is not None
581      self._worker_init = control_flow_ops.group(
582          row_wt_cache_init, col_wt_cache_init, name="worker_init")
583    else:
584      self._worker_init = control_flow_ops.no_op(name="worker_init")
585
586  @property
587  def worker_init(self):
588    """Op to initialize worker state once before starting any updates.
589
590    Note that specifically this initializes the cache of the row and column
591    weights on workers when `use_factors_weights_cache` is True. In this case,
592    if these weights are being calculated and reset after the object is created,
593    it is important to ensure this ops is run afterwards so the cache reflects
594    the correct values.
595    """
596    return self._worker_init
597
598  @property
599  def row_update_prep_gramian_op(self):
600    """Op to form the gramian before starting row updates.
601
602    Must be run before initialize_row_update_op and should only be run by one
603    trainer (usually the chief) when doing distributed training.
604
605    Returns:
606      Op to form the gramian.
607    """
608    return self._row_update_prep_gramian
609
610  @property
611  def col_update_prep_gramian_op(self):
612    """Op to form the gramian before starting col updates.
613
614    Must be run before initialize_col_update_op and should only be run by one
615    trainer (usually the chief) when doing distributed training.
616
617    Returns:
618      Op to form the gramian.
619    """
620    return self._col_update_prep_gramian
621
622  @property
623  def initialize_row_update_op(self):
624    """Op to initialize worker state before starting row updates."""
625    return self._row_updates_init
626
627  @property
628  def initialize_col_update_op(self):
629    """Op to initialize worker state before starting column updates."""
630    return self._col_updates_init
631
632  @staticmethod
633  def _get_sharding_func(size, num_shards):
634    """Create sharding function for scatter update."""
635
636    def func(ids):
637      if num_shards == 1:
638        return None, ids
639      else:
640        ids_per_shard = size // num_shards
641        extras = size % num_shards
642        assignments = math_ops.maximum(ids // (ids_per_shard + 1),
643                                       (ids - extras) // ids_per_shard)
644        new_ids = array_ops.where(assignments < extras,
645                                  ids % (ids_per_shard + 1),
646                                  (ids - extras) % ids_per_shard)
647        return assignments, new_ids
648
649    return func
650
651  @classmethod
652  def scatter_update(cls, factor, indices, values, sharding_func, name=None):
653    """Helper function for doing sharded scatter update."""
654    assert isinstance(factor, list)
655    if len(factor) == 1:
656      with ops.colocate_with(factor[0]):
657        # TODO(agarwal): assign instead of scatter update for full batch update.
658        return state_ops.scatter_update(
659            factor[0], indices, values, name=name).op
660    else:
661      num_shards = len(factor)
662      assignments, new_ids = sharding_func(indices)
663      assert assignments is not None
664      assignments = math_ops.cast(assignments, dtypes.int32)
665      sharded_ids = data_flow_ops.dynamic_partition(new_ids, assignments,
666                                                    num_shards)
667      sharded_values = data_flow_ops.dynamic_partition(values, assignments,
668                                                       num_shards)
669      updates = []
670      for i in xrange(num_shards):
671        updates.append(
672            state_ops.scatter_update(factor[i], sharded_ids[i], sharded_values[
673                i]))
674      return control_flow_ops.group(*updates, name=name)
675
676  def update_row_factors(self, sp_input=None, transpose_input=False):
677    r"""Updates the row factors.
678
679    Args:
680      sp_input: A SparseTensor representing a subset of rows of the full input
681        in any order. Please note that this SparseTensor must retain the
682        indexing as the original input.
683      transpose_input: If true, the input will be logically transposed and the
684        rows corresponding to the transposed input are updated.
685
686    Returns:
687      A tuple consisting of the following elements:
688      new_values: New values for the row factors.
689      update_op: An op that assigns the newly computed values to the row
690        factors.
691      unregularized_loss: A tensor (scalar) that contains the normalized
692        minibatch loss corresponding to sp_input, without the regularization
693        term. If sp_input contains the rows \\({A_{i, :}, i \in I}\\), and the
694        input matrix A has n total rows, then the unregularized loss is:
695        \\(\|\sqrt W_I \odot (A_I - U_I V^T)\|_F^2 * n / |I|\\)
696        The total loss is unregularized_loss + regularization.
697      regularization: A tensor (scalar) that contains the normalized
698        regularization term for the minibatch loss corresponding to sp_input.
699        If sp_input contains the rows \\({A_{i, :}, i \in I}\\), and the input
700        matrix A has n total rows, then the regularization term is:
701        \\(\lambda \|U_I\|_F^2) * n / |I| + \lambda \|V\|_F^2\\).
702      sum_weights: The sum of the weights W_I corresponding to sp_input,
703        normalized by a factor of \\(n / |I|\\). The root weighted squared
704        error is: \sqrt(unregularized_loss / sum_weights).
705    """
706    return self._process_input_helper(
707        True, sp_input=sp_input, transpose_input=transpose_input)
708
709  def update_col_factors(self, sp_input=None, transpose_input=False):
710    r"""Updates the column factors.
711
712    Args:
713      sp_input: A SparseTensor representing a subset of columns of the full
714        input. Please refer to comments for update_row_factors for
715        restrictions.
716      transpose_input: If true, the input will be logically transposed and the
717        columns corresponding to the transposed input are updated.
718
719    Returns:
720      A tuple consisting of the following elements:
721      new_values: New values for the column factors.
722      update_op: An op that assigns the newly computed values to the column
723        factors.
724      unregularized_loss: A tensor (scalar) that contains the normalized
725        minibatch loss corresponding to sp_input, without the regularization
726        term. If sp_input contains the columns \\({A_{:, j}, j \in J}\\), and
727        the input matrix A has m total columns, then the unregularized loss is:
728        \\(\|\sqrt W_J \odot (A_J - U V_J^T)\|_F^2 * m / |I|\\)
729        The total loss is unregularized_loss + regularization.
730      regularization: A tensor (scalar) that contains the normalized
731        regularization term for the minibatch loss corresponding to sp_input.
732        If sp_input contains the columns \\({A_{:, j}, j \in J}\\), and the
733        input matrix A has m total columns, then the regularization term is:
734        \\(\lambda \|V_J\|_F^2) * m / |J| + \lambda \|U\|_F^2\\).
735      sum_weights: The sum of the weights W_J corresponding to sp_input,
736        normalized by a factor of \\(m / |J|\\). The root weighted squared
737        error is: \sqrt(unregularized_loss / sum_weights).
738    """
739    return self._process_input_helper(
740        False, sp_input=sp_input, transpose_input=transpose_input)
741
742  def project_row_factors(self,
743                          sp_input=None,
744                          transpose_input=False,
745                          projection_weights=None):
746    """Projects the row factors.
747
748    This computes the row embedding \\(u_i\\) for an observed row \\(a_i\\) by
749    solving one iteration of the update equations.
750
751    Args:
752      sp_input: A SparseTensor representing a set of rows. Please note that the
753        column indices of this SparseTensor must match the model column feature
754        indexing while the row indices are ignored. The returned results will be
755        in the same ordering as the input rows.
756      transpose_input: If true, the input will be logically transposed and the
757        rows corresponding to the transposed input are projected.
758      projection_weights: The row weights to be used for the projection. If None
759        then 1.0 is used. This can be either a scaler or a rank-1 tensor with
760        the number of elements matching the number of rows to be projected.
761        Note that the column weights will be determined by the underlying WALS
762        model.
763
764    Returns:
765      Projected row factors.
766    """
767    if projection_weights is None:
768      projection_weights = 1
769    return self._process_input_helper(
770        True,
771        sp_input=sp_input,
772        transpose_input=transpose_input,
773        row_weights=projection_weights)[0]
774
775  def project_col_factors(self,
776                          sp_input=None,
777                          transpose_input=False,
778                          projection_weights=None):
779    """Projects the column factors.
780
781    This computes the column embedding \\(v_j\\) for an observed column
782    \\(a_j\\) by solving one iteration of the update equations.
783
784    Args:
785      sp_input: A SparseTensor representing a set of columns. Please note that
786        the row indices of this SparseTensor must match the model row feature
787        indexing while the column indices are ignored. The returned results will
788        be in the same ordering as the input columns.
789      transpose_input: If true, the input will be logically transposed and the
790        columns corresponding to the transposed input are projected.
791      projection_weights: The column weights to be used for the projection. If
792        None then 1.0 is used. This can be either a scaler or a rank-1 tensor
793        with the number of elements matching the number of columns to be
794        projected. Note that the row weights will be determined by the
795        underlying WALS model.
796
797    Returns:
798      Projected column factors.
799    """
800    if projection_weights is None:
801      projection_weights = 1
802    return self._process_input_helper(
803        False,
804        sp_input=sp_input,
805        transpose_input=transpose_input,
806        row_weights=projection_weights)[0]
807
808  def _process_input_helper(self,
809                            update_row_factors,
810                            sp_input=None,
811                            transpose_input=False,
812                            row_weights=None):
813    """Creates the graph for processing a sparse slice of input.
814
815    Args:
816      update_row_factors: if True, update or project the row_factors, else
817        update or project the column factors.
818      sp_input: Please refer to comments for update_row_factors,
819        update_col_factors, project_row_factors, and project_col_factors for
820        restrictions.
821      transpose_input: If True, the input is logically transposed and then the
822        corresponding rows/columns of the transposed input are updated.
823      row_weights: If not None, this is the row/column weights to be used for
824        the update or projection. If None, use the corresponding weights from
825        the model. Note that the feature (column/row) weights will be
826        determined by the model. When not None, it can either be a scalar or
827        a rank-1 tensor with the same number of elements as the number of rows
828        of columns to be updated/projected.
829
830    Returns:
831      A tuple consisting of the following elements:
832      new_values: New values for the row/column factors.
833      update_op: An op that assigns the newly computed values to the row/column
834        factors.
835      unregularized_loss: A tensor (scalar) that contains the normalized
836        minibatch loss corresponding to sp_input, without the regularization
837        term. Add the regularization term below to yield the loss.
838      regularization: A tensor (scalar) that contains the normalized
839        regularization term for the minibatch loss corresponding to sp_input.
840      sum_weights: The sum of the weights corresponding to sp_input. This
841        can be used with unregularized loss to calculate the root weighted
842        squared error.
843    """
844    assert isinstance(sp_input, sparse_tensor.SparseTensor)
845
846    if update_row_factors:
847      left = self._row_factors
848      right_factors = self._col_factors_cache
849      row_wt = self._row_wt_cache
850      col_wt = self._col_wt_cache
851      total_rows = self._input_rows
852      total_cols = self._input_cols
853      sharding_func = WALSModel._get_sharding_func(self._input_rows,
854                                                   self._num_row_shards)
855      gramian = self._col_gramian_cache
856    else:
857      left = self._col_factors
858      right_factors = self._row_factors_cache
859      row_wt = self._col_wt_cache
860      col_wt = self._row_wt_cache
861      total_rows = self._input_cols
862      total_cols = self._input_rows
863      sharding_func = WALSModel._get_sharding_func(self._input_cols,
864                                                   self._num_col_shards)
865      gramian = self._row_gramian_cache
866      transpose_input = not transpose_input
867
868    # Note that the row indices of sp_input are based on the original full input
869    # Here we reindex the rows and give them contiguous ids starting at 0.
870    # We use tf.unique to achieve this reindexing. Note that this is done so
871    # that the downstream kernel can assume that the input is "dense" along the
872    # row dimension.
873    row_ids, col_ids = array_ops.split(
874        value=sp_input.indices, num_or_size_splits=2, axis=1)
875    update_row_indices, all_row_ids = array_ops.unique(row_ids[:, 0])
876    update_col_indices, all_col_ids = array_ops.unique(col_ids[:, 0])
877    col_ids = array_ops.expand_dims(math_ops.cast(all_col_ids, dtypes.int64), 1)
878    row_ids = array_ops.expand_dims(math_ops.cast(all_row_ids, dtypes.int64), 1)
879
880    if transpose_input:
881      update_indices = update_col_indices
882      row_shape = [
883          math_ops.cast(array_ops.shape(update_row_indices)[0], dtypes.int64)
884      ]
885      gather_indices = update_row_indices
886    else:
887      update_indices = update_row_indices
888      row_shape = [
889          math_ops.cast(array_ops.shape(update_col_indices)[0], dtypes.int64)
890      ]
891      gather_indices = update_col_indices
892
893    num_rows = math_ops.cast(array_ops.shape(update_indices)[0], dtypes.int64)
894    col_shape = [num_rows]
895    right = embedding_ops.embedding_lookup(
896        right_factors, gather_indices, partition_strategy="div")
897    new_sp_indices = array_ops.concat([row_ids, col_ids], 1)
898    new_sp_shape = (array_ops.concat([row_shape, col_shape], 0)
899                    if transpose_input else
900                    array_ops.concat([col_shape, row_shape], 0))
901    new_sp_input = sparse_tensor.SparseTensor(
902        indices=new_sp_indices,
903        values=sp_input.values,
904        dense_shape=new_sp_shape)
905
906    # Compute lhs and rhs of the normal equations
907    total_lhs = (self._unobserved_weight * gramian)
908    if self._regularization_matrix is not None:
909      total_lhs += self._regularization_matrix
910    if self._row_weights is None:
911      # Special case of ALS. Use a much simpler update rule.
912      total_rhs = (
913          self._unobserved_weight * sparse_ops.sparse_tensor_dense_matmul(
914              new_sp_input, right, adjoint_a=transpose_input))
915      # TODO(rmlarsen): handle transposing in tf.matrix_solve instead of
916      # transposing explicitly.
917      # TODO(rmlarsen): multi-thread tf.matrix_solve.
918      new_left_values = array_ops.transpose(
919          linalg_ops.matrix_solve(total_lhs, array_ops.transpose(total_rhs)))
920    else:
921      if row_weights is None:
922        # TODO(yifanchen): Add special handling for single shard without using
923        # embedding_lookup and perform benchmarks for those cases. Same for
924        # col_weights lookup below.
925        row_weights_slice = embedding_ops.embedding_lookup(
926            row_wt, update_indices, partition_strategy="div")
927      else:
928        num_indices = array_ops.shape(update_indices)[0]
929        with ops.control_dependencies(
930            [check_ops.assert_less_equal(array_ops.rank(row_weights), 1)]):
931          row_weights_slice = control_flow_ops.cond(
932              math_ops.equal(array_ops.rank(row_weights), 0),
933              lambda: (array_ops.ones([num_indices]) * row_weights),
934              lambda: math_ops.cast(row_weights, dtypes.float32))
935
936      col_weights = embedding_ops.embedding_lookup(
937          col_wt, gather_indices, partition_strategy="div")
938      partial_lhs, total_rhs = (
939          gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
940              right,
941              col_weights,
942              self._unobserved_weight,
943              row_weights_slice,
944              new_sp_input.indices,
945              new_sp_input.values,
946              [],
947              num_rows,
948              transpose_input,
949              name="wals_compute_partial_lhs_rhs"))
950      total_lhs = array_ops.expand_dims(total_lhs, 0) + partial_lhs
951      total_rhs = array_ops.expand_dims(total_rhs, -1)
952      new_left_values = array_ops.squeeze(
953          linalg_ops.matrix_solve(total_lhs, total_rhs), [2])
954
955    update_op_name = "row_update" if update_row_factors else "col_update"
956    update_op = self.scatter_update(
957        left,
958        update_indices,
959        new_left_values,
960        sharding_func,
961        name=update_op_name)
962
963    # Create the loss subgraph
964    loss_sp_input = (sparse_ops.sparse_transpose(new_sp_input)
965                     if transpose_input else new_sp_input)
966    # sp_approx is the low rank estimate of the input matrix, formed by
967    # computing the product <\\(u_i, v_j\\)> for (i, j) in loss_sp_input.indices.
968    sp_approx_vals = gen_factorization_ops.masked_matmul(
969        new_left_values,
970        right,
971        loss_sp_input.indices,
972        transpose_a=False,
973        transpose_b=True)
974    sp_approx = sparse_tensor.SparseTensor(
975        loss_sp_input.indices, sp_approx_vals, loss_sp_input.dense_shape)
976    sp_approx_sq = math_ops.square(sp_approx)
977    sp_residual = sparse_ops.sparse_add(loss_sp_input, sp_approx * (-1))
978    sp_residual_sq = math_ops.square(sp_residual)
979    row_wt_mat = (constant_op.constant(0.)
980                  if self._row_weights is None else array_ops.expand_dims(
981                      row_weights_slice, 1))
982    col_wt_mat = (constant_op.constant(0.)
983                  if self._col_weights is None else array_ops.expand_dims(
984                      col_weights, 0))
985
986    # We return the normalized loss
987    partial_row_gramian = math_ops.matmul(
988        new_left_values, new_left_values, transpose_a=True)
989    normalization_factor = total_rows / math_ops.cast(num_rows, dtypes.float32)
990
991    unregularized_loss = (
992        self._unobserved_weight * (  # pyformat line break
993            sparse_ops.sparse_reduce_sum(sp_residual_sq) -  # pyformat break
994            sparse_ops.sparse_reduce_sum(sp_approx_sq) +  # pyformat break
995            math_ops.trace(math_ops.matmul(partial_row_gramian, gramian))) +
996        sparse_ops.sparse_reduce_sum(row_wt_mat * (sp_residual_sq * col_wt_mat))
997    ) * normalization_factor
998
999    if self._regularization is not None:
1000      regularization = self._regularization * (
1001          math_ops.trace(partial_row_gramian) * normalization_factor +
1002          math_ops.trace(gramian))
1003    else:
1004      regularization = constant_op.constant(0.)
1005
1006    sum_weights = self._unobserved_weight * math_ops.cast(
1007        total_rows * total_cols, dtypes.float32)
1008    if self._row_weights is not None and self._col_weights is not None:
1009      ones = sparse_tensor.SparseTensor(
1010          indices=loss_sp_input.indices,
1011          values=array_ops.ones(array_ops.shape(loss_sp_input.values)),
1012          dense_shape=loss_sp_input.dense_shape)
1013      sum_weights += sparse_ops.sparse_reduce_sum(row_wt_mat * (
1014          ones * col_wt_mat)) * normalization_factor
1015
1016    return (new_left_values, update_op, unregularized_loss, regularization,
1017            sum_weights)
1018