1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base Estimator class (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import abc 27import collections 28import copy 29import os 30import tempfile 31 32import numpy as np 33import six 34 35from google.protobuf import message 36from tensorflow.contrib import layers 37from tensorflow.contrib.framework import deprecated 38from tensorflow.contrib.framework import deprecated_args 39from tensorflow.contrib.framework import list_variables 40from tensorflow.contrib.framework import load_variable 41from tensorflow.contrib.learn.python.learn import evaluable 42from tensorflow.contrib.learn.python.learn import metric_spec 43from tensorflow.contrib.learn.python.learn import monitors as monitor_lib 44from tensorflow.contrib.learn.python.learn import trainable 45from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn 46from tensorflow.contrib.learn.python.learn.estimators import constants 47from tensorflow.contrib.learn.python.learn.estimators import metric_key 48from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib 49from tensorflow.contrib.learn.python.learn.estimators import run_config 50from tensorflow.contrib.learn.python.learn.estimators import tensor_signature 51from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError 52from tensorflow.contrib.learn.python.learn.learn_io import data_feeder 53from tensorflow.contrib.learn.python.learn.utils import export 54from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils 55from tensorflow.contrib.meta_graph_transform import meta_graph_transform 56from tensorflow.contrib.training.python.training import evaluation 57from tensorflow.core.framework import summary_pb2 58from tensorflow.core.protobuf import config_pb2 59from tensorflow.python.client import session as tf_session 60from tensorflow.python.framework import ops 61from tensorflow.python.framework import random_seed 62from tensorflow.python.framework import sparse_tensor 63from tensorflow.python.framework import tensor_util 64from tensorflow.python.ops import control_flow_ops 65from tensorflow.python.ops import lookup_ops 66from tensorflow.python.ops import metrics as metrics_lib 67from tensorflow.python.ops import resources 68from tensorflow.python.ops import variables 69from tensorflow.python.platform import gfile 70from tensorflow.python.platform import tf_logging as logging 71from tensorflow.python.saved_model import builder as saved_model_builder 72from tensorflow.python.saved_model import tag_constants 73from tensorflow.python.summary import summary as core_summary 74from tensorflow.python.training import basic_session_run_hooks 75from tensorflow.python.training import checkpoint_management 76from tensorflow.python.training import device_setter 77from tensorflow.python.training import monitored_session 78from tensorflow.python.training import saver 79from tensorflow.python.training import training_util 80from tensorflow.python.util import compat 81from tensorflow.python.util import tf_decorator 82from tensorflow.python.util import tf_inspect 83 84AS_ITERABLE_DATE = '2016-09-15' 85AS_ITERABLE_INSTRUCTIONS = ( 86 'The default behavior of predict() is changing. The default value for\n' 87 'as_iterable will change to True, and then the flag will be removed\n' 88 'altogether. The behavior of this flag is described below.') 89SCIKIT_DECOUPLE_DATE = '2016-12-01' 90SCIKIT_DECOUPLE_INSTRUCTIONS = ( 91 'Estimator is decoupled from Scikit Learn interface by moving into\n' 92 'separate class SKCompat. Arguments x, y and batch_size are only\n' 93 'available in the SKCompat class, Estimator will only accept input_fn.\n' 94 'Example conversion:\n' 95 ' est = Estimator(...) -> est = SKCompat(Estimator(...))') 96 97 98def _verify_input_args(x, y, input_fn, feed_fn, batch_size): 99 """Verifies validity of co-existence of input arguments.""" 100 if input_fn is None: 101 if x is None: 102 raise ValueError('Either x or input_fn must be provided.') 103 104 if tensor_util.is_tensor(x) or y is not None and tensor_util.is_tensor(y): 105 raise ValueError('Inputs cannot be tensors. Please provide input_fn.') 106 107 if feed_fn is not None: 108 raise ValueError('Can not provide both feed_fn and x or y.') 109 else: 110 if (x is not None) or (y is not None): 111 raise ValueError('Can not provide both input_fn and x or y.') 112 if batch_size is not None: 113 raise ValueError('Can not provide both input_fn and batch_size.') 114 115 116def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1): 117 """Make inputs into input and feed functions. 118 119 Args: 120 x: Numpy, Pandas or Dask matrix or iterable. 121 y: Numpy, Pandas or Dask matrix or iterable. 122 input_fn: Pre-defined input function for training data. 123 feed_fn: Pre-defined data feeder function. 124 batch_size: Size to split data into parts. Must be >= 1. 125 shuffle: Whether to shuffle the inputs. 126 epochs: Number of epochs to run. 127 128 Returns: 129 Data input and feeder function based on training data. 130 131 Raises: 132 ValueError: Only one of `(x & y)` or `input_fn` must be provided. 133 """ 134 _verify_input_args(x, y, input_fn, feed_fn, batch_size) 135 if input_fn is not None: 136 return input_fn, feed_fn 137 df = data_feeder.setup_train_data_feeder( 138 x, 139 y, 140 n_classes=None, 141 batch_size=batch_size, 142 shuffle=shuffle, 143 epochs=epochs) 144 return df.input_builder, df.get_feed_dict_fn() 145 146 147@deprecated(None, 'Please specify feature columns explicitly.') 148def infer_real_valued_columns_from_input_fn(input_fn): 149 """Creates `FeatureColumn` objects for inputs defined by `input_fn`. 150 151 This interprets all inputs as dense, fixed-length float values. This creates 152 a local graph in which it calls `input_fn` to build the tensors, then discards 153 it. 154 155 Args: 156 input_fn: Input function returning a tuple of: 157 features - Dictionary of string feature name to `Tensor` or `Tensor`. 158 labels - `Tensor` of label values. 159 160 Returns: 161 List of `FeatureColumn` objects. 162 """ 163 with ops.Graph().as_default(): 164 features, _ = input_fn() 165 return layers.infer_real_valued_columns(features) 166 167 168@deprecated(None, 'Please specify feature columns explicitly.') 169def infer_real_valued_columns_from_input(x): 170 """Creates `FeatureColumn` objects for inputs defined by input `x`. 171 172 This interprets all inputs as dense, fixed-length float values. 173 174 Args: 175 x: Real-valued matrix of shape [n_samples, n_features...]. Can be 176 iterator that returns arrays of features. 177 178 Returns: 179 List of `FeatureColumn` objects. 180 """ 181 input_fn, _ = _get_input_fn( 182 x=x, y=None, input_fn=None, feed_fn=None, batch_size=None) 183 return infer_real_valued_columns_from_input_fn(input_fn) 184 185 186def _model_fn_args(fn): 187 """Get argument names for function-like object. 188 189 Args: 190 fn: Function, or function-like object (e.g., result of `functools.partial`). 191 192 Returns: 193 `tuple` of string argument names. 194 195 Raises: 196 ValueError: if partial function has positionally bound arguments 197 """ 198 _, fn = tf_decorator.unwrap(fn) 199 if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'): 200 # Handle functools.partial and similar objects. 201 return tuple([ 202 arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):] 203 if arg not in set(fn.keywords.keys()) 204 ]) 205 # Handle function. 206 return tuple(tf_inspect.getargspec(fn).args) 207 208 209def _get_replica_device_setter(config): 210 """Creates a replica device setter if required. 211 212 Args: 213 config: A RunConfig instance. 214 215 Returns: 216 A replica device setter, or None. 217 """ 218 ps_ops = [ 219 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', 220 'MutableHashTableV2', 'MutableHashTableOfTensors', 221 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable', 222 'MutableDenseHashTableV2', 'VarHandleOp' 223 ] 224 225 if config.task_type: 226 worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id) 227 else: 228 worker_device = '/job:worker' 229 230 if config.num_ps_replicas > 0: 231 return device_setter.replica_device_setter( 232 ps_tasks=config.num_ps_replicas, 233 worker_device=worker_device, 234 merge_devices=True, 235 ps_ops=ps_ops, 236 cluster=config.cluster_spec) 237 else: 238 return None 239 240 241def _make_metrics_ops(metrics, features, labels, predictions): 242 """Add metrics based on `features`, `labels`, and `predictions`. 243 244 `metrics` contains a specification for how to run metrics. It is a dict 245 mapping friendly names to either `MetricSpec` objects, or directly to a metric 246 function (assuming that `predictions` and `labels` are single tensors), or to 247 `(pred_name, metric)` `tuple`, which passes `predictions[pred_name]` and 248 `labels` to `metric` (assuming `labels` is a single tensor). 249 250 Users are encouraged to use `MetricSpec` objects, which are more flexible and 251 cleaner. They also lead to clearer errors. 252 253 Args: 254 metrics: A dict mapping names to metrics specification, for example 255 `MetricSpec` objects. 256 features: A dict of tensors returned from an input_fn as features/inputs. 257 labels: A single tensor or a dict of tensors returned from an input_fn as 258 labels. 259 predictions: A single tensor or a dict of tensors output from a model as 260 predictions. 261 262 Returns: 263 A dict mapping the friendly given in `metrics` to the result of calling the 264 given metric function. 265 266 Raises: 267 ValueError: If metrics specifications do not work with the type of 268 `features`, `labels`, or `predictions` provided. Mostly, a dict is given 269 but no pred_name specified. 270 """ 271 metrics = metrics or {} 272 273 # If labels is a dict with a single key, unpack into a single tensor. 274 labels_tensor_or_dict = labels 275 if isinstance(labels, dict) and len(labels) == 1: 276 labels_tensor_or_dict = labels[list(labels.keys())[0]] 277 278 result = {} 279 # Iterate in lexicographic order, so the graph is identical among runs. 280 for name, metric in sorted(six.iteritems(metrics)): 281 if isinstance(metric, metric_spec.MetricSpec): 282 result[name] = metric.create_metric_ops(features, labels, predictions) 283 continue 284 285 # TODO(b/31229024): Remove the rest of this loop 286 logging.warning('Please specify metrics using MetricSpec. Using bare ' 287 'functions or (key, fn) tuples is deprecated and support ' 288 'for it will be removed on Oct 1, 2016.') 289 290 if isinstance(name, tuple): 291 # Multi-head metrics. 292 if len(name) != 2: 293 raise ValueError('Invalid metric for {}. It returned a tuple with ' 294 'len {}, expected 2.'.format(name, len(name))) 295 if not isinstance(predictions, dict): 296 raise ValueError('Metrics passed provide (name, prediction), ' 297 'but predictions are not dict. ' 298 'Metrics: %s, Predictions: %s.' % (metrics, 299 predictions)) 300 # Here are two options: labels are single Tensor or a dict. 301 if isinstance(labels, dict) and name[1] in labels: 302 # If labels are dict and the prediction name is in it, apply metric. 303 result[name[0]] = metric(predictions[name[1]], labels[name[1]]) 304 else: 305 # Otherwise pass the labels to the metric. 306 result[name[0]] = metric(predictions[name[1]], labels_tensor_or_dict) 307 else: 308 # Single head metrics. 309 if isinstance(predictions, dict): 310 raise ValueError('Metrics passed provide only name, no prediction, ' 311 'but predictions are dict. ' 312 'Metrics: %s, Labels: %s.' % (metrics, 313 labels_tensor_or_dict)) 314 result[name] = metric(predictions, labels_tensor_or_dict) 315 return result 316 317 318def _dict_to_str(dictionary): 319 """Get a `str` representation of a `dict`. 320 321 Args: 322 dictionary: The `dict` to be represented as `str`. 323 324 Returns: 325 A `str` representing the `dictionary`. 326 """ 327 results = [] 328 for k, v in sorted(dictionary.items()): 329 if isinstance(v, float) or isinstance(v, np.float32) or isinstance( 330 v, int) or isinstance(v, np.int64) or isinstance(v, np.int32): 331 results.append('%s = %s' % (k, v)) 332 else: 333 results.append('Type of %s = %s' % (k, type(v))) 334 335 return ', '.join(results) 336 337 338def _write_dict_to_summary(output_dir, dictionary, current_global_step): 339 """Writes a `dict` into summary file in given output directory. 340 341 Args: 342 output_dir: `str`, directory to write the summary file in. 343 dictionary: the `dict` to be written to summary file. 344 current_global_step: `int`, the current global step. 345 """ 346 logging.info('Saving dict for global step %d: %s', current_global_step, 347 _dict_to_str(dictionary)) 348 summary_writer = core_summary.FileWriterCache.get(output_dir) 349 summary_proto = summary_pb2.Summary() 350 for key in dictionary: 351 if dictionary[key] is None: 352 continue 353 if key == 'global_step': 354 continue 355 if (isinstance(dictionary[key], np.float32) or 356 isinstance(dictionary[key], float)): 357 summary_proto.value.add(tag=key, simple_value=float(dictionary[key])) 358 elif (isinstance(dictionary[key], np.int64) or 359 isinstance(dictionary[key], np.int32) or 360 isinstance(dictionary[key], int)): 361 summary_proto.value.add(tag=key, simple_value=int(dictionary[key])) 362 elif isinstance(dictionary[key], six.string_types): 363 try: 364 summ = summary_pb2.Summary.FromString(dictionary[key]) 365 for i, _ in enumerate(summ.value): 366 summ.value[i].tag = key 367 summary_proto.value.extend(summ.value) 368 except message.DecodeError: 369 logging.warn('Skipping summary for %s, cannot parse string to Summary.', 370 key) 371 continue 372 elif isinstance(dictionary[key], np.ndarray): 373 value = summary_proto.value.add() 374 value.tag = key 375 value.node_name = key 376 tensor_proto = tensor_util.make_tensor_proto(dictionary[key]) 377 value.tensor.CopyFrom(tensor_proto) 378 logging.info( 379 'Summary for np.ndarray is not visible in Tensorboard by default. ' 380 'Consider using a Tensorboard plugin for visualization (see ' 381 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md' 382 ' for more information).') 383 else: 384 logging.warn( 385 'Skipping summary for %s, must be a float, np.float32, np.int64, ' 386 'np.int32 or int or np.ndarray or a serialized string of Summary.', 387 key) 388 summary_writer.add_summary(summary_proto, current_global_step) 389 summary_writer.flush() 390 391 392GraphRewriteSpec = collections.namedtuple('GraphRewriteSpec', 393 ['tags', 'transforms']) 394 395 396class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, 397 trainable.Trainable): 398 """Abstract BaseEstimator class to train and evaluate TensorFlow models. 399 400 THIS CLASS IS DEPRECATED. See 401 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 402 for general migration instructions. 403 404 Users should not instantiate or subclass this class. Instead, use an 405 `Estimator`. 406 """ 407 408 # Note that for Google users, this is overridden with 409 # learn_runner.EstimatorConfig. 410 # TODO(wicke): Remove this once launcher takes over config functionality 411 _Config = run_config.RunConfig # pylint: disable=invalid-name 412 413 @deprecated(None, 'Please replace uses of any Estimator from tf.contrib.learn' 414 ' with an Estimator from tf.estimator.*') 415 def __init__(self, model_dir=None, config=None): 416 """Initializes a BaseEstimator instance. 417 418 Args: 419 model_dir: Directory to save model parameters, graph and etc. This can 420 also be used to load checkpoints from the directory into a estimator to 421 continue training a previously saved model. If `None`, the model_dir in 422 `config` will be used if set. If both are set, they must be same. 423 config: A RunConfig instance. 424 """ 425 # Create a run configuration. 426 if config is None: 427 self._config = BaseEstimator._Config() 428 logging.info('Using default config.') 429 else: 430 self._config = config 431 432 if self._config.session_config is None: 433 self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) 434 else: 435 self._session_config = self._config.session_config 436 437 # Model directory. 438 if (model_dir is not None) and (self._config.model_dir is not None): 439 if model_dir != self._config.model_dir: 440 # TODO(b/9965722): remove this suppression after it is no longer 441 # necessary. 442 # pylint: disable=g-doc-exception 443 raise ValueError( 444 'model_dir are set both in constructor and RunConfig, but with ' 445 "different values. In constructor: '{}', in RunConfig: " 446 "'{}' ".format(model_dir, self._config.model_dir)) 447 # pylint: enable=g-doc-exception 448 449 self._model_dir = model_dir or self._config.model_dir 450 if self._model_dir is None: 451 self._model_dir = tempfile.mkdtemp() 452 logging.warning('Using temporary folder as model directory: %s', 453 self._model_dir) 454 if self._config.model_dir is None: 455 self._config = self._config.replace(model_dir=self._model_dir) 456 logging.info('Using config: %s', str(vars(self._config))) 457 458 # Set device function depending if there are replicas or not. 459 self._device_fn = _get_replica_device_setter(self._config) 460 461 # Features and labels TensorSignature objects. 462 # TODO(wicke): Rename these to something more descriptive 463 self._features_info = None 464 self._labels_info = None 465 466 self._graph = None 467 468 @property 469 def config(self): 470 # TODO(wicke): make RunConfig immutable, and then return it without a copy. 471 return copy.deepcopy(self._config) 472 473 @property 474 def model_fn(self): 475 """Returns the model_fn which is bound to self.params. 476 477 Returns: 478 The model_fn with the following signature: 479 `def model_fn(features, labels, mode, metrics)` 480 """ 481 482 def public_model_fn(features, labels, mode, config): 483 return self._call_model_fn(features, labels, mode, config=config) 484 485 return public_model_fn 486 487 @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 488 ('x', None), ('y', None), ('batch_size', None)) 489 def fit(self, 490 x=None, 491 y=None, 492 input_fn=None, 493 steps=None, 494 batch_size=None, 495 monitors=None, 496 max_steps=None): 497 # pylint: disable=g-doc-args,g-doc-return-or-yield 498 """See `Trainable`. 499 500 Raises: 501 ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`. 502 ValueError: If both `steps` and `max_steps` are not `None`. 503 """ 504 if (steps is not None) and (max_steps is not None): 505 raise ValueError('Can not provide both steps and max_steps.') 506 _verify_input_args(x, y, input_fn, None, batch_size) 507 if x is not None: 508 SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors) 509 return self 510 511 if max_steps is not None: 512 try: 513 start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP) 514 if max_steps <= start_step: 515 logging.info('Skipping training since max_steps has already saved.') 516 return self 517 except: # pylint: disable=bare-except 518 pass 519 520 hooks = monitor_lib.replace_monitors_with_hooks(monitors, self) 521 if steps is not None or max_steps is not None: 522 hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps)) 523 524 loss = self._train_model(input_fn=input_fn, hooks=hooks) 525 logging.info('Loss for final step: %s.', loss) 526 return self 527 528 @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 529 ('x', None), ('y', None), ('batch_size', None)) 530 def partial_fit(self, 531 x=None, 532 y=None, 533 input_fn=None, 534 steps=1, 535 batch_size=None, 536 monitors=None): 537 """Incremental fit on a batch of samples. 538 539 This method is expected to be called several times consecutively 540 on different or the same chunks of the dataset. This either can 541 implement iterative training or out-of-core/online training. 542 543 This is especially useful when the whole dataset is too big to 544 fit in memory at the same time. Or when model is taking long time 545 to converge, and you want to split up training into subparts. 546 547 Args: 548 x: Matrix of shape [n_samples, n_features...]. Can be iterator that 549 returns arrays of features. The training input samples for fitting the 550 model. If set, `input_fn` must be `None`. 551 y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be 552 iterator that returns array of labels. The training label values 553 (class labels in classification, real numbers in regression). If set, 554 `input_fn` must be `None`. 555 input_fn: Input function. If set, `x`, `y`, and `batch_size` must be 556 `None`. 557 steps: Number of steps for which to train model. If `None`, train forever. 558 batch_size: minibatch size to use on the input, defaults to first 559 dimension of `x`. Must be `None` if `input_fn` is provided. 560 monitors: List of `BaseMonitor` subclass instances. Used for callbacks 561 inside the training loop. 562 563 Returns: 564 `self`, for chaining. 565 566 Raises: 567 ValueError: If at least one of `x` and `y` is provided, and `input_fn` is 568 provided. 569 """ 570 logging.warning('The current implementation of partial_fit is not optimized' 571 ' for use in a loop. Consider using fit() instead.') 572 return self.fit( 573 x=x, 574 y=y, 575 input_fn=input_fn, 576 steps=steps, 577 batch_size=batch_size, 578 monitors=monitors) 579 580 @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 581 ('x', None), ('y', None), ('batch_size', None)) 582 def evaluate(self, 583 x=None, 584 y=None, 585 input_fn=None, 586 feed_fn=None, 587 batch_size=None, 588 steps=None, 589 metrics=None, 590 name=None, 591 checkpoint_path=None, 592 hooks=None, 593 log_progress=True): 594 # pylint: disable=g-doc-args,g-doc-return-or-yield 595 """See `Evaluable`. 596 597 Raises: 598 ValueError: If at least one of `x` or `y` is provided, and at least one of 599 `input_fn` or `feed_fn` is provided. 600 Or if `metrics` is not `None` or `dict`. 601 """ 602 _verify_input_args(x, y, input_fn, feed_fn, batch_size) 603 if x is not None: 604 return SKCompat(self).score(x, y, batch_size, steps, metrics, name) 605 606 if metrics is not None and not isinstance(metrics, dict): 607 raise ValueError('Metrics argument should be None or dict. ' 608 'Got %s.' % metrics) 609 eval_results, global_step = self._evaluate_model( 610 input_fn=input_fn, 611 feed_fn=feed_fn, 612 steps=steps, 613 metrics=metrics, 614 name=name, 615 checkpoint_path=checkpoint_path, 616 hooks=hooks, 617 log_progress=log_progress) 618 619 if eval_results is not None: 620 eval_results.update({'global_step': global_step}) 621 return eval_results 622 623 @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, 624 ('x', None), ('batch_size', None), ('as_iterable', True)) 625 def predict(self, 626 x=None, 627 input_fn=None, 628 batch_size=None, 629 outputs=None, 630 as_iterable=True, 631 iterate_batches=False): 632 """Returns predictions for given features. 633 634 Args: 635 x: Matrix of shape [n_samples, n_features...]. Can be iterator that 636 returns arrays of features. The training input samples for fitting the 637 model. If set, `input_fn` must be `None`. 638 input_fn: Input function. If set, `x` and 'batch_size' must be `None`. 639 batch_size: Override default batch size. If set, 'input_fn' must be 640 'None'. 641 outputs: list of `str`, name of the output to predict. 642 If `None`, returns all. 643 as_iterable: If True, return an iterable which keeps yielding predictions 644 for each example until inputs are exhausted. Note: The inputs must 645 terminate if you want the iterable to terminate (e.g. be sure to pass 646 num_epochs=1 if you are using something like read_batch_features). 647 iterate_batches: If True, yield the whole batch at once instead of 648 decomposing the batch into individual samples. Only relevant when 649 as_iterable is True. 650 651 Returns: 652 A numpy array of predicted classes or regression values if the 653 constructor's `model_fn` returns a `Tensor` for `predictions` or a `dict` 654 of numpy arrays if `model_fn` returns a `dict`. Returns an iterable of 655 predictions if as_iterable is True. 656 657 Raises: 658 ValueError: If x and input_fn are both provided or both `None`. 659 """ 660 _verify_input_args(x, None, input_fn, None, batch_size) 661 if x is not None and not as_iterable: 662 return SKCompat(self).predict(x, batch_size) 663 664 input_fn, feed_fn = _get_input_fn(x, None, input_fn, None, batch_size) 665 return self._infer_model( 666 input_fn=input_fn, 667 feed_fn=feed_fn, 668 outputs=outputs, 669 as_iterable=as_iterable, 670 iterate_batches=iterate_batches) 671 672 def get_variable_value(self, name): 673 """Returns value of the variable given by name. 674 675 Args: 676 name: string, name of the tensor. 677 678 Returns: 679 Numpy array - value of the tensor. 680 """ 681 return load_variable(self.model_dir, name) 682 683 def get_variable_names(self): 684 """Returns list of all variable names in this model. 685 686 Returns: 687 List of names. 688 """ 689 return [name for name, _ in list_variables(self.model_dir)] 690 691 @property 692 def model_dir(self): 693 return self._model_dir 694 695 @deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') 696 def export( 697 self, 698 export_dir, 699 input_fn=export._default_input_fn, # pylint: disable=protected-access 700 input_feature_key=None, 701 use_deprecated_input_fn=True, 702 signature_fn=None, 703 prediction_key=None, 704 default_batch_size=1, 705 exports_to_keep=None, 706 checkpoint_path=None): 707 """Exports inference graph into given dir. 708 709 Args: 710 export_dir: A string containing a directory to write the exported graph 711 and checkpoints. 712 input_fn: If `use_deprecated_input_fn` is true, then a function that given 713 `Tensor` of `Example` strings, parses it into features that are then 714 passed to the model. Otherwise, a function that takes no argument and 715 returns a tuple of (features, labels), where features is a dict of 716 string key to `Tensor` and labels is a `Tensor` that's currently not 717 used (and so can be `None`). 718 input_feature_key: Only used if `use_deprecated_input_fn` is false. String 719 key into the features dict returned by `input_fn` that corresponds to a 720 the raw `Example` strings `Tensor` that the exported model will take as 721 input. Can only be `None` if you're using a custom `signature_fn` that 722 does not use the first arg (examples). 723 use_deprecated_input_fn: Determines the signature format of `input_fn`. 724 signature_fn: Function that returns a default signature and a named 725 signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s 726 for features and `Tensor` or `dict` of `Tensor`s for predictions. 727 prediction_key: The key for a tensor in the `predictions` dict (output 728 from the `model_fn`) to use as the `predictions` input to the 729 `signature_fn`. Optional. If `None`, predictions will pass to 730 `signature_fn` without filtering. 731 default_batch_size: Default batch size of the `Example` placeholder. 732 exports_to_keep: Number of exports to keep. 733 checkpoint_path: the checkpoint path of the model to be exported. If it is 734 `None` (which is default), will use the latest checkpoint in 735 export_dir. 736 737 Returns: 738 The string path to the exported directory. NB: this functionality was 739 added ca. 2016/09/25; clients that depend on the return value may need 740 to handle the case where this function returns None because subclasses 741 are not returning a value. 742 """ 743 # pylint: disable=protected-access 744 return export._export_estimator( 745 estimator=self, 746 export_dir=export_dir, 747 signature_fn=signature_fn, 748 prediction_key=prediction_key, 749 input_fn=input_fn, 750 input_feature_key=input_feature_key, 751 use_deprecated_input_fn=use_deprecated_input_fn, 752 default_batch_size=default_batch_size, 753 exports_to_keep=exports_to_keep, 754 checkpoint_path=checkpoint_path) 755 756 @abc.abstractproperty 757 def _get_train_ops(self, features, labels): 758 """Method that builds model graph and returns trainer ops. 759 760 Expected to be overridden by sub-classes that require custom support. 761 762 Args: 763 features: `Tensor` or `dict` of `Tensor` objects. 764 labels: `Tensor` or `dict` of `Tensor` objects. 765 766 Returns: 767 A `ModelFnOps` object. 768 """ 769 pass 770 771 @abc.abstractproperty 772 def _get_predict_ops(self, features): 773 """Method that builds model graph and returns prediction ops. 774 775 Args: 776 features: `Tensor` or `dict` of `Tensor` objects. 777 778 Returns: 779 A `ModelFnOps` object. 780 """ 781 pass 782 783 def _get_eval_ops(self, features, labels, metrics): 784 """Method that builds model graph and returns evaluation ops. 785 786 Expected to be overridden by sub-classes that require custom support. 787 788 Args: 789 features: `Tensor` or `dict` of `Tensor` objects. 790 labels: `Tensor` or `dict` of `Tensor` objects. 791 metrics: Dict of metrics to run. If None, the default metric functions 792 are used; if {}, no metrics are used. Otherwise, `metrics` should map 793 friendly names for the metric to a `MetricSpec` object defining which 794 model outputs to evaluate against which labels with which metric 795 function. Metric ops should support streaming, e.g., returning 796 update_op and value tensors. See more details in 797 `../../../../metrics/python/metrics/ops/streaming_metrics.py` and 798 `../metric_spec.py`. 799 800 Returns: 801 A `ModelFnOps` object. 802 """ 803 raise NotImplementedError('_get_eval_ops not implemented in BaseEstimator') 804 805 @deprecated( 806 '2016-09-23', 807 'The signature of the input_fn accepted by export is changing to be ' 808 'consistent with what\'s used by tf.Learn Estimator\'s train/evaluate, ' 809 'which makes this function useless. This will be removed after the ' 810 'deprecation date.') 811 def _get_feature_ops_from_example(self, examples_batch): 812 """Returns feature parser for given example batch using features info. 813 814 This function requires `fit()` has been called. 815 816 Args: 817 examples_batch: batch of tf.Example 818 819 Returns: 820 features: `Tensor` or `dict` of `Tensor` objects. 821 822 Raises: 823 ValueError: If `_features_info` attribute is not available (usually 824 because `fit()` has not been called). 825 """ 826 if self._features_info is None: 827 raise ValueError('Features information missing, was fit() ever called?') 828 return tensor_signature.create_example_parser_from_signatures( 829 self._features_info, examples_batch) 830 831 def _check_inputs(self, features, labels): 832 if self._features_info is not None: 833 logging.debug('Given features: %s, required signatures: %s.', 834 str(features), str(self._features_info)) 835 if not tensor_signature.tensors_compatible(features, self._features_info): 836 raise ValueError('Features are incompatible with given information. ' 837 'Given features: %s, required signatures: %s.' % 838 (str(features), str(self._features_info))) 839 else: 840 self._features_info = tensor_signature.create_signatures(features) 841 logging.debug('Setting feature info to %s.', str(self._features_info)) 842 if labels is not None: 843 if self._labels_info is not None: 844 logging.debug('Given labels: %s, required signatures: %s.', str(labels), 845 str(self._labels_info)) 846 if not tensor_signature.tensors_compatible(labels, self._labels_info): 847 raise ValueError('Labels are incompatible with given information. ' 848 'Given labels: %s, required signatures: %s.' % 849 (str(labels), str(self._labels_info))) 850 else: 851 self._labels_info = tensor_signature.create_signatures(labels) 852 logging.debug('Setting labels info to %s', str(self._labels_info)) 853 854 def _extract_metric_update_ops(self, eval_dict): 855 """Separate update operations from metric value operations.""" 856 update_ops = [] 857 value_ops = {} 858 for name, metric_ops in six.iteritems(eval_dict): 859 if isinstance(metric_ops, (list, tuple)): 860 if len(metric_ops) == 2: 861 value_ops[name] = metric_ops[0] 862 update_ops.append(metric_ops[1]) 863 else: 864 logging.warning( 865 'Ignoring metric {}. It returned a list|tuple with len {}, ' 866 'expected 2'.format(name, len(metric_ops))) 867 value_ops[name] = metric_ops 868 else: 869 value_ops[name] = metric_ops 870 871 if update_ops: 872 update_ops = control_flow_ops.group(*update_ops) 873 else: 874 update_ops = None 875 876 return update_ops, value_ops 877 878 def _evaluate_model(self, 879 input_fn, 880 steps, 881 feed_fn=None, 882 metrics=None, 883 name='', 884 checkpoint_path=None, 885 hooks=None, 886 log_progress=True): 887 # TODO(wicke): Remove this once Model and associated code are gone. 888 if (hasattr(self._config, 'execution_mode') and 889 self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset')): 890 return None, None 891 892 # Check that model has been trained (if nothing has been set explicitly). 893 if not checkpoint_path: 894 latest_path = checkpoint_management.latest_checkpoint(self._model_dir) 895 if not latest_path: 896 raise NotFittedError( 897 "Couldn't find trained model at %s." % self._model_dir) 898 checkpoint_path = latest_path 899 900 # Setup output directory. 901 eval_dir = os.path.join(self._model_dir, 'eval' 902 if not name else 'eval_' + name) 903 904 with ops.Graph().as_default() as g: 905 random_seed.set_random_seed(self._config.tf_random_seed) 906 global_step = training_util.create_global_step(g) 907 features, labels = input_fn() 908 self._check_inputs(features, labels) 909 910 model_fn_results = self._get_eval_ops(features, labels, metrics) 911 eval_dict = model_fn_results.eval_metric_ops 912 913 update_op, eval_dict = self._extract_metric_update_ops(eval_dict) 914 915 # We need to copy the hook array as we modify it, thus [:]. 916 hooks = hooks[:] if hooks else [] 917 if feed_fn: 918 hooks.append(basic_session_run_hooks.FeedFnHook(feed_fn)) 919 if steps == 0: 920 logging.warning('evaluation steps are 0. If `input_fn` does not raise ' 921 '`OutOfRangeError`, the evaluation will never stop. ' 922 'Use steps=None if intended.') 923 if steps: 924 hooks.append( 925 evaluation.StopAfterNEvalsHook(steps, log_progress=log_progress)) 926 927 global_step_key = 'global_step' 928 while global_step_key in eval_dict: 929 global_step_key = '_' + global_step_key 930 eval_dict[global_step_key] = global_step 931 932 eval_results = evaluation.evaluate_once( 933 checkpoint_path=checkpoint_path, 934 master=self._config.evaluation_master, 935 scaffold=model_fn_results.scaffold, 936 eval_ops=update_op, 937 final_ops=eval_dict, 938 hooks=hooks, 939 config=self._session_config) 940 current_global_step = eval_results[global_step_key] 941 942 _write_dict_to_summary(eval_dir, eval_results, current_global_step) 943 944 return eval_results, current_global_step 945 946 def _get_features_from_input_fn(self, input_fn): 947 result = input_fn() 948 if isinstance(result, (list, tuple)): 949 return result[0] 950 return result 951 952 def _infer_model(self, 953 input_fn, 954 feed_fn=None, 955 outputs=None, 956 as_iterable=True, 957 iterate_batches=False): 958 # Check that model has been trained. 959 checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) 960 if not checkpoint_path: 961 raise NotFittedError( 962 "Couldn't find trained model at %s." % self._model_dir) 963 964 with ops.Graph().as_default() as g: 965 random_seed.set_random_seed(self._config.tf_random_seed) 966 training_util.create_global_step(g) 967 features = self._get_features_from_input_fn(input_fn) 968 infer_ops = self._get_predict_ops(features) 969 predictions = self._filter_predictions(infer_ops.predictions, outputs) 970 mon_sess = monitored_session.MonitoredSession( 971 session_creator=monitored_session.ChiefSessionCreator( 972 checkpoint_filename_with_path=checkpoint_path, 973 scaffold=infer_ops.scaffold, 974 config=self._session_config)) 975 if not as_iterable: 976 with mon_sess: 977 if not mon_sess.should_stop(): 978 return mon_sess.run(predictions, feed_fn() if feed_fn else None) 979 else: 980 return self._predict_generator(mon_sess, predictions, feed_fn, 981 iterate_batches) 982 983 def _predict_generator(self, mon_sess, predictions, feed_fn, iterate_batches): 984 with mon_sess: 985 while not mon_sess.should_stop(): 986 preds = mon_sess.run(predictions, feed_fn() if feed_fn else None) 987 if iterate_batches: 988 yield preds 989 elif not isinstance(predictions, dict): 990 for pred in preds: 991 yield pred 992 else: 993 first_tensor = list(preds.values())[0] 994 if isinstance(first_tensor, sparse_tensor.SparseTensorValue): 995 batch_length = first_tensor.dense_shape[0] 996 else: 997 batch_length = first_tensor.shape[0] 998 for i in range(batch_length): 999 yield {key: value[i] for key, value in six.iteritems(preds)} 1000 if self._is_input_constant(feed_fn, mon_sess.graph): 1001 return 1002 1003 def _is_input_constant(self, feed_fn, graph): 1004 # If there are no queue_runners, the input `predictions` is a 1005 # constant, and we should stop after the first epoch. If, 1006 # instead, there are queue_runners, eventually they should throw 1007 # an `OutOfRangeError`. 1008 if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS): 1009 return False 1010 # data_feeder uses feed_fn to generate `OutOfRangeError`. 1011 if feed_fn is not None: 1012 return False 1013 return True 1014 1015 def _filter_predictions(self, predictions, outputs): 1016 if not outputs: 1017 return predictions 1018 if not isinstance(predictions, dict): 1019 raise ValueError( 1020 'outputs argument is not valid in case of non-dict predictions.') 1021 existing_keys = predictions.keys() 1022 predictions = { 1023 key: value 1024 for key, value in six.iteritems(predictions) 1025 if key in outputs 1026 } 1027 if not predictions: 1028 raise ValueError('Expected to run at least one output from %s, ' 1029 'provided %s.' % (existing_keys, outputs)) 1030 return predictions 1031 1032 def _train_model(self, input_fn, hooks): 1033 all_hooks = [] 1034 self._graph = ops.Graph() 1035 with self._graph.as_default() as g, g.device(self._device_fn): 1036 random_seed.set_random_seed(self._config.tf_random_seed) 1037 global_step = training_util.create_global_step(g) 1038 features, labels = input_fn() 1039 self._check_inputs(features, labels) 1040 training_util._get_or_create_global_step_read() # pylint: disable=protected-access 1041 model_fn_ops = self._get_train_ops(features, labels) 1042 ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss) 1043 all_hooks.extend(hooks) 1044 all_hooks.extend([ 1045 basic_session_run_hooks.NanTensorHook(model_fn_ops.loss), 1046 basic_session_run_hooks.LoggingTensorHook( 1047 { 1048 'loss': model_fn_ops.loss, 1049 'step': global_step 1050 }, 1051 every_n_iter=100) 1052 ]) 1053 1054 scaffold = model_fn_ops.scaffold or monitored_session.Scaffold() 1055 if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): 1056 ops.add_to_collection( 1057 ops.GraphKeys.SAVERS, 1058 saver.Saver( 1059 sharded=True, 1060 max_to_keep=self._config.keep_checkpoint_max, 1061 keep_checkpoint_every_n_hours=( 1062 self._config.keep_checkpoint_every_n_hours), 1063 defer_build=True, 1064 save_relative_paths=True)) 1065 1066 chief_hooks = [] 1067 if (self._config.save_checkpoints_secs or 1068 self._config.save_checkpoints_steps): 1069 saver_hook_exists = any( 1070 isinstance(h, basic_session_run_hooks.CheckpointSaverHook) 1071 for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks + 1072 model_fn_ops.training_chief_hooks) 1073 ) 1074 if not saver_hook_exists: 1075 chief_hooks = [ 1076 basic_session_run_hooks.CheckpointSaverHook( 1077 self._model_dir, 1078 save_secs=self._config.save_checkpoints_secs, 1079 save_steps=self._config.save_checkpoints_steps, 1080 scaffold=scaffold) 1081 ] 1082 with monitored_session.MonitoredTrainingSession( 1083 master=self._config.master, 1084 is_chief=self._config.is_chief, 1085 checkpoint_dir=self._model_dir, 1086 scaffold=scaffold, 1087 hooks=all_hooks + model_fn_ops.training_hooks, 1088 chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks, 1089 save_checkpoint_secs=0, # Saving is handled by a hook. 1090 save_summaries_steps=self._config.save_summary_steps, 1091 config=self._session_config) as mon_sess: 1092 loss = None 1093 while not mon_sess.should_stop(): 1094 _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss]) 1095 return loss 1096 1097 1098def _identity_feature_engineering_fn(features, labels): 1099 return features, labels 1100 1101 1102class Estimator(BaseEstimator): 1103 """Estimator class is the basic TensorFlow model trainer/evaluator. 1104 1105 THIS CLASS IS DEPRECATED. See 1106 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 1107 for general migration instructions. 1108 """ 1109 1110 def __init__(self, 1111 model_fn=None, 1112 model_dir=None, 1113 config=None, 1114 params=None, 1115 feature_engineering_fn=None): 1116 """Constructs an `Estimator` instance. 1117 1118 Args: 1119 model_fn: Model function. Follows the signature: 1120 * Args: 1121 * `features`: single `Tensor` or `dict` of `Tensor`s 1122 (depending on data passed to `fit`), 1123 * `labels`: `Tensor` or `dict` of `Tensor`s (for multi-head 1124 models). If mode is `ModeKeys.INFER`, `labels=None` will be 1125 passed. If the `model_fn`'s signature does not accept 1126 `mode`, the `model_fn` must still be able to handle 1127 `labels=None`. 1128 * `mode`: Optional. Specifies if this training, evaluation or 1129 prediction. See `ModeKeys`. 1130 * `params`: Optional `dict` of hyperparameters. Will receive what 1131 is passed to Estimator in `params` parameter. This allows 1132 to configure Estimators from hyper parameter tuning. 1133 * `config`: Optional configuration object. Will receive what is passed 1134 to Estimator in `config` parameter, or the default `config`. 1135 Allows updating things in your model_fn based on configuration 1136 such as `num_ps_replicas`. 1137 * `model_dir`: Optional directory where model parameters, graph etc 1138 are saved. Will receive what is passed to Estimator in 1139 `model_dir` parameter, or the default `model_dir`. Allows 1140 updating things in your model_fn that expect model_dir, such as 1141 training hooks. 1142 1143 * Returns: 1144 `ModelFnOps` 1145 1146 Also supports a legacy signature which returns tuple of: 1147 1148 * predictions: `Tensor`, `SparseTensor` or dictionary of same. 1149 Can also be any type that is convertible to a `Tensor` or 1150 `SparseTensor`, or dictionary of same. 1151 * loss: Scalar loss `Tensor`. 1152 * train_op: Training update `Tensor` or `Operation`. 1153 1154 Supports next three signatures for the function: 1155 1156 * `(features, labels) -> (predictions, loss, train_op)` 1157 * `(features, labels, mode) -> (predictions, loss, train_op)` 1158 * `(features, labels, mode, params) -> (predictions, loss, train_op)` 1159 * `(features, labels, mode, params, config) -> 1160 (predictions, loss, train_op)` 1161 * `(features, labels, mode, params, config, model_dir) -> 1162 (predictions, loss, train_op)` 1163 1164 model_dir: Directory to save model parameters, graph and etc. This can 1165 also be used to load checkpoints from the directory into a estimator to 1166 continue training a previously saved model. 1167 config: Configuration object. 1168 params: `dict` of hyper parameters that will be passed into `model_fn`. 1169 Keys are names of parameters, values are basic python types. 1170 feature_engineering_fn: Feature engineering function. Takes features and 1171 labels which are the output of `input_fn` and 1172 returns features and labels which will be fed 1173 into `model_fn`. Please check `model_fn` for 1174 a definition of features and labels. 1175 1176 Raises: 1177 ValueError: parameters of `model_fn` don't match `params`. 1178 """ 1179 super(Estimator, self).__init__(model_dir=model_dir, config=config) 1180 if model_fn is not None: 1181 # Check number of arguments of the given function matches requirements. 1182 model_fn_args = _model_fn_args(model_fn) 1183 if params is not None and 'params' not in model_fn_args: 1184 raise ValueError('Estimator\'s model_fn (%s) does not have a params ' 1185 'argument, but params (%s) were passed to the ' 1186 'Estimator\'s constructor.' % (model_fn, params)) 1187 if params is None and 'params' in model_fn_args: 1188 logging.warning('Estimator\'s model_fn (%s) includes params ' 1189 'argument, but params are not passed to Estimator.', 1190 model_fn) 1191 self._model_fn = model_fn 1192 self.params = params 1193 self._feature_engineering_fn = ( 1194 feature_engineering_fn or _identity_feature_engineering_fn) 1195 1196 def _call_model_fn(self, features, labels, mode, metrics=None, config=None): 1197 """Calls model function with support of 2, 3 or 4 arguments. 1198 1199 Args: 1200 features: features dict. 1201 labels: labels dict. 1202 mode: ModeKeys 1203 metrics: Dict of metrics. 1204 config: RunConfig. 1205 1206 Returns: 1207 A `ModelFnOps` object. If model_fn returns a tuple, wraps them up in a 1208 `ModelFnOps` object. 1209 1210 Raises: 1211 ValueError: if model_fn returns invalid objects. 1212 """ 1213 features, labels = self._feature_engineering_fn(features, labels) 1214 model_fn_args = _model_fn_args(self._model_fn) 1215 kwargs = {} 1216 if 'mode' in model_fn_args: 1217 kwargs['mode'] = mode 1218 if 'params' in model_fn_args: 1219 kwargs['params'] = self.params 1220 if 'config' in model_fn_args: 1221 if config: 1222 kwargs['config'] = config 1223 else: 1224 kwargs['config'] = self.config 1225 if 'model_dir' in model_fn_args: 1226 kwargs['model_dir'] = self.model_dir 1227 model_fn_results = self._model_fn(features, labels, **kwargs) 1228 1229 if isinstance(model_fn_results, model_fn_lib.ModelFnOps): 1230 model_fn_ops = model_fn_results 1231 else: 1232 # Here model_fn_results should be a tuple with 3 elements. 1233 if len(model_fn_results) != 3: 1234 raise ValueError('Unrecognized value returned by model_fn, ' 1235 'please return ModelFnOps.') 1236 model_fn_ops = model_fn_lib.ModelFnOps( 1237 mode=mode, 1238 predictions=model_fn_results[0], 1239 loss=model_fn_results[1], 1240 train_op=model_fn_results[2]) 1241 1242 # Custom metrics should overwrite defaults. 1243 if metrics: 1244 model_fn_ops.eval_metric_ops.update( 1245 _make_metrics_ops(metrics, features, labels, 1246 model_fn_ops.predictions)) 1247 1248 return model_fn_ops 1249 1250 def _get_train_ops(self, features, labels): 1251 """Method that builds model graph and returns trainer ops. 1252 1253 Expected to be overridden by sub-classes that require custom support. 1254 This implementation uses `model_fn` passed as parameter to constructor to 1255 build model. 1256 1257 Args: 1258 features: `Tensor` or `dict` of `Tensor` objects. 1259 labels: `Tensor` or `dict` of `Tensor` objects. 1260 1261 Returns: 1262 `ModelFnOps` object. 1263 """ 1264 return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN) 1265 1266 def _get_eval_ops(self, features, labels, metrics): 1267 """Method that builds model graph and returns evaluation ops. 1268 1269 Expected to be overridden by sub-classes that require custom support. 1270 This implementation uses `model_fn` passed as parameter to constructor to 1271 build model. 1272 1273 Args: 1274 features: `Tensor` or `dict` of `Tensor` objects. 1275 labels: `Tensor` or `dict` of `Tensor` objects. 1276 metrics: Dict of metrics to run. If None, the default metric functions 1277 are used; if {}, no metrics are used. Otherwise, `metrics` should map 1278 friendly names for the metric to a `MetricSpec` object defining which 1279 model outputs to evaluate against which labels with which metric 1280 function. Metric ops should support streaming, e.g., returning 1281 update_op and value tensors. See more details in 1282 `../../../../metrics/python/metrics/ops/streaming_metrics.py` and 1283 `../metric_spec.py`. 1284 1285 Returns: 1286 `ModelFnOps` object. 1287 1288 Raises: 1289 ValueError: if `metrics` don't match `labels`. 1290 """ 1291 model_fn_ops = self._call_model_fn(features, labels, 1292 model_fn_lib.ModeKeys.EVAL, metrics) 1293 1294 if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops: 1295 model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = ( 1296 metrics_lib.mean(model_fn_ops.loss)) 1297 return model_fn_ops 1298 1299 def _get_predict_ops(self, features): 1300 """Method that builds model graph and returns prediction ops. 1301 1302 Expected to be overridden by sub-classes that require custom support. 1303 This implementation uses `model_fn` passed as parameter to constructor to 1304 build model. 1305 1306 Args: 1307 features: `Tensor` or `dict` of `Tensor` objects. 1308 1309 Returns: 1310 `ModelFnOps` object. 1311 """ 1312 labels = tensor_signature.create_placeholders_from_signatures( 1313 self._labels_info) 1314 return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER) 1315 1316 def export_savedmodel(self, 1317 export_dir_base, 1318 serving_input_fn, 1319 default_output_alternative_key=None, 1320 assets_extra=None, 1321 as_text=False, 1322 checkpoint_path=None, 1323 graph_rewrite_specs=(GraphRewriteSpec( 1324 (tag_constants.SERVING,), ()),), 1325 strip_default_attrs=False): 1326 # pylint: disable=line-too-long 1327 """Exports inference graph as a SavedModel into given dir. 1328 1329 Args: 1330 export_dir_base: A string containing a directory to write the exported 1331 graph and checkpoints. 1332 serving_input_fn: A function that takes no argument and 1333 returns an `InputFnOps`. 1334 default_output_alternative_key: the name of the head to serve when none is 1335 specified. Not needed for single-headed models. 1336 assets_extra: A dict specifying how to populate the assets.extra directory 1337 within the exported SavedModel. Each key should give the destination 1338 path (including the filename) relative to the assets.extra directory. 1339 The corresponding value gives the full path of the source file to be 1340 copied. For example, the simple case of copying a single file without 1341 renaming it is specified as 1342 `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. 1343 as_text: whether to write the SavedModel proto in text format. 1344 checkpoint_path: The checkpoint path to export. If None (the default), 1345 the most recent checkpoint found within the model directory is chosen. 1346 graph_rewrite_specs: an iterable of `GraphRewriteSpec`. Each element will 1347 produce a separate MetaGraphDef within the exported SavedModel, tagged 1348 and rewritten as specified. Defaults to a single entry using the 1349 default serving tag ("serve") and no rewriting. 1350 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 1351 removed from the NodeDefs. For a detailed guide, see 1352 [Stripping Default-Valued 1353 Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 1354 1355 Returns: 1356 The string path to the exported directory. 1357 1358 Raises: 1359 ValueError: if an unrecognized export_type is requested. 1360 """ 1361 # pylint: enable=line-too-long 1362 if serving_input_fn is None: 1363 raise ValueError('serving_input_fn must be defined.') 1364 1365 if not checkpoint_path: 1366 # Locate the latest checkpoint 1367 checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir) 1368 if not checkpoint_path: 1369 raise NotFittedError( 1370 "Couldn't find trained model at %s." % self._model_dir) 1371 1372 export_dir = saved_model_export_utils.get_timestamped_export_dir( 1373 export_dir_base) 1374 # We'll write the SavedModel to a temporary directory and then atomically 1375 # rename it at the end. This helps to avoid corrupt / incomplete outputs, 1376 # which could otherwise occur if the job is preempted or otherwise fails 1377 # in the middle of SavedModel creation. 1378 temp_export_dir = saved_model_export_utils.get_temp_export_dir(export_dir) 1379 builder = saved_model_builder.SavedModelBuilder(temp_export_dir) 1380 1381 # Build the base graph 1382 with ops.Graph().as_default() as g: 1383 training_util.create_global_step(g) 1384 1385 # Call the serving_input_fn and collect the input alternatives. 1386 input_ops = serving_input_fn() 1387 input_alternatives, features = ( 1388 saved_model_export_utils.get_input_alternatives(input_ops)) 1389 1390 # TODO(b/34388557) This is a stopgap, pending recording model provenance. 1391 # Record which features are expected at serving time. It is assumed that 1392 # these are the features that were used in training. 1393 for feature_key in input_ops.features.keys(): 1394 ops.add_to_collection( 1395 constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS, feature_key) 1396 1397 # Call the model_fn and collect the output alternatives. 1398 model_fn_ops = self._call_model_fn(features, None, 1399 model_fn_lib.ModeKeys.INFER) 1400 output_alternatives, actual_default_output_alternative_key = ( 1401 saved_model_export_utils.get_output_alternatives( 1402 model_fn_ops, default_output_alternative_key)) 1403 1404 init_op = control_flow_ops.group(variables.local_variables_initializer(), 1405 resources.initialize_resources( 1406 resources.shared_resources()), 1407 lookup_ops.tables_initializer()) 1408 1409 # Build the SignatureDefs from all pairs of input and output alternatives 1410 signature_def_map = saved_model_export_utils.build_all_signature_defs( 1411 input_alternatives, output_alternatives, 1412 actual_default_output_alternative_key) 1413 1414 # Export the first MetaGraphDef with variables, assets etc. 1415 with tf_session.Session('') as session: 1416 1417 # pylint: disable=protected-access 1418 saveables = variables._all_saveable_objects() 1419 # pylint: enable=protected-access 1420 1421 if (model_fn_ops.scaffold is not None and 1422 model_fn_ops.scaffold.saver is not None): 1423 saver_for_restore = model_fn_ops.scaffold.saver 1424 elif saveables: 1425 saver_for_restore = saver.Saver(saveables, sharded=True) 1426 1427 saver_for_restore.restore(session, checkpoint_path) 1428 1429 # Perform the export 1430 if not graph_rewrite_specs or graph_rewrite_specs[0].transforms: 1431 raise ValueError('The first element of graph_rewrite_specs ' 1432 'must specify no transforms.') 1433 untransformed_tags = graph_rewrite_specs[0].tags 1434 1435 builder.add_meta_graph_and_variables( 1436 session, 1437 untransformed_tags, 1438 signature_def_map=signature_def_map, 1439 assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), 1440 main_op=init_op, 1441 strip_default_attrs=strip_default_attrs) 1442 1443 # pylint: disable=protected-access 1444 base_meta_graph_def = builder._saved_model.meta_graphs[0] 1445 # pylint: enable=protected-access 1446 1447 if graph_rewrite_specs[1:]: 1448 # Prepare the input_names and output_names needed for the 1449 # meta_graph_transform call below. 1450 input_names = [ 1451 tensor.name 1452 for input_dict in input_alternatives.values() 1453 for tensor in input_dict.values() 1454 ] 1455 output_names = [ 1456 tensor.name 1457 for output_alternative in output_alternatives.values() 1458 for tensor in output_alternative[1].values() 1459 ] 1460 1461 # Write the additional MetaGraphDefs 1462 for graph_rewrite_spec in graph_rewrite_specs[1:]: 1463 1464 # TODO(soergel) consider moving most of this to saved_model.builder_impl 1465 # as e.g. builder.add_rewritten_meta_graph(rewritten_graph_def, tags) 1466 1467 transformed_meta_graph_def = meta_graph_transform.meta_graph_transform( 1468 base_meta_graph_def, input_names, output_names, 1469 graph_rewrite_spec.transforms, graph_rewrite_spec.tags) 1470 1471 # pylint: disable=protected-access 1472 meta_graph_def = builder._saved_model.meta_graphs.add() 1473 # pylint: enable=protected-access 1474 meta_graph_def.CopyFrom(transformed_meta_graph_def) 1475 1476 # Add the extra assets 1477 if assets_extra: 1478 assets_extra_path = os.path.join( 1479 compat.as_bytes(temp_export_dir), compat.as_bytes('assets.extra')) 1480 for dest_relative, source in assets_extra.items(): 1481 dest_absolute = os.path.join( 1482 compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) 1483 dest_path = os.path.dirname(dest_absolute) 1484 gfile.MakeDirs(dest_path) 1485 gfile.Copy(source, dest_absolute) 1486 1487 builder.save(as_text) 1488 gfile.Rename(temp_export_dir, export_dir) 1489 return export_dir 1490 1491 1492# For time of deprecation x,y from Estimator allow direct access. 1493# pylint: disable=protected-access 1494class SKCompat(sklearn.BaseEstimator): 1495 """Scikit learn wrapper for TensorFlow Learn Estimator. 1496 1497 THIS CLASS IS DEPRECATED. See 1498 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 1499 for general migration instructions. 1500 """ 1501 1502 @deprecated(None, 'Please switch to the Estimator interface.') 1503 def __init__(self, estimator): 1504 self._estimator = estimator 1505 1506 def fit(self, x, y, batch_size=128, steps=None, max_steps=None, 1507 monitors=None): 1508 input_fn, feed_fn = _get_input_fn( 1509 x, 1510 y, 1511 input_fn=None, 1512 feed_fn=None, 1513 batch_size=batch_size, 1514 shuffle=True, 1515 epochs=None) 1516 all_monitors = [] 1517 if feed_fn: 1518 all_monitors = [basic_session_run_hooks.FeedFnHook(feed_fn)] 1519 if monitors: 1520 all_monitors.extend(monitors) 1521 1522 self._estimator.fit( 1523 input_fn=input_fn, 1524 steps=steps, 1525 max_steps=max_steps, 1526 monitors=all_monitors) 1527 return self 1528 1529 def score(self, x, y, batch_size=128, steps=None, metrics=None, name=None): 1530 input_fn, feed_fn = _get_input_fn( 1531 x, 1532 y, 1533 input_fn=None, 1534 feed_fn=None, 1535 batch_size=batch_size, 1536 shuffle=False, 1537 epochs=1) 1538 if metrics is not None and not isinstance(metrics, dict): 1539 raise ValueError('Metrics argument should be None or dict. ' 1540 'Got %s.' % metrics) 1541 eval_results, global_step = self._estimator._evaluate_model( 1542 input_fn=input_fn, 1543 feed_fn=feed_fn, 1544 steps=steps, 1545 metrics=metrics, 1546 name=name) 1547 if eval_results is not None: 1548 eval_results.update({'global_step': global_step}) 1549 return eval_results 1550 1551 def predict(self, x, batch_size=128, outputs=None): 1552 input_fn, feed_fn = _get_input_fn( 1553 x, 1554 None, 1555 input_fn=None, 1556 feed_fn=None, 1557 batch_size=batch_size, 1558 shuffle=False, 1559 epochs=1) 1560 results = list( 1561 self._estimator._infer_model( 1562 input_fn=input_fn, 1563 feed_fn=feed_fn, 1564 outputs=outputs, 1565 as_iterable=True, 1566 iterate_batches=True)) 1567 if not isinstance(results[0], dict): 1568 return np.concatenate([output for output in results], axis=0) 1569 return { 1570 key: np.concatenate([output[key] for output in results], axis=0) 1571 for key in results[0] 1572 } 1573