1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Weighted Alternating Least Squares (WALS) on the tf.learn API.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.factorization.python.ops import factorization_ops 22from tensorflow.contrib.learn.python.learn.estimators import estimator 23from tensorflow.contrib.learn.python.learn.estimators import model_fn 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import control_flow_ops 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import state_ops 30from tensorflow.python.ops import variable_scope 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.summary import summary 33from tensorflow.python.training import session_run_hook 34from tensorflow.python.training import training_util 35 36 37class _SweepHook(session_run_hook.SessionRunHook): 38 """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" 39 40 def __init__(self, is_row_sweep_var, is_sweep_done_var, init_op, 41 row_prep_ops, col_prep_ops, row_train_op, col_train_op, 42 switch_op): 43 """Initializes SweepHook. 44 45 Args: 46 is_row_sweep_var: A Boolean tf.Variable, determines whether we are 47 currently doing a row or column sweep. It is updated by the hook. 48 is_sweep_done_var: A Boolean tf.Variable, determines whether we are 49 starting a new sweep (this is used to determine when to run the prep ops 50 below). 51 init_op: op to be run once before training. This is typically a local 52 initialization op (such as cache initialization). 53 row_prep_ops: A list of TensorFlow ops, to be run before the beginning of 54 each row sweep (and during initialization), in the given order. 55 col_prep_ops: A list of TensorFlow ops, to be run before the beginning of 56 each column sweep (and during initialization), in the given order. 57 row_train_op: A TensorFlow op to be run during row sweeps. 58 col_train_op: A TensorFlow op to be run during column sweeps. 59 switch_op: A TensorFlow op to be run before each sweep. 60 """ 61 self._is_row_sweep_var = is_row_sweep_var 62 self._is_sweep_done_var = is_sweep_done_var 63 self._init_op = init_op 64 self._row_prep_ops = row_prep_ops 65 self._col_prep_ops = col_prep_ops 66 self._row_train_op = row_train_op 67 self._col_train_op = col_train_op 68 self._switch_op = switch_op 69 # Boolean variable that determines whether the init_op has been run. 70 self._is_initialized = False 71 72 def before_run(self, run_context): 73 """Runs the appropriate prep ops, and requests running update ops.""" 74 sess = run_context.session 75 is_sweep_done = sess.run(self._is_sweep_done_var) 76 if not self._is_initialized: 77 logging.info("SweepHook running init op.") 78 sess.run(self._init_op) 79 if is_sweep_done: 80 logging.info("SweepHook starting the next sweep.") 81 sess.run(self._switch_op) 82 is_row_sweep = sess.run(self._is_row_sweep_var) 83 if is_sweep_done or not self._is_initialized: 84 logging.info("SweepHook running prep ops for the {} sweep.".format( 85 "row" if is_row_sweep else "col")) 86 prep_ops = self._row_prep_ops if is_row_sweep else self._col_prep_ops 87 for prep_op in prep_ops: 88 sess.run(prep_op) 89 self._is_initialized = True 90 logging.info("Next fit step starting.") 91 return session_run_hook.SessionRunArgs( 92 fetches=[self._row_train_op if is_row_sweep else self._col_train_op]) 93 94 95class _IncrementGlobalStepHook(session_run_hook.SessionRunHook): 96 """Hook that increments the global step.""" 97 98 def __init__(self): 99 global_step = training_util.get_global_step() 100 if global_step: 101 self._global_step_incr_op = state_ops.assign_add( 102 global_step, 1, name="global_step_incr").op 103 else: 104 self._global_step_incr_op = None 105 106 def before_run(self, run_context): 107 if self._global_step_incr_op: 108 run_context.session.run(self._global_step_incr_op) 109 110 111class _StopAtSweepHook(session_run_hook.SessionRunHook): 112 """Hook that requests stop at a given sweep.""" 113 114 def __init__(self, last_sweep): 115 """Initializes a `StopAtSweepHook`. 116 117 This hook requests stop at a given sweep. Relies on the tensor named 118 COMPLETED_SWEEPS in the default graph. 119 120 Args: 121 last_sweep: Integer, number of the last sweep to run. 122 """ 123 self._last_sweep = last_sweep 124 125 def begin(self): 126 try: 127 self._completed_sweeps_var = ops.get_default_graph().get_tensor_by_name( 128 WALSMatrixFactorization.COMPLETED_SWEEPS + ":0") 129 except KeyError: 130 raise RuntimeError(WALSMatrixFactorization.COMPLETED_SWEEPS + 131 " counter should be created to use StopAtSweepHook.") 132 133 def before_run(self, run_context): 134 return session_run_hook.SessionRunArgs(self._completed_sweeps_var) 135 136 def after_run(self, run_context, run_values): 137 completed_sweeps = run_values.results 138 if completed_sweeps >= self._last_sweep: 139 run_context.request_stop() 140 141 142def _wals_factorization_model_function(features, labels, mode, params): 143 """Model function for the WALSFactorization estimator. 144 145 Args: 146 features: Dictionary of features. See WALSMatrixFactorization. 147 labels: Must be None. 148 mode: A model_fn.ModeKeys object. 149 params: Dictionary of parameters containing arguments passed to the 150 WALSMatrixFactorization constructor. 151 152 Returns: 153 A ModelFnOps object. 154 155 Raises: 156 ValueError: If `mode` is not recognized. 157 """ 158 assert labels is None 159 use_factors_weights_cache = (params["use_factors_weights_cache_for_training"] 160 and mode == model_fn.ModeKeys.TRAIN) 161 use_gramian_cache = (params["use_gramian_cache_for_training"] and 162 mode == model_fn.ModeKeys.TRAIN) 163 max_sweeps = params["max_sweeps"] 164 model = factorization_ops.WALSModel( 165 params["num_rows"], 166 params["num_cols"], 167 params["embedding_dimension"], 168 unobserved_weight=params["unobserved_weight"], 169 regularization=params["regularization_coeff"], 170 row_init=params["row_init"], 171 col_init=params["col_init"], 172 num_row_shards=params["num_row_shards"], 173 num_col_shards=params["num_col_shards"], 174 row_weights=params["row_weights"], 175 col_weights=params["col_weights"], 176 use_factors_weights_cache=use_factors_weights_cache, 177 use_gramian_cache=use_gramian_cache) 178 179 # Get input rows and cols. We either update rows or columns depending on 180 # the value of row_sweep, which is maintained using a session hook. 181 input_rows = features[WALSMatrixFactorization.INPUT_ROWS] 182 input_cols = features[WALSMatrixFactorization.INPUT_COLS] 183 184 # TRAIN mode: 185 if mode == model_fn.ModeKeys.TRAIN: 186 # Training consists of the following ops (controlled using a SweepHook). 187 # Before a row sweep: 188 # row_update_prep_gramian_op 189 # initialize_row_update_op 190 # During a row sweep: 191 # update_row_factors_op 192 # Before a col sweep: 193 # col_update_prep_gramian_op 194 # initialize_col_update_op 195 # During a col sweep: 196 # update_col_factors_op 197 198 is_row_sweep_var = variable_scope.variable( 199 True, 200 trainable=False, 201 name="is_row_sweep", 202 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 203 is_sweep_done_var = variable_scope.variable( 204 False, 205 trainable=False, 206 name="is_sweep_done", 207 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 208 completed_sweeps_var = variable_scope.variable( 209 0, 210 trainable=False, 211 name=WALSMatrixFactorization.COMPLETED_SWEEPS, 212 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 213 loss_var = variable_scope.variable( 214 0., 215 trainable=False, 216 name=WALSMatrixFactorization.LOSS, 217 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 218 # The root weighted squared error = 219 # \\(\sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )\\) 220 rwse_var = variable_scope.variable( 221 0., 222 trainable=False, 223 name=WALSMatrixFactorization.RWSE, 224 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 225 226 summary.scalar("loss", loss_var) 227 summary.scalar("root_weighted_squared_error", rwse_var) 228 summary.scalar("completed_sweeps", completed_sweeps_var) 229 230 def create_axis_ops(sp_input, num_items, update_fn, axis_name): 231 """Creates book-keeping and training ops for a given axis. 232 233 Args: 234 sp_input: A SparseTensor corresponding to the row or column batch. 235 num_items: An integer, the total number of items of this axis. 236 update_fn: A function that takes one argument (`sp_input`), and that 237 returns a tuple of 238 * new_factors: A float Tensor of the factor values after update. 239 * update_op: a TensorFlow op which updates the factors. 240 * loss: A float Tensor, the unregularized loss. 241 * reg_loss: A float Tensor, the regularization loss. 242 * sum_weights: A float Tensor, the sum of factor weights. 243 axis_name: A string that specifies the name of the axis. 244 245 Returns: 246 A tuple consisting of: 247 * reset_processed_items_op: A TensorFlow op, to be run before the 248 beginning of any sweep. It marks all items as not-processed. 249 * axis_train_op: A Tensorflow op, to be run during this axis' sweeps. 250 """ 251 processed_items_init = array_ops.fill(dims=[num_items], value=False) 252 with ops.colocate_with(processed_items_init): 253 processed_items = variable_scope.variable( 254 processed_items_init, 255 collections=[ops.GraphKeys.GLOBAL_VARIABLES], 256 trainable=False, 257 name="processed_" + axis_name) 258 _, update_op, loss, reg, sum_weights = update_fn(sp_input) 259 input_indices = sp_input.indices[:, 0] 260 with ops.control_dependencies([ 261 update_op, 262 state_ops.assign(loss_var, loss + reg), 263 state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]): 264 with ops.colocate_with(processed_items): 265 update_processed_items = state_ops.scatter_update( 266 processed_items, 267 input_indices, 268 array_ops.ones_like(input_indices, dtype=dtypes.bool), 269 name="update_processed_{}_indices".format(axis_name)) 270 with ops.control_dependencies([update_processed_items]): 271 is_sweep_done = math_ops.reduce_all(processed_items) 272 axis_train_op = control_flow_ops.group( 273 state_ops.assign(is_sweep_done_var, is_sweep_done), 274 state_ops.assign_add( 275 completed_sweeps_var, 276 math_ops.cast(is_sweep_done, dtypes.int32)), 277 name="{}_sweep_train_op".format(axis_name)) 278 return processed_items.initializer, axis_train_op 279 280 reset_processed_rows_op, row_train_op = create_axis_ops( 281 input_rows, 282 params["num_rows"], 283 lambda x: model.update_row_factors(sp_input=x, transpose_input=False), 284 "rows") 285 reset_processed_cols_op, col_train_op = create_axis_ops( 286 input_cols, 287 params["num_cols"], 288 lambda x: model.update_col_factors(sp_input=x, transpose_input=True), 289 "cols") 290 switch_op = control_flow_ops.group( 291 state_ops.assign( 292 is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)), 293 reset_processed_rows_op, 294 reset_processed_cols_op, 295 name="sweep_switch_op") 296 row_prep_ops = [ 297 model.row_update_prep_gramian_op, model.initialize_row_update_op] 298 col_prep_ops = [ 299 model.col_update_prep_gramian_op, model.initialize_col_update_op] 300 init_op = model.worker_init 301 sweep_hook = _SweepHook( 302 is_row_sweep_var, is_sweep_done_var, init_op, 303 row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op) 304 global_step_hook = _IncrementGlobalStepHook() 305 training_hooks = [sweep_hook, global_step_hook] 306 if max_sweeps is not None: 307 training_hooks.append(_StopAtSweepHook(max_sweeps)) 308 309 return model_fn.ModelFnOps( 310 mode=model_fn.ModeKeys.TRAIN, 311 predictions={}, 312 loss=loss_var, 313 eval_metric_ops={}, 314 train_op=control_flow_ops.no_op(), 315 training_hooks=training_hooks) 316 317 # INFER mode 318 elif mode == model_fn.ModeKeys.INFER: 319 projection_weights = features.get( 320 WALSMatrixFactorization.PROJECTION_WEIGHTS) 321 322 def get_row_projection(): 323 return model.project_row_factors( 324 sp_input=input_rows, 325 projection_weights=projection_weights, 326 transpose_input=False) 327 328 def get_col_projection(): 329 return model.project_col_factors( 330 sp_input=input_cols, 331 projection_weights=projection_weights, 332 transpose_input=True) 333 334 predictions = { 335 WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond( 336 features[WALSMatrixFactorization.PROJECT_ROW], 337 get_row_projection, 338 get_col_projection) 339 } 340 341 return model_fn.ModelFnOps( 342 mode=model_fn.ModeKeys.INFER, 343 predictions=predictions, 344 loss=None, 345 eval_metric_ops={}, 346 train_op=control_flow_ops.no_op(), 347 training_hooks=[]) 348 349 # EVAL mode 350 elif mode == model_fn.ModeKeys.EVAL: 351 def get_row_loss(): 352 _, _, loss, reg, _ = model.update_row_factors( 353 sp_input=input_rows, transpose_input=False) 354 return loss + reg 355 def get_col_loss(): 356 _, _, loss, reg, _ = model.update_col_factors( 357 sp_input=input_cols, transpose_input=True) 358 return loss + reg 359 loss = control_flow_ops.cond( 360 features[WALSMatrixFactorization.PROJECT_ROW], 361 get_row_loss, 362 get_col_loss) 363 return model_fn.ModelFnOps( 364 mode=model_fn.ModeKeys.EVAL, 365 predictions={}, 366 loss=loss, 367 eval_metric_ops={}, 368 train_op=control_flow_ops.no_op(), 369 training_hooks=[]) 370 371 else: 372 raise ValueError("mode=%s is not recognized." % str(mode)) 373 374 375class WALSMatrixFactorization(estimator.Estimator): 376 """An Estimator for Weighted Matrix Factorization, using the WALS method. 377 378 WALS (Weighted Alternating Least Squares) is an algorithm for weighted matrix 379 factorization. It computes a low-rank approximation of a given sparse (n x m) 380 matrix `A`, by a product of two matrices, `U * V^T`, where `U` is a (n x k) 381 matrix and `V` is a (m x k) matrix. Here k is the rank of the approximation, 382 also called the embedding dimension. We refer to `U` as the row factors, and 383 `V` as the column factors. 384 See tensorflow/contrib/factorization/g3doc/wals.md for the precise problem 385 formulation. 386 387 The training proceeds in sweeps: during a row_sweep, we fix `V` and solve for 388 `U`. During a column sweep, we fix `U` and solve for `V`. Each one of these 389 problems is an unconstrained quadratic minimization problem and can be solved 390 exactly (it can also be solved in mini-batches, since the solution decouples 391 across rows of each matrix). 392 The alternating between sweeps is achieved by using a hook during training, 393 which is responsible for keeping track of the sweeps and running preparation 394 ops at the beginning of each sweep. It also updates the global_step variable, 395 which keeps track of the number of batches processed since the beginning of 396 training. 397 The current implementation assumes that the training is run on a single 398 machine, and will fail if `config.num_worker_replicas` is not equal to one. 399 Training is done by calling `self.fit(input_fn=input_fn)`, where `input_fn` 400 provides two tensors: one for rows of the input matrix, and one for rows of 401 the transposed input matrix (i.e. columns of the original matrix). Note that 402 during a row sweep, only row batches are processed (ignoring column batches) 403 and vice-versa. 404 Also note that every row (respectively every column) of the input matrix 405 must be processed at least once for the sweep to be considered complete. In 406 particular, training will not make progress if some rows are not generated by 407 the `input_fn`. 408 409 For prediction, given a new set of input rows `A'`, we compute a corresponding 410 set of row factors `U'`, such that `U' * V^T` is a good approximation of `A'`. 411 We call this operation a row projection. A similar operation is defined for 412 columns. Projection is done by calling 413 `self.get_projections(input_fn=input_fn)`, where `input_fn` satisfies the 414 constraints given below. 415 416 The input functions must satisfy the following constraints: Calling `input_fn` 417 must return a tuple `(features, labels)` where `labels` is None, and 418 `features` is a dict containing the following keys: 419 420 TRAIN: 421 * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix). 422 Rows of the input matrix to process (or to project). 423 * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix). 424 Columns of the input matrix to process (or to project), transposed. 425 426 INFER: 427 * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix). 428 Rows to project. 429 * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix). 430 Columns to project. 431 * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project 432 the rows or columns. 433 * `WALSMatrixFactorization.PROJECTION_WEIGHTS` (Optional): float32 Tensor 434 (vector). The weights to use in the projection. 435 436 EVAL: 437 * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix). 438 Rows to project. 439 * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix). 440 Columns to project. 441 * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project 442 the rows or columns. 443 """ 444 # Keys to be used in model_fn 445 # Features keys 446 INPUT_ROWS = "input_rows" 447 INPUT_COLS = "input_cols" 448 PROJECT_ROW = "project_row" 449 PROJECTION_WEIGHTS = "projection_weights" 450 # Predictions key 451 PROJECTION_RESULT = "projection" 452 # Name of the completed_sweeps variable 453 COMPLETED_SWEEPS = "completed_sweeps" 454 # Name of the loss variable 455 LOSS = "WALS_loss" 456 # Name of the Root Weighted Squared Error variable 457 RWSE = "WALS_RWSE" 458 459 def __init__(self, 460 num_rows, 461 num_cols, 462 embedding_dimension, 463 unobserved_weight=0.1, 464 regularization_coeff=None, 465 row_init="random", 466 col_init="random", 467 num_row_shards=1, 468 num_col_shards=1, 469 row_weights=1, 470 col_weights=1, 471 use_factors_weights_cache_for_training=True, 472 use_gramian_cache_for_training=True, 473 max_sweeps=None, 474 model_dir=None, 475 config=None): 476 r"""Creates a model for matrix factorization using the WALS method. 477 478 Args: 479 num_rows: Total number of rows for input matrix. 480 num_cols: Total number of cols for input matrix. 481 embedding_dimension: Dimension to use for the factors. 482 unobserved_weight: Weight of the unobserved entries of matrix. 483 regularization_coeff: Weight of the L2 regularization term. Defaults to 484 None, in which case the problem is not regularized. 485 row_init: Initializer for row factor. Must be either: 486 - A tensor: The row factor matrix is initialized to this tensor, 487 - A numpy constant, 488 - "random": The rows are initialized using a normal distribution. 489 col_init: Initializer for column factor. See row_init. 490 num_row_shards: Number of shards to use for the row factors. 491 num_col_shards: Number of shards to use for the column factors. 492 row_weights: Must be in one of the following three formats: 493 - None: In this case, the weight of every entry is the unobserved_weight 494 and the problem simplifies to ALS. Note that, in this case, 495 col_weights must also be set to "None". 496 - List of lists of non-negative scalars, of the form 497 \\([[w_0, w_1, ...], [w_k, ... ], [...]]\\), 498 where the number of inner lists equal to the number of row factor 499 shards and the elements in each inner list are the weights for the 500 rows of that shard. In this case, 501 \\(w_ij = unonbserved_weight + row_weights[i] * col_weights[j]\\). 502 - A non-negative scalar: This value is used for all row weights. 503 Note that it is allowed to have row_weights as a list and col_weights 504 as a scalar, or vice-versa. 505 col_weights: See row_weights. 506 use_factors_weights_cache_for_training: Boolean, whether the factors and 507 weights will be cached on the workers before the updates start, during 508 training. Defaults to True. 509 Note that caching is disabled during prediction. 510 use_gramian_cache_for_training: Boolean, whether the Gramians will be 511 cached on the workers before the updates start, during training. 512 Defaults to True. Note that caching is disabled during prediction. 513 max_sweeps: integer, optional. Specifies the number of sweeps for which 514 to train the model, where a sweep is defined as a full update of all the 515 row factors (resp. column factors). 516 If `steps` or `max_steps` is also specified in model.fit(), training 517 stops when either of the steps condition or sweeps condition is met. 518 model_dir: The directory to save the model results and log files. 519 config: A Configuration object. See Estimator. 520 521 Raises: 522 ValueError: If config.num_worker_replicas is strictly greater than one. 523 The current implementation only supports running on a single worker. 524 """ 525 # TODO(walidk): Support power-law based weight computation. 526 # TODO(walidk): Add factor lookup by indices, with caching. 527 # TODO(walidk): Support caching during prediction. 528 # TODO(walidk): Provide input pipelines that handle missing rows. 529 530 params = { 531 "num_rows": 532 num_rows, 533 "num_cols": 534 num_cols, 535 "embedding_dimension": 536 embedding_dimension, 537 "unobserved_weight": 538 unobserved_weight, 539 "regularization_coeff": 540 regularization_coeff, 541 "row_init": 542 row_init, 543 "col_init": 544 col_init, 545 "num_row_shards": 546 num_row_shards, 547 "num_col_shards": 548 num_col_shards, 549 "row_weights": 550 row_weights, 551 "col_weights": 552 col_weights, 553 "max_sweeps": 554 max_sweeps, 555 "use_factors_weights_cache_for_training": 556 use_factors_weights_cache_for_training, 557 "use_gramian_cache_for_training": 558 use_gramian_cache_for_training 559 } 560 self._row_factors_names = [ 561 "row_factors_shard_%d" % i for i in range(num_row_shards) 562 ] 563 self._col_factors_names = [ 564 "col_factors_shard_%d" % i for i in range(num_col_shards) 565 ] 566 567 super(WALSMatrixFactorization, self).__init__( 568 model_fn=_wals_factorization_model_function, 569 params=params, 570 model_dir=model_dir, 571 config=config) 572 573 if self._config is not None and self._config.num_worker_replicas > 1: 574 raise ValueError("WALSMatrixFactorization must be run on a single worker " 575 "replica.") 576 577 def get_row_factors(self): 578 """Returns the row factors of the model, loading them from checkpoint. 579 580 Should only be run after training. 581 582 Returns: 583 A list of the row factors of the model. 584 """ 585 return [self.get_variable_value(name) for name in self._row_factors_names] 586 587 def get_col_factors(self): 588 """Returns the column factors of the model, loading them from checkpoint. 589 590 Should only be run after training. 591 592 Returns: 593 A list of the column factors of the model. 594 """ 595 return [self.get_variable_value(name) for name in self._col_factors_names] 596 597 def get_projections(self, input_fn): 598 """Computes the projections of the rows or columns given in input_fn. 599 600 Runs predict() with the given input_fn, and returns the results. Should only 601 be run after training. 602 603 Args: 604 input_fn: Input function which specifies the rows or columns to project. 605 Returns: 606 A generator of the projected factors. 607 """ 608 return (result[WALSMatrixFactorization.PROJECTION_RESULT] 609 for result in self.predict(input_fn=input_fn)) 610