1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for WALSMatrixFactorization.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22import json 23import numpy as np 24 25from tensorflow.contrib.factorization.python.ops import factorization_ops_test_utils 26from tensorflow.contrib.factorization.python.ops import wals as wals_lib 27from tensorflow.contrib.learn.python.learn import run_config 28from tensorflow.contrib.learn.python.learn.estimators import model_fn 29from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import sparse_tensor 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import embedding_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import sparse_ops 38from tensorflow.python.ops import state_ops 39from tensorflow.python.ops import variables 40from tensorflow.python.platform import test 41from tensorflow.python.training import input as input_lib 42from tensorflow.python.training import monitored_session 43 44 45class WALSMatrixFactorizationTest(test.TestCase): 46 INPUT_MATRIX = factorization_ops_test_utils.INPUT_MATRIX 47 48 def np_array_to_sparse(self, np_array): 49 """Transforms an np.array to a tf.SparseTensor.""" 50 return factorization_ops_test_utils.np_matrix_to_tf_sparse(np_array) 51 52 def calculate_loss(self): 53 """Calculates the loss of the current (trained) model.""" 54 current_rows = embedding_ops.embedding_lookup( 55 self._model.get_row_factors(), math_ops.range(self._num_rows), 56 partition_strategy='div') 57 current_cols = embedding_ops.embedding_lookup( 58 self._model.get_col_factors(), math_ops.range(self._num_cols), 59 partition_strategy='div') 60 row_wts = embedding_ops.embedding_lookup( 61 self._row_weights, math_ops.range(self._num_rows), 62 partition_strategy='div') 63 col_wts = embedding_ops.embedding_lookup( 64 self._col_weights, math_ops.range(self._num_cols), 65 partition_strategy='div') 66 sp_inputs = self.np_array_to_sparse(self.INPUT_MATRIX) 67 return factorization_ops_test_utils.calculate_loss( 68 sp_inputs, current_rows, current_cols, self._regularization_coeff, 69 self._unobserved_weight, row_wts, col_wts) 70 71 # TODO(walidk): Replace with input_reader_utils functions once open sourced. 72 def remap_sparse_tensor_rows(self, sp_x, row_ids, shape): 73 """Remaps the row ids of a tf.SparseTensor.""" 74 old_row_ids, old_col_ids = array_ops.split( 75 value=sp_x.indices, num_or_size_splits=2, axis=1) 76 new_row_ids = array_ops.gather(row_ids, old_row_ids) 77 new_indices = array_ops.concat([new_row_ids, old_col_ids], 1) 78 return sparse_tensor.SparseTensor( 79 indices=new_indices, values=sp_x.values, dense_shape=shape) 80 81 # TODO(walidk): Add an option to shuffle inputs. 82 def input_fn(self, np_matrix, batch_size, mode, 83 project_row=None, projection_weights=None, 84 remove_empty_rows_columns=False): 85 """Returns an input_fn that selects row and col batches from np_matrix. 86 87 This simple utility creates an input function from a numpy_array. The 88 following transformations are performed: 89 * The empty rows and columns in np_matrix are removed (if 90 remove_empty_rows_columns is true) 91 * np_matrix is converted to a SparseTensor. 92 * The rows of the sparse matrix (and the rows of its transpose) are batched. 93 * A features dictionary is created, which contains the row / column batches. 94 95 In TRAIN mode, one only needs to specify the np_matrix and the batch_size. 96 In INFER and EVAL modes, one must also provide project_row, a boolean which 97 specifies whether we are projecting rows or columns. 98 99 Args: 100 np_matrix: A numpy array. The input matrix to use. 101 batch_size: Integer. 102 mode: Can be one of model_fn.ModeKeys.{TRAIN, INFER, EVAL}. 103 project_row: A boolean. Used in INFER and EVAL modes. Specifies whether 104 to project rows or columns. 105 projection_weights: A float numpy array. Used in INFER mode. Specifies 106 the weights to use in the projection (the weights are optional, and 107 default to 1.). 108 remove_empty_rows_columns: A boolean. When true, this will remove empty 109 rows and columns in the np_matrix. Note that this will result in 110 modifying the indices of the input matrix. The mapping from new indices 111 to old indices is returned in the form of two numpy arrays. 112 113 Returns: 114 A tuple consisting of: 115 _fn: A callable. Calling _fn returns a features dict. 116 nz_row_ids: A numpy array of the ids of non-empty rows, such that 117 nz_row_ids[i] is the old row index corresponding to new index i. 118 nz_col_ids: A numpy array of the ids of non-empty columns, such that 119 nz_col_ids[j] is the old column index corresponding to new index j. 120 """ 121 if remove_empty_rows_columns: 122 np_matrix, nz_row_ids, nz_col_ids = ( 123 factorization_ops_test_utils.remove_empty_rows_columns(np_matrix)) 124 else: 125 nz_row_ids = np.arange(np.shape(np_matrix)[0]) 126 nz_col_ids = np.arange(np.shape(np_matrix)[1]) 127 128 def extract_features(row_batch, col_batch, num_rows, num_cols): 129 row_ids = row_batch[0] 130 col_ids = col_batch[0] 131 rows = self.remap_sparse_tensor_rows( 132 row_batch[1], row_ids, shape=[num_rows, num_cols]) 133 cols = self.remap_sparse_tensor_rows( 134 col_batch[1], col_ids, shape=[num_cols, num_rows]) 135 features = { 136 wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows, 137 wals_lib.WALSMatrixFactorization.INPUT_COLS: cols, 138 } 139 return features 140 141 def _fn(): 142 num_rows = np.shape(np_matrix)[0] 143 num_cols = np.shape(np_matrix)[1] 144 row_ids = math_ops.range(num_rows, dtype=dtypes.int64) 145 col_ids = math_ops.range(num_cols, dtype=dtypes.int64) 146 sp_mat = self.np_array_to_sparse(np_matrix) 147 sp_mat_t = sparse_ops.sparse_transpose(sp_mat) 148 row_batch = input_lib.batch( 149 [row_ids, sp_mat], 150 batch_size=min(batch_size, num_rows), 151 capacity=10, 152 enqueue_many=True) 153 col_batch = input_lib.batch( 154 [col_ids, sp_mat_t], 155 batch_size=min(batch_size, num_cols), 156 capacity=10, 157 enqueue_many=True) 158 159 features = extract_features(row_batch, col_batch, num_rows, num_cols) 160 161 if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL: 162 self.assertTrue( 163 project_row is not None, 164 msg='project_row must be specified in INFER or EVAL mode.') 165 features[wals_lib.WALSMatrixFactorization.PROJECT_ROW] = ( 166 constant_op.constant(project_row)) 167 168 if mode == model_fn.ModeKeys.INFER and projection_weights is not None: 169 weights_batch = input_lib.batch( 170 projection_weights, 171 batch_size=batch_size, 172 capacity=10, 173 enqueue_many=True) 174 features[wals_lib.WALSMatrixFactorization.PROJECTION_WEIGHTS] = ( 175 weights_batch) 176 177 labels = None 178 return features, labels 179 180 return _fn, nz_row_ids, nz_col_ids 181 182 @property 183 def input_matrix(self): 184 return self.INPUT_MATRIX 185 186 @property 187 def row_steps(self): 188 return np.ceil(self._num_rows / self.batch_size) 189 190 @property 191 def col_steps(self): 192 return np.ceil(self._num_cols / self.batch_size) 193 194 @property 195 def batch_size(self): 196 return 5 197 198 @property 199 def use_cache(self): 200 return False 201 202 @property 203 def max_sweeps(self): 204 return None 205 206 def setUp(self): 207 self._num_rows = 5 208 self._num_cols = 7 209 self._embedding_dimension = 3 210 self._unobserved_weight = 0.1 211 self._num_row_shards = 2 212 self._num_col_shards = 3 213 self._regularization_coeff = 0.01 214 self._col_init = [ 215 # Shard 0. 216 [[-0.36444709, -0.39077035, -0.32528427], 217 [1.19056475, 0.07231052, 2.11834812], 218 [0.93468881, -0.71099287, 1.91826844]], 219 # Shard 1. 220 [[1.18160152, 1.52490723, -0.50015002], 221 [1.82574749, -0.57515913, -1.32810032]], 222 # Shard 2. 223 [[-0.15515432, -0.84675711, 0.13097958], 224 [-0.9246484, 0.69117504, 1.2036494]], 225 ] 226 self._row_weights = [[0.1, 0.2, 0.3], [0.4, 0.5]] 227 self._col_weights = [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]] 228 229 # Values of row and column factors after running one iteration or factor 230 # updates. 231 self._row_factors_0 = [[0.097689, -0.219293, -0.020780], 232 [0.50842, 0.64626, 0.22364], 233 [0.401159, -0.046558, -0.192854]] 234 self._row_factors_1 = [[1.20597, -0.48025, 0.35582], 235 [1.5564, 1.2528, 1.0528]] 236 self._col_factors_0 = [[2.4725, -1.2950, -1.9980], 237 [0.44625, 1.50771, 1.27118], 238 [1.39801, -2.10134, 0.73572]] 239 self._col_factors_1 = [[3.36509, -0.66595, -3.51208], 240 [0.57191, 1.59407, 1.33020]] 241 self._col_factors_2 = [[3.3459, -1.3341, -3.3008], 242 [0.57366, 1.83729, 1.26798]] 243 self._model = wals_lib.WALSMatrixFactorization( 244 self._num_rows, 245 self._num_cols, 246 self._embedding_dimension, 247 self._unobserved_weight, 248 col_init=self._col_init, 249 regularization_coeff=self._regularization_coeff, 250 num_row_shards=self._num_row_shards, 251 num_col_shards=self._num_col_shards, 252 row_weights=self._row_weights, 253 col_weights=self._col_weights, 254 max_sweeps=self.max_sweeps, 255 use_factors_weights_cache_for_training=self.use_cache, 256 use_gramian_cache_for_training=self.use_cache) 257 258 def test_fit(self): 259 # Row sweep. 260 input_fn = self.input_fn(np_matrix=self.input_matrix, 261 batch_size=self.batch_size, 262 mode=model_fn.ModeKeys.TRAIN, 263 remove_empty_rows_columns=True)[0] 264 self._model.fit(input_fn=input_fn, steps=self.row_steps) 265 row_factors = self._model.get_row_factors() 266 self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3) 267 self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3) 268 269 # Col sweep. 270 # Running fit a second time will resume training from the checkpoint. 271 input_fn = self.input_fn(np_matrix=self.input_matrix, 272 batch_size=self.batch_size, 273 mode=model_fn.ModeKeys.TRAIN, 274 remove_empty_rows_columns=True)[0] 275 self._model.fit(input_fn=input_fn, steps=self.col_steps) 276 col_factors = self._model.get_col_factors() 277 self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3) 278 self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3) 279 self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3) 280 281 def test_predict(self): 282 input_fn = self.input_fn(np_matrix=self.input_matrix, 283 batch_size=self.batch_size, 284 mode=model_fn.ModeKeys.TRAIN, 285 remove_empty_rows_columns=True, 286 )[0] 287 # Project rows 1 and 4 from the input matrix. 288 proj_input_fn = self.input_fn( 289 np_matrix=self.INPUT_MATRIX[[1, 4], :], 290 batch_size=2, 291 mode=model_fn.ModeKeys.INFER, 292 project_row=True, 293 projection_weights=[[0.2, 0.5]])[0] 294 295 self._model.fit(input_fn=input_fn, steps=self.row_steps) 296 projections = self._model.get_projections(proj_input_fn) 297 projected_rows = list(itertools.islice(projections, 2)) 298 299 self.assertAllClose( 300 projected_rows, 301 [self._row_factors_0[1], self._row_factors_1[1]], 302 atol=1e-3) 303 304 # Project columns 5, 3, 1 from the input matrix. 305 proj_input_fn = self.input_fn( 306 np_matrix=self.INPUT_MATRIX[:, [5, 3, 1]], 307 batch_size=3, 308 mode=model_fn.ModeKeys.INFER, 309 project_row=False, 310 projection_weights=[[0.6, 0.4, 0.2]])[0] 311 312 self._model.fit(input_fn=input_fn, steps=self.col_steps) 313 projections = self._model.get_projections(proj_input_fn) 314 projected_cols = list(itertools.islice(projections, 3)) 315 self.assertAllClose( 316 projected_cols, 317 [self._col_factors_2[0], self._col_factors_1[0], 318 self._col_factors_0[1]], 319 atol=1e-3) 320 321 def test_eval(self): 322 # Do a row sweep then evaluate the model on row inputs. 323 # The evaluate function returns the loss of the projected rows, but since 324 # projection is idempotent, the eval loss must match the model loss. 325 input_fn = self.input_fn(np_matrix=self.input_matrix, 326 batch_size=self.batch_size, 327 mode=model_fn.ModeKeys.TRAIN, 328 remove_empty_rows_columns=True, 329 )[0] 330 self._model.fit(input_fn=input_fn, steps=self.row_steps) 331 eval_input_fn_row = self.input_fn(np_matrix=self.input_matrix, 332 batch_size=1, 333 mode=model_fn.ModeKeys.EVAL, 334 project_row=True, 335 remove_empty_rows_columns=True)[0] 336 loss = self._model.evaluate( 337 input_fn=eval_input_fn_row, steps=self._num_rows)['loss'] 338 339 with self.cached_session(): 340 true_loss = self.calculate_loss() 341 342 self.assertNear( 343 loss, true_loss, err=.001, 344 msg="""After row update, eval loss = {}, does not match the true 345 loss = {}.""".format(loss, true_loss)) 346 347 # Do a col sweep then evaluate the model on col inputs. 348 self._model.fit(input_fn=input_fn, steps=self.col_steps) 349 eval_input_fn_col = self.input_fn(np_matrix=self.input_matrix, 350 batch_size=1, 351 mode=model_fn.ModeKeys.EVAL, 352 project_row=False, 353 remove_empty_rows_columns=True)[0] 354 loss = self._model.evaluate( 355 input_fn=eval_input_fn_col, steps=self._num_cols)['loss'] 356 357 with self.cached_session(): 358 true_loss = self.calculate_loss() 359 360 self.assertNear( 361 loss, true_loss, err=.001, 362 msg="""After col update, eval loss = {}, does not match the true 363 loss = {}.""".format(loss, true_loss)) 364 365 366class WALSMatrixFactorizationTestSweeps(WALSMatrixFactorizationTest): 367 368 @property 369 def max_sweeps(self): 370 return 2 371 372 # We set the column steps to None so that we rely only on max_sweeps to stop 373 # training. 374 @property 375 def col_steps(self): 376 return None 377 378 379class WALSMatrixFactorizationTestCached(WALSMatrixFactorizationTest): 380 381 @property 382 def use_cache(self): 383 return True 384 385 386class WALSMatrixFactorizaiontTestPaddedInput(WALSMatrixFactorizationTest): 387 PADDED_INPUT_MATRIX = np.pad( 388 WALSMatrixFactorizationTest.INPUT_MATRIX, 389 [(1, 0), (1, 0)], mode='constant') 390 391 @property 392 def input_matrix(self): 393 return self.PADDED_INPUT_MATRIX 394 395 396class WALSMatrixFactorizationUnsupportedTest(test.TestCase): 397 398 def setUp(self): 399 pass 400 401 def testDistributedWALSUnsupported(self): 402 tf_config = { 403 'cluster': { 404 run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], 405 run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] 406 }, 407 'task': { 408 'type': run_config_lib.TaskType.WORKER, 409 'index': 1 410 } 411 } 412 with test.mock.patch.dict('os.environ', 413 {'TF_CONFIG': json.dumps(tf_config)}): 414 config = run_config.RunConfig() 415 self.assertEqual(config.num_worker_replicas, 2) 416 with self.assertRaises(ValueError): 417 self._model = wals_lib.WALSMatrixFactorization(1, 1, 1, config=config) 418 419 420class SweepHookTest(test.TestCase): 421 422 def test_sweeps(self): 423 is_row_sweep_var = variables.VariableV1(True) 424 is_sweep_done_var = variables.VariableV1(False) 425 init_done = variables.VariableV1(False) 426 row_prep_done = variables.VariableV1(False) 427 col_prep_done = variables.VariableV1(False) 428 row_train_done = variables.VariableV1(False) 429 col_train_done = variables.VariableV1(False) 430 431 init_op = state_ops.assign(init_done, True) 432 row_prep_op = state_ops.assign(row_prep_done, True) 433 col_prep_op = state_ops.assign(col_prep_done, True) 434 row_train_op = state_ops.assign(row_train_done, True) 435 col_train_op = state_ops.assign(col_train_done, True) 436 train_op = control_flow_ops.no_op() 437 switch_op = control_flow_ops.group( 438 state_ops.assign(is_sweep_done_var, False), 439 state_ops.assign(is_row_sweep_var, 440 math_ops.logical_not(is_row_sweep_var))) 441 mark_sweep_done = state_ops.assign(is_sweep_done_var, True) 442 443 with self.cached_session() as sess: 444 sweep_hook = wals_lib._SweepHook( 445 is_row_sweep_var, 446 is_sweep_done_var, 447 init_op, 448 [row_prep_op], 449 [col_prep_op], 450 row_train_op, 451 col_train_op, 452 switch_op) 453 mon_sess = monitored_session._HookedSession(sess, [sweep_hook]) 454 sess.run([variables.global_variables_initializer()]) 455 456 # Row sweep. 457 mon_sess.run(train_op) 458 self.assertTrue(sess.run(init_done), 459 msg='init op not run by the Sweephook') 460 self.assertTrue(sess.run(row_prep_done), 461 msg='row_prep_op not run by the SweepHook') 462 self.assertTrue(sess.run(row_train_done), 463 msg='row_train_op not run by the SweepHook') 464 self.assertTrue( 465 sess.run(is_row_sweep_var), 466 msg='Row sweep is not complete but is_row_sweep_var is False.') 467 # Col sweep. 468 mon_sess.run(mark_sweep_done) 469 mon_sess.run(train_op) 470 self.assertTrue(sess.run(col_prep_done), 471 msg='col_prep_op not run by the SweepHook') 472 self.assertTrue(sess.run(col_train_done), 473 msg='col_train_op not run by the SweepHook') 474 self.assertFalse( 475 sess.run(is_row_sweep_var), 476 msg='Col sweep is not complete but is_row_sweep_var is True.') 477 # Row sweep. 478 mon_sess.run(mark_sweep_done) 479 mon_sess.run(train_op) 480 self.assertTrue( 481 sess.run(is_row_sweep_var), 482 msg='Col sweep is complete but is_row_sweep_var is False.') 483 484 485class StopAtSweepHookTest(test.TestCase): 486 487 def test_stop(self): 488 hook = wals_lib._StopAtSweepHook(last_sweep=10) 489 completed_sweeps = variables.VariableV1( 490 8, name=wals_lib.WALSMatrixFactorization.COMPLETED_SWEEPS) 491 train_op = state_ops.assign_add(completed_sweeps, 1) 492 hook.begin() 493 494 with self.cached_session() as sess: 495 sess.run([variables.global_variables_initializer()]) 496 mon_sess = monitored_session._HookedSession(sess, [hook]) 497 mon_sess.run(train_op) 498 # completed_sweeps is 9 after running train_op. 499 self.assertFalse(mon_sess.should_stop()) 500 mon_sess.run(train_op) 501 # completed_sweeps is 10 after running train_op. 502 self.assertTrue(mon_sess.should_stop()) 503 504 505if __name__ == '__main__': 506 test.main() 507