• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base Estimator class (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import abc
27import collections
28import copy
29import os
30import tempfile
31
32import numpy as np
33import six
34
35from google.protobuf import message
36from tensorflow.contrib import layers
37from tensorflow.contrib.framework import deprecated
38from tensorflow.contrib.framework import deprecated_args
39from tensorflow.contrib.framework import list_variables
40from tensorflow.contrib.framework import load_variable
41from tensorflow.contrib.learn.python.learn import evaluable
42from tensorflow.contrib.learn.python.learn import metric_spec
43from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
44from tensorflow.contrib.learn.python.learn import trainable
45from tensorflow.contrib.learn.python.learn.estimators import _sklearn as sklearn
46from tensorflow.contrib.learn.python.learn.estimators import constants
47from tensorflow.contrib.learn.python.learn.estimators import metric_key
48from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib
49from tensorflow.contrib.learn.python.learn.estimators import run_config
50from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
51from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
52from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
53from tensorflow.contrib.learn.python.learn.utils import export
54from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
55from tensorflow.contrib.meta_graph_transform import meta_graph_transform
56from tensorflow.contrib.training.python.training import evaluation
57from tensorflow.core.framework import summary_pb2
58from tensorflow.core.protobuf import config_pb2
59from tensorflow.python.client import session as tf_session
60from tensorflow.python.framework import ops
61from tensorflow.python.framework import random_seed
62from tensorflow.python.framework import sparse_tensor
63from tensorflow.python.framework import tensor_util
64from tensorflow.python.ops import control_flow_ops
65from tensorflow.python.ops import lookup_ops
66from tensorflow.python.ops import metrics as metrics_lib
67from tensorflow.python.ops import resources
68from tensorflow.python.ops import variables
69from tensorflow.python.platform import gfile
70from tensorflow.python.platform import tf_logging as logging
71from tensorflow.python.saved_model import builder as saved_model_builder
72from tensorflow.python.saved_model import tag_constants
73from tensorflow.python.summary import summary as core_summary
74from tensorflow.python.training import basic_session_run_hooks
75from tensorflow.python.training import checkpoint_management
76from tensorflow.python.training import device_setter
77from tensorflow.python.training import monitored_session
78from tensorflow.python.training import saver
79from tensorflow.python.training import training_util
80from tensorflow.python.util import compat
81from tensorflow.python.util import tf_decorator
82from tensorflow.python.util import tf_inspect
83
84AS_ITERABLE_DATE = '2016-09-15'
85AS_ITERABLE_INSTRUCTIONS = (
86    'The default behavior of predict() is changing. The default value for\n'
87    'as_iterable will change to True, and then the flag will be removed\n'
88    'altogether. The behavior of this flag is described below.')
89SCIKIT_DECOUPLE_DATE = '2016-12-01'
90SCIKIT_DECOUPLE_INSTRUCTIONS = (
91    'Estimator is decoupled from Scikit Learn interface by moving into\n'
92    'separate class SKCompat. Arguments x, y and batch_size are only\n'
93    'available in the SKCompat class, Estimator will only accept input_fn.\n'
94    'Example conversion:\n'
95    '  est = Estimator(...) -> est = SKCompat(Estimator(...))')
96
97
98def _verify_input_args(x, y, input_fn, feed_fn, batch_size):
99  """Verifies validity of co-existence of input arguments."""
100  if input_fn is None:
101    if x is None:
102      raise ValueError('Either x or input_fn must be provided.')
103
104    if tensor_util.is_tensor(x) or y is not None and tensor_util.is_tensor(y):
105      raise ValueError('Inputs cannot be tensors. Please provide input_fn.')
106
107    if feed_fn is not None:
108      raise ValueError('Can not provide both feed_fn and x or y.')
109  else:
110    if (x is not None) or (y is not None):
111      raise ValueError('Can not provide both input_fn and x or y.')
112    if batch_size is not None:
113      raise ValueError('Can not provide both input_fn and batch_size.')
114
115
116def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1):
117  """Make inputs into input and feed functions.
118
119  Args:
120    x: Numpy, Pandas or Dask matrix or iterable.
121    y: Numpy, Pandas or Dask matrix or iterable.
122    input_fn: Pre-defined input function for training data.
123    feed_fn: Pre-defined data feeder function.
124    batch_size: Size to split data into parts. Must be >= 1.
125    shuffle: Whether to shuffle the inputs.
126    epochs: Number of epochs to run.
127
128  Returns:
129    Data input and feeder function based on training data.
130
131  Raises:
132    ValueError: Only one of `(x & y)` or `input_fn` must be provided.
133  """
134  _verify_input_args(x, y, input_fn, feed_fn, batch_size)
135  if input_fn is not None:
136    return input_fn, feed_fn
137  df = data_feeder.setup_train_data_feeder(
138      x,
139      y,
140      n_classes=None,
141      batch_size=batch_size,
142      shuffle=shuffle,
143      epochs=epochs)
144  return df.input_builder, df.get_feed_dict_fn()
145
146
147@deprecated(None, 'Please specify feature columns explicitly.')
148def infer_real_valued_columns_from_input_fn(input_fn):
149  """Creates `FeatureColumn` objects for inputs defined by `input_fn`.
150
151  This interprets all inputs as dense, fixed-length float values. This creates
152  a local graph in which it calls `input_fn` to build the tensors, then discards
153  it.
154
155  Args:
156    input_fn: Input function returning a tuple of:
157        features - Dictionary of string feature name to `Tensor` or `Tensor`.
158        labels - `Tensor` of label values.
159
160  Returns:
161    List of `FeatureColumn` objects.
162  """
163  with ops.Graph().as_default():
164    features, _ = input_fn()
165    return layers.infer_real_valued_columns(features)
166
167
168@deprecated(None, 'Please specify feature columns explicitly.')
169def infer_real_valued_columns_from_input(x):
170  """Creates `FeatureColumn` objects for inputs defined by input `x`.
171
172  This interprets all inputs as dense, fixed-length float values.
173
174  Args:
175    x: Real-valued matrix of shape [n_samples, n_features...]. Can be
176       iterator that returns arrays of features.
177
178  Returns:
179    List of `FeatureColumn` objects.
180  """
181  input_fn, _ = _get_input_fn(
182      x=x, y=None, input_fn=None, feed_fn=None, batch_size=None)
183  return infer_real_valued_columns_from_input_fn(input_fn)
184
185
186def _model_fn_args(fn):
187  """Get argument names for function-like object.
188
189  Args:
190    fn: Function, or function-like object (e.g., result of `functools.partial`).
191
192  Returns:
193    `tuple` of string argument names.
194
195  Raises:
196    ValueError: if partial function has positionally bound arguments
197  """
198  _, fn = tf_decorator.unwrap(fn)
199  if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
200    # Handle functools.partial and similar objects.
201    return tuple([
202        arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):]
203        if arg not in set(fn.keywords.keys())
204    ])
205  # Handle function.
206  return tuple(tf_inspect.getargspec(fn).args)
207
208
209def _get_replica_device_setter(config):
210  """Creates a replica device setter if required.
211
212  Args:
213    config: A RunConfig instance.
214
215  Returns:
216    A replica device setter, or None.
217  """
218  ps_ops = [
219      'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
220      'MutableHashTableV2', 'MutableHashTableOfTensors',
221      'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
222      'MutableDenseHashTableV2', 'VarHandleOp'
223  ]
224
225  if config.task_type:
226    worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
227  else:
228    worker_device = '/job:worker'
229
230  if config.num_ps_replicas > 0:
231    return device_setter.replica_device_setter(
232        ps_tasks=config.num_ps_replicas,
233        worker_device=worker_device,
234        merge_devices=True,
235        ps_ops=ps_ops,
236        cluster=config.cluster_spec)
237  else:
238    return None
239
240
241def _make_metrics_ops(metrics, features, labels, predictions):
242  """Add metrics based on `features`, `labels`, and `predictions`.
243
244  `metrics` contains a specification for how to run metrics. It is a dict
245  mapping friendly names to either `MetricSpec` objects, or directly to a metric
246  function (assuming that `predictions` and `labels` are single tensors), or to
247  `(pred_name, metric)` `tuple`, which passes `predictions[pred_name]` and
248  `labels` to `metric` (assuming `labels` is a single tensor).
249
250  Users are encouraged to use `MetricSpec` objects, which are more flexible and
251  cleaner. They also lead to clearer errors.
252
253  Args:
254    metrics: A dict mapping names to metrics specification, for example
255      `MetricSpec` objects.
256    features: A dict of tensors returned from an input_fn as features/inputs.
257    labels: A single tensor or a dict of tensors returned from an input_fn as
258      labels.
259    predictions: A single tensor or a dict of tensors output from a model as
260      predictions.
261
262  Returns:
263    A dict mapping the friendly given in `metrics` to the result of calling the
264    given metric function.
265
266  Raises:
267    ValueError: If metrics specifications do not work with the type of
268      `features`, `labels`, or `predictions` provided. Mostly, a dict is given
269      but no pred_name specified.
270  """
271  metrics = metrics or {}
272
273  # If labels is a dict with a single key, unpack into a single tensor.
274  labels_tensor_or_dict = labels
275  if isinstance(labels, dict) and len(labels) == 1:
276    labels_tensor_or_dict = labels[list(labels.keys())[0]]
277
278  result = {}
279  # Iterate in lexicographic order, so the graph is identical among runs.
280  for name, metric in sorted(six.iteritems(metrics)):
281    if isinstance(metric, metric_spec.MetricSpec):
282      result[name] = metric.create_metric_ops(features, labels, predictions)
283      continue
284
285    # TODO(b/31229024): Remove the rest of this loop
286    logging.warning('Please specify metrics using MetricSpec. Using bare '
287                    'functions or (key, fn) tuples is deprecated and support '
288                    'for it will be removed on Oct 1, 2016.')
289
290    if isinstance(name, tuple):
291      # Multi-head metrics.
292      if len(name) != 2:
293        raise ValueError('Invalid metric for {}. It returned a tuple with '
294                         'len {}, expected 2.'.format(name, len(name)))
295      if not isinstance(predictions, dict):
296        raise ValueError('Metrics passed provide (name, prediction), '
297                         'but predictions are not dict. '
298                         'Metrics: %s, Predictions: %s.' % (metrics,
299                                                            predictions))
300      # Here are two options: labels are single Tensor or a dict.
301      if isinstance(labels, dict) and name[1] in labels:
302        # If labels are dict and the prediction name is in it, apply metric.
303        result[name[0]] = metric(predictions[name[1]], labels[name[1]])
304      else:
305        # Otherwise pass the labels to the metric.
306        result[name[0]] = metric(predictions[name[1]], labels_tensor_or_dict)
307    else:
308      # Single head metrics.
309      if isinstance(predictions, dict):
310        raise ValueError('Metrics passed provide only name, no prediction, '
311                         'but predictions are dict. '
312                         'Metrics: %s, Labels: %s.' % (metrics,
313                                                       labels_tensor_or_dict))
314      result[name] = metric(predictions, labels_tensor_or_dict)
315  return result
316
317
318def _dict_to_str(dictionary):
319  """Get a `str` representation of a `dict`.
320
321  Args:
322    dictionary: The `dict` to be represented as `str`.
323
324  Returns:
325    A `str` representing the `dictionary`.
326  """
327  results = []
328  for k, v in sorted(dictionary.items()):
329    if isinstance(v, float) or isinstance(v, np.float32) or isinstance(
330        v, int) or isinstance(v, np.int64) or isinstance(v, np.int32):
331      results.append('%s = %s' % (k, v))
332    else:
333      results.append('Type of %s = %s' % (k, type(v)))
334
335  return ', '.join(results)
336
337
338def _write_dict_to_summary(output_dir, dictionary, current_global_step):
339  """Writes a `dict` into summary file in given output directory.
340
341  Args:
342    output_dir: `str`, directory to write the summary file in.
343    dictionary: the `dict` to be written to summary file.
344    current_global_step: `int`, the current global step.
345  """
346  logging.info('Saving dict for global step %d: %s', current_global_step,
347               _dict_to_str(dictionary))
348  summary_writer = core_summary.FileWriterCache.get(output_dir)
349  summary_proto = summary_pb2.Summary()
350  for key in dictionary:
351    if dictionary[key] is None:
352      continue
353    if key == 'global_step':
354      continue
355    if (isinstance(dictionary[key], np.float32) or
356        isinstance(dictionary[key], float)):
357      summary_proto.value.add(tag=key, simple_value=float(dictionary[key]))
358    elif (isinstance(dictionary[key], np.int64) or
359          isinstance(dictionary[key], np.int32) or
360          isinstance(dictionary[key], int)):
361      summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
362    elif isinstance(dictionary[key], six.string_types):
363      try:
364        summ = summary_pb2.Summary.FromString(dictionary[key])
365        for i, _ in enumerate(summ.value):
366          summ.value[i].tag = key
367        summary_proto.value.extend(summ.value)
368      except message.DecodeError:
369        logging.warn('Skipping summary for %s, cannot parse string to Summary.',
370                     key)
371        continue
372    elif isinstance(dictionary[key], np.ndarray):
373      value = summary_proto.value.add()
374      value.tag = key
375      value.node_name = key
376      tensor_proto = tensor_util.make_tensor_proto(dictionary[key])
377      value.tensor.CopyFrom(tensor_proto)
378      logging.info(
379          'Summary for np.ndarray is not visible in Tensorboard by default. '
380          'Consider using a Tensorboard plugin for visualization (see '
381          'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md'
382          ' for more information).')
383    else:
384      logging.warn(
385          'Skipping summary for %s, must be a float, np.float32, np.int64, '
386          'np.int32 or int or np.ndarray or a serialized string of Summary.',
387          key)
388  summary_writer.add_summary(summary_proto, current_global_step)
389  summary_writer.flush()
390
391
392GraphRewriteSpec = collections.namedtuple('GraphRewriteSpec',
393                                          ['tags', 'transforms'])
394
395
396class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable,
397                    trainable.Trainable):
398  """Abstract BaseEstimator class to train and evaluate TensorFlow models.
399
400  THIS CLASS IS DEPRECATED. See
401  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
402  for general migration instructions.
403
404  Users should not instantiate or subclass this class. Instead, use an
405  `Estimator`.
406  """
407
408  # Note that for Google users, this is overridden with
409  # learn_runner.EstimatorConfig.
410  # TODO(wicke): Remove this once launcher takes over config functionality
411  _Config = run_config.RunConfig  # pylint: disable=invalid-name
412
413  @deprecated(None, 'Please replace uses of any Estimator from tf.contrib.learn'
414              ' with an Estimator from tf.estimator.*')
415  def __init__(self, model_dir=None, config=None):
416    """Initializes a BaseEstimator instance.
417
418    Args:
419      model_dir: Directory to save model parameters, graph and etc. This can
420        also be used to load checkpoints from the directory into a estimator to
421        continue training a previously saved model. If `None`, the model_dir in
422        `config` will be used if set. If both are set, they must be same.
423      config: A RunConfig instance.
424    """
425    # Create a run configuration.
426    if config is None:
427      self._config = BaseEstimator._Config()
428      logging.info('Using default config.')
429    else:
430      self._config = config
431
432    if self._config.session_config is None:
433      self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
434    else:
435      self._session_config = self._config.session_config
436
437    # Model directory.
438    if (model_dir is not None) and (self._config.model_dir is not None):
439      if model_dir != self._config.model_dir:
440        # TODO(b/9965722): remove this suppression after it is no longer
441        #                  necessary.
442        # pylint: disable=g-doc-exception
443        raise ValueError(
444            'model_dir are set both in constructor and RunConfig, but with '
445            "different values. In constructor: '{}', in RunConfig: "
446            "'{}' ".format(model_dir, self._config.model_dir))
447        # pylint: enable=g-doc-exception
448
449    self._model_dir = model_dir or self._config.model_dir
450    if self._model_dir is None:
451      self._model_dir = tempfile.mkdtemp()
452      logging.warning('Using temporary folder as model directory: %s',
453                      self._model_dir)
454    if self._config.model_dir is None:
455      self._config = self._config.replace(model_dir=self._model_dir)
456    logging.info('Using config: %s', str(vars(self._config)))
457
458    # Set device function depending if there are replicas or not.
459    self._device_fn = _get_replica_device_setter(self._config)
460
461    # Features and labels TensorSignature objects.
462    # TODO(wicke): Rename these to something more descriptive
463    self._features_info = None
464    self._labels_info = None
465
466    self._graph = None
467
468  @property
469  def config(self):
470    # TODO(wicke): make RunConfig immutable, and then return it without a copy.
471    return copy.deepcopy(self._config)
472
473  @property
474  def model_fn(self):
475    """Returns the model_fn which is bound to self.params.
476
477    Returns:
478      The model_fn with the following signature:
479        `def model_fn(features, labels, mode, metrics)`
480    """
481
482    def public_model_fn(features, labels, mode, config):
483      return self._call_model_fn(features, labels, mode, config=config)
484
485    return public_model_fn
486
487  @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
488                   ('x', None), ('y', None), ('batch_size', None))
489  def fit(self,
490          x=None,
491          y=None,
492          input_fn=None,
493          steps=None,
494          batch_size=None,
495          monitors=None,
496          max_steps=None):
497    # pylint: disable=g-doc-args,g-doc-return-or-yield
498    """See `Trainable`.
499
500    Raises:
501      ValueError: If `x` or `y` are not `None` while `input_fn` is not `None`.
502      ValueError: If both `steps` and `max_steps` are not `None`.
503    """
504    if (steps is not None) and (max_steps is not None):
505      raise ValueError('Can not provide both steps and max_steps.')
506    _verify_input_args(x, y, input_fn, None, batch_size)
507    if x is not None:
508      SKCompat(self).fit(x, y, batch_size, steps, max_steps, monitors)
509      return self
510
511    if max_steps is not None:
512      try:
513        start_step = load_variable(self._model_dir, ops.GraphKeys.GLOBAL_STEP)
514        if max_steps <= start_step:
515          logging.info('Skipping training since max_steps has already saved.')
516          return self
517      except:  # pylint: disable=bare-except
518        pass
519
520    hooks = monitor_lib.replace_monitors_with_hooks(monitors, self)
521    if steps is not None or max_steps is not None:
522      hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
523
524    loss = self._train_model(input_fn=input_fn, hooks=hooks)
525    logging.info('Loss for final step: %s.', loss)
526    return self
527
528  @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
529                   ('x', None), ('y', None), ('batch_size', None))
530  def partial_fit(self,
531                  x=None,
532                  y=None,
533                  input_fn=None,
534                  steps=1,
535                  batch_size=None,
536                  monitors=None):
537    """Incremental fit on a batch of samples.
538
539    This method is expected to be called several times consecutively
540    on different or the same chunks of the dataset. This either can
541    implement iterative training or out-of-core/online training.
542
543    This is especially useful when the whole dataset is too big to
544    fit in memory at the same time. Or when model is taking long time
545    to converge, and you want to split up training into subparts.
546
547    Args:
548      x: Matrix of shape [n_samples, n_features...]. Can be iterator that
549         returns arrays of features. The training input samples for fitting the
550         model. If set, `input_fn` must be `None`.
551      y: Vector or matrix [n_samples] or [n_samples, n_outputs]. Can be
552         iterator that returns array of labels. The training label values
553         (class labels in classification, real numbers in regression). If set,
554         `input_fn` must be `None`.
555      input_fn: Input function. If set, `x`, `y`, and `batch_size` must be
556        `None`.
557      steps: Number of steps for which to train model. If `None`, train forever.
558      batch_size: minibatch size to use on the input, defaults to first
559        dimension of `x`. Must be `None` if `input_fn` is provided.
560      monitors: List of `BaseMonitor` subclass instances. Used for callbacks
561        inside the training loop.
562
563    Returns:
564      `self`, for chaining.
565
566    Raises:
567      ValueError: If at least one of `x` and `y` is provided, and `input_fn` is
568          provided.
569    """
570    logging.warning('The current implementation of partial_fit is not optimized'
571                    ' for use in a loop. Consider using fit() instead.')
572    return self.fit(
573        x=x,
574        y=y,
575        input_fn=input_fn,
576        steps=steps,
577        batch_size=batch_size,
578        monitors=monitors)
579
580  @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
581                   ('x', None), ('y', None), ('batch_size', None))
582  def evaluate(self,
583               x=None,
584               y=None,
585               input_fn=None,
586               feed_fn=None,
587               batch_size=None,
588               steps=None,
589               metrics=None,
590               name=None,
591               checkpoint_path=None,
592               hooks=None,
593               log_progress=True):
594    # pylint: disable=g-doc-args,g-doc-return-or-yield
595    """See `Evaluable`.
596
597    Raises:
598      ValueError: If at least one of `x` or `y` is provided, and at least one of
599          `input_fn` or `feed_fn` is provided.
600          Or if `metrics` is not `None` or `dict`.
601    """
602    _verify_input_args(x, y, input_fn, feed_fn, batch_size)
603    if x is not None:
604      return SKCompat(self).score(x, y, batch_size, steps, metrics, name)
605
606    if metrics is not None and not isinstance(metrics, dict):
607      raise ValueError('Metrics argument should be None or dict. '
608                       'Got %s.' % metrics)
609    eval_results, global_step = self._evaluate_model(
610        input_fn=input_fn,
611        feed_fn=feed_fn,
612        steps=steps,
613        metrics=metrics,
614        name=name,
615        checkpoint_path=checkpoint_path,
616        hooks=hooks,
617        log_progress=log_progress)
618
619    if eval_results is not None:
620      eval_results.update({'global_step': global_step})
621    return eval_results
622
623  @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS,
624                   ('x', None), ('batch_size', None), ('as_iterable', True))
625  def predict(self,
626              x=None,
627              input_fn=None,
628              batch_size=None,
629              outputs=None,
630              as_iterable=True,
631              iterate_batches=False):
632    """Returns predictions for given features.
633
634    Args:
635      x: Matrix of shape [n_samples, n_features...]. Can be iterator that
636         returns arrays of features. The training input samples for fitting the
637         model. If set, `input_fn` must be `None`.
638      input_fn: Input function. If set, `x` and 'batch_size' must be `None`.
639      batch_size: Override default batch size. If set, 'input_fn' must be
640        'None'.
641      outputs: list of `str`, name of the output to predict.
642        If `None`, returns all.
643      as_iterable: If True, return an iterable which keeps yielding predictions
644        for each example until inputs are exhausted. Note: The inputs must
645        terminate if you want the iterable to terminate (e.g. be sure to pass
646        num_epochs=1 if you are using something like read_batch_features).
647      iterate_batches: If True, yield the whole batch at once instead of
648        decomposing the batch into individual samples. Only relevant when
649        as_iterable is True.
650
651    Returns:
652      A numpy array of predicted classes or regression values if the
653      constructor's `model_fn` returns a `Tensor` for `predictions` or a `dict`
654      of numpy arrays if `model_fn` returns a `dict`. Returns an iterable of
655      predictions if as_iterable is True.
656
657    Raises:
658      ValueError: If x and input_fn are both provided or both `None`.
659    """
660    _verify_input_args(x, None, input_fn, None, batch_size)
661    if x is not None and not as_iterable:
662      return SKCompat(self).predict(x, batch_size)
663
664    input_fn, feed_fn = _get_input_fn(x, None, input_fn, None, batch_size)
665    return self._infer_model(
666        input_fn=input_fn,
667        feed_fn=feed_fn,
668        outputs=outputs,
669        as_iterable=as_iterable,
670        iterate_batches=iterate_batches)
671
672  def get_variable_value(self, name):
673    """Returns value of the variable given by name.
674
675    Args:
676      name: string, name of the tensor.
677
678    Returns:
679      Numpy array - value of the tensor.
680    """
681    return load_variable(self.model_dir, name)
682
683  def get_variable_names(self):
684    """Returns list of all variable names in this model.
685
686    Returns:
687      List of names.
688    """
689    return [name for name, _ in list_variables(self.model_dir)]
690
691  @property
692  def model_dir(self):
693    return self._model_dir
694
695  @deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.')
696  def export(
697      self,
698      export_dir,
699      input_fn=export._default_input_fn,  # pylint: disable=protected-access
700      input_feature_key=None,
701      use_deprecated_input_fn=True,
702      signature_fn=None,
703      prediction_key=None,
704      default_batch_size=1,
705      exports_to_keep=None,
706      checkpoint_path=None):
707    """Exports inference graph into given dir.
708
709    Args:
710      export_dir: A string containing a directory to write the exported graph
711        and checkpoints.
712      input_fn: If `use_deprecated_input_fn` is true, then a function that given
713        `Tensor` of `Example` strings, parses it into features that are then
714        passed to the model. Otherwise, a function that takes no argument and
715        returns a tuple of (features, labels), where features is a dict of
716        string key to `Tensor` and labels is a `Tensor` that's currently not
717        used (and so can be `None`).
718      input_feature_key: Only used if `use_deprecated_input_fn` is false. String
719        key into the features dict returned by `input_fn` that corresponds to a
720        the raw `Example` strings `Tensor` that the exported model will take as
721        input. Can only be `None` if you're using a custom `signature_fn` that
722        does not use the first arg (examples).
723      use_deprecated_input_fn: Determines the signature format of `input_fn`.
724      signature_fn: Function that returns a default signature and a named
725        signature map, given `Tensor` of `Example` strings, `dict` of `Tensor`s
726        for features and `Tensor` or `dict` of `Tensor`s for predictions.
727      prediction_key: The key for a tensor in the `predictions` dict (output
728        from the `model_fn`) to use as the `predictions` input to the
729        `signature_fn`. Optional. If `None`, predictions will pass to
730        `signature_fn` without filtering.
731      default_batch_size: Default batch size of the `Example` placeholder.
732      exports_to_keep: Number of exports to keep.
733      checkpoint_path: the checkpoint path of the model to be exported. If it is
734          `None` (which is default), will use the latest checkpoint in
735          export_dir.
736
737    Returns:
738      The string path to the exported directory. NB: this functionality was
739      added ca. 2016/09/25; clients that depend on the return value may need
740      to handle the case where this function returns None because subclasses
741      are not returning a value.
742    """
743    # pylint: disable=protected-access
744    return export._export_estimator(
745        estimator=self,
746        export_dir=export_dir,
747        signature_fn=signature_fn,
748        prediction_key=prediction_key,
749        input_fn=input_fn,
750        input_feature_key=input_feature_key,
751        use_deprecated_input_fn=use_deprecated_input_fn,
752        default_batch_size=default_batch_size,
753        exports_to_keep=exports_to_keep,
754        checkpoint_path=checkpoint_path)
755
756  @abc.abstractproperty
757  def _get_train_ops(self, features, labels):
758    """Method that builds model graph and returns trainer ops.
759
760    Expected to be overridden by sub-classes that require custom support.
761
762    Args:
763      features: `Tensor` or `dict` of `Tensor` objects.
764      labels: `Tensor` or `dict` of `Tensor` objects.
765
766    Returns:
767      A `ModelFnOps` object.
768    """
769    pass
770
771  @abc.abstractproperty
772  def _get_predict_ops(self, features):
773    """Method that builds model graph and returns prediction ops.
774
775    Args:
776      features: `Tensor` or `dict` of `Tensor` objects.
777
778    Returns:
779      A `ModelFnOps` object.
780    """
781    pass
782
783  def _get_eval_ops(self, features, labels, metrics):
784    """Method that builds model graph and returns evaluation ops.
785
786    Expected to be overridden by sub-classes that require custom support.
787
788    Args:
789      features: `Tensor` or `dict` of `Tensor` objects.
790      labels: `Tensor` or `dict` of `Tensor` objects.
791      metrics: Dict of metrics to run. If None, the default metric functions
792        are used; if {}, no metrics are used. Otherwise, `metrics` should map
793        friendly names for the metric to a `MetricSpec` object defining which
794        model outputs to evaluate against which labels with which metric
795        function. Metric ops should support streaming, e.g., returning
796        update_op and value tensors. See more details in
797        `../../../../metrics/python/metrics/ops/streaming_metrics.py` and
798        `../metric_spec.py`.
799
800    Returns:
801      A `ModelFnOps` object.
802    """
803    raise NotImplementedError('_get_eval_ops not implemented in BaseEstimator')
804
805  @deprecated(
806      '2016-09-23',
807      'The signature of the input_fn accepted by export is changing to be '
808      'consistent with what\'s used by tf.Learn Estimator\'s train/evaluate, '
809      'which makes this function useless. This will be removed after the '
810      'deprecation date.')
811  def _get_feature_ops_from_example(self, examples_batch):
812    """Returns feature parser for given example batch using features info.
813
814    This function requires `fit()` has been called.
815
816    Args:
817      examples_batch: batch of tf.Example
818
819    Returns:
820      features: `Tensor` or `dict` of `Tensor` objects.
821
822    Raises:
823      ValueError: If `_features_info` attribute is not available (usually
824      because `fit()` has not been called).
825    """
826    if self._features_info is None:
827      raise ValueError('Features information missing, was fit() ever called?')
828    return tensor_signature.create_example_parser_from_signatures(
829        self._features_info, examples_batch)
830
831  def _check_inputs(self, features, labels):
832    if self._features_info is not None:
833      logging.debug('Given features: %s, required signatures: %s.',
834                    str(features), str(self._features_info))
835      if not tensor_signature.tensors_compatible(features, self._features_info):
836        raise ValueError('Features are incompatible with given information. '
837                         'Given features: %s, required signatures: %s.' %
838                         (str(features), str(self._features_info)))
839    else:
840      self._features_info = tensor_signature.create_signatures(features)
841      logging.debug('Setting feature info to %s.', str(self._features_info))
842    if labels is not None:
843      if self._labels_info is not None:
844        logging.debug('Given labels: %s, required signatures: %s.', str(labels),
845                      str(self._labels_info))
846        if not tensor_signature.tensors_compatible(labels, self._labels_info):
847          raise ValueError('Labels are incompatible with given information. '
848                           'Given labels: %s, required signatures: %s.' %
849                           (str(labels), str(self._labels_info)))
850      else:
851        self._labels_info = tensor_signature.create_signatures(labels)
852        logging.debug('Setting labels info to %s', str(self._labels_info))
853
854  def _extract_metric_update_ops(self, eval_dict):
855    """Separate update operations from metric value operations."""
856    update_ops = []
857    value_ops = {}
858    for name, metric_ops in six.iteritems(eval_dict):
859      if isinstance(metric_ops, (list, tuple)):
860        if len(metric_ops) == 2:
861          value_ops[name] = metric_ops[0]
862          update_ops.append(metric_ops[1])
863        else:
864          logging.warning(
865              'Ignoring metric {}. It returned a list|tuple with len {}, '
866              'expected 2'.format(name, len(metric_ops)))
867          value_ops[name] = metric_ops
868      else:
869        value_ops[name] = metric_ops
870
871    if update_ops:
872      update_ops = control_flow_ops.group(*update_ops)
873    else:
874      update_ops = None
875
876    return update_ops, value_ops
877
878  def _evaluate_model(self,
879                      input_fn,
880                      steps,
881                      feed_fn=None,
882                      metrics=None,
883                      name='',
884                      checkpoint_path=None,
885                      hooks=None,
886                      log_progress=True):
887    # TODO(wicke): Remove this once Model and associated code are gone.
888    if (hasattr(self._config, 'execution_mode') and
889        self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset')):
890      return None, None
891
892    # Check that model has been trained (if nothing has been set explicitly).
893    if not checkpoint_path:
894      latest_path = checkpoint_management.latest_checkpoint(self._model_dir)
895      if not latest_path:
896        raise NotFittedError(
897            "Couldn't find trained model at %s." % self._model_dir)
898      checkpoint_path = latest_path
899
900    # Setup output directory.
901    eval_dir = os.path.join(self._model_dir, 'eval'
902                            if not name else 'eval_' + name)
903
904    with ops.Graph().as_default() as g:
905      random_seed.set_random_seed(self._config.tf_random_seed)
906      global_step = training_util.create_global_step(g)
907      features, labels = input_fn()
908      self._check_inputs(features, labels)
909
910      model_fn_results = self._get_eval_ops(features, labels, metrics)
911      eval_dict = model_fn_results.eval_metric_ops
912
913      update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
914
915      # We need to copy the hook array as we modify it, thus [:].
916      hooks = hooks[:] if hooks else []
917      if feed_fn:
918        hooks.append(basic_session_run_hooks.FeedFnHook(feed_fn))
919      if steps == 0:
920        logging.warning('evaluation steps are 0. If `input_fn` does not raise '
921                        '`OutOfRangeError`, the evaluation will never stop. '
922                        'Use steps=None if intended.')
923      if steps:
924        hooks.append(
925            evaluation.StopAfterNEvalsHook(steps, log_progress=log_progress))
926
927      global_step_key = 'global_step'
928      while global_step_key in eval_dict:
929        global_step_key = '_' + global_step_key
930      eval_dict[global_step_key] = global_step
931
932      eval_results = evaluation.evaluate_once(
933          checkpoint_path=checkpoint_path,
934          master=self._config.evaluation_master,
935          scaffold=model_fn_results.scaffold,
936          eval_ops=update_op,
937          final_ops=eval_dict,
938          hooks=hooks,
939          config=self._session_config)
940      current_global_step = eval_results[global_step_key]
941
942      _write_dict_to_summary(eval_dir, eval_results, current_global_step)
943
944    return eval_results, current_global_step
945
946  def _get_features_from_input_fn(self, input_fn):
947    result = input_fn()
948    if isinstance(result, (list, tuple)):
949      return result[0]
950    return result
951
952  def _infer_model(self,
953                   input_fn,
954                   feed_fn=None,
955                   outputs=None,
956                   as_iterable=True,
957                   iterate_batches=False):
958    # Check that model has been trained.
959    checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
960    if not checkpoint_path:
961      raise NotFittedError(
962          "Couldn't find trained model at %s." % self._model_dir)
963
964    with ops.Graph().as_default() as g:
965      random_seed.set_random_seed(self._config.tf_random_seed)
966      training_util.create_global_step(g)
967      features = self._get_features_from_input_fn(input_fn)
968      infer_ops = self._get_predict_ops(features)
969      predictions = self._filter_predictions(infer_ops.predictions, outputs)
970      mon_sess = monitored_session.MonitoredSession(
971          session_creator=monitored_session.ChiefSessionCreator(
972              checkpoint_filename_with_path=checkpoint_path,
973              scaffold=infer_ops.scaffold,
974              config=self._session_config))
975      if not as_iterable:
976        with mon_sess:
977          if not mon_sess.should_stop():
978            return mon_sess.run(predictions, feed_fn() if feed_fn else None)
979      else:
980        return self._predict_generator(mon_sess, predictions, feed_fn,
981                                       iterate_batches)
982
983  def _predict_generator(self, mon_sess, predictions, feed_fn, iterate_batches):
984    with mon_sess:
985      while not mon_sess.should_stop():
986        preds = mon_sess.run(predictions, feed_fn() if feed_fn else None)
987        if iterate_batches:
988          yield preds
989        elif not isinstance(predictions, dict):
990          for pred in preds:
991            yield pred
992        else:
993          first_tensor = list(preds.values())[0]
994          if isinstance(first_tensor, sparse_tensor.SparseTensorValue):
995            batch_length = first_tensor.dense_shape[0]
996          else:
997            batch_length = first_tensor.shape[0]
998          for i in range(batch_length):
999            yield {key: value[i] for key, value in six.iteritems(preds)}
1000        if self._is_input_constant(feed_fn, mon_sess.graph):
1001          return
1002
1003  def _is_input_constant(self, feed_fn, graph):
1004    # If there are no queue_runners, the input `predictions` is a
1005    # constant, and we should stop after the first epoch.  If,
1006    # instead, there are queue_runners, eventually they should throw
1007    # an `OutOfRangeError`.
1008    if graph.get_collection(ops.GraphKeys.QUEUE_RUNNERS):
1009      return False
1010    # data_feeder uses feed_fn to generate `OutOfRangeError`.
1011    if feed_fn is not None:
1012      return False
1013    return True
1014
1015  def _filter_predictions(self, predictions, outputs):
1016    if not outputs:
1017      return predictions
1018    if not isinstance(predictions, dict):
1019      raise ValueError(
1020          'outputs argument is not valid in case of non-dict predictions.')
1021    existing_keys = predictions.keys()
1022    predictions = {
1023        key: value
1024        for key, value in six.iteritems(predictions)
1025        if key in outputs
1026    }
1027    if not predictions:
1028      raise ValueError('Expected to run at least one output from %s, '
1029                       'provided %s.' % (existing_keys, outputs))
1030    return predictions
1031
1032  def _train_model(self, input_fn, hooks):
1033    all_hooks = []
1034    self._graph = ops.Graph()
1035    with self._graph.as_default() as g, g.device(self._device_fn):
1036      random_seed.set_random_seed(self._config.tf_random_seed)
1037      global_step = training_util.create_global_step(g)
1038      features, labels = input_fn()
1039      self._check_inputs(features, labels)
1040      training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
1041      model_fn_ops = self._get_train_ops(features, labels)
1042      ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss)
1043      all_hooks.extend(hooks)
1044      all_hooks.extend([
1045          basic_session_run_hooks.NanTensorHook(model_fn_ops.loss),
1046          basic_session_run_hooks.LoggingTensorHook(
1047              {
1048                  'loss': model_fn_ops.loss,
1049                  'step': global_step
1050              },
1051              every_n_iter=100)
1052      ])
1053
1054      scaffold = model_fn_ops.scaffold or monitored_session.Scaffold()
1055      if not (scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)):
1056        ops.add_to_collection(
1057            ops.GraphKeys.SAVERS,
1058            saver.Saver(
1059                sharded=True,
1060                max_to_keep=self._config.keep_checkpoint_max,
1061                keep_checkpoint_every_n_hours=(
1062                    self._config.keep_checkpoint_every_n_hours),
1063                defer_build=True,
1064                save_relative_paths=True))
1065
1066      chief_hooks = []
1067      if (self._config.save_checkpoints_secs or
1068          self._config.save_checkpoints_steps):
1069        saver_hook_exists = any(
1070            isinstance(h, basic_session_run_hooks.CheckpointSaverHook)
1071            for h in (all_hooks + model_fn_ops.training_hooks + chief_hooks +
1072                      model_fn_ops.training_chief_hooks)
1073        )
1074        if not saver_hook_exists:
1075          chief_hooks = [
1076              basic_session_run_hooks.CheckpointSaverHook(
1077                  self._model_dir,
1078                  save_secs=self._config.save_checkpoints_secs,
1079                  save_steps=self._config.save_checkpoints_steps,
1080                  scaffold=scaffold)
1081          ]
1082      with monitored_session.MonitoredTrainingSession(
1083          master=self._config.master,
1084          is_chief=self._config.is_chief,
1085          checkpoint_dir=self._model_dir,
1086          scaffold=scaffold,
1087          hooks=all_hooks + model_fn_ops.training_hooks,
1088          chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks,
1089          save_checkpoint_secs=0,  # Saving is handled by a hook.
1090          save_summaries_steps=self._config.save_summary_steps,
1091          config=self._session_config) as mon_sess:
1092        loss = None
1093        while not mon_sess.should_stop():
1094          _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss])
1095      return loss
1096
1097
1098def _identity_feature_engineering_fn(features, labels):
1099  return features, labels
1100
1101
1102class Estimator(BaseEstimator):
1103  """Estimator class is the basic TensorFlow model trainer/evaluator.
1104
1105  THIS CLASS IS DEPRECATED. See
1106  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
1107  for general migration instructions.
1108  """
1109
1110  def __init__(self,
1111               model_fn=None,
1112               model_dir=None,
1113               config=None,
1114               params=None,
1115               feature_engineering_fn=None):
1116    """Constructs an `Estimator` instance.
1117
1118    Args:
1119      model_fn: Model function. Follows the signature:
1120        * Args:
1121          * `features`: single `Tensor` or `dict` of `Tensor`s
1122                 (depending on data passed to `fit`),
1123          * `labels`: `Tensor` or `dict` of `Tensor`s (for multi-head
1124                 models). If mode is `ModeKeys.INFER`, `labels=None` will be
1125                 passed. If the `model_fn`'s signature does not accept
1126                 `mode`, the `model_fn` must still be able to handle
1127                 `labels=None`.
1128          * `mode`: Optional. Specifies if this training, evaluation or
1129                 prediction. See `ModeKeys`.
1130          * `params`: Optional `dict` of hyperparameters.  Will receive what
1131                 is passed to Estimator in `params` parameter. This allows
1132                 to configure Estimators from hyper parameter tuning.
1133          * `config`: Optional configuration object. Will receive what is passed
1134                 to Estimator in `config` parameter, or the default `config`.
1135                 Allows updating things in your model_fn based on configuration
1136                 such as `num_ps_replicas`.
1137          * `model_dir`: Optional directory where model parameters, graph etc
1138                 are saved. Will receive what is passed to Estimator in
1139                 `model_dir` parameter, or the default `model_dir`. Allows
1140                 updating things in your model_fn that expect model_dir, such as
1141                 training hooks.
1142
1143        * Returns:
1144          `ModelFnOps`
1145
1146        Also supports a legacy signature which returns tuple of:
1147
1148          * predictions: `Tensor`, `SparseTensor` or dictionary of same.
1149              Can also be any type that is convertible to a `Tensor` or
1150              `SparseTensor`, or dictionary of same.
1151          * loss: Scalar loss `Tensor`.
1152          * train_op: Training update `Tensor` or `Operation`.
1153
1154        Supports next three signatures for the function:
1155
1156          * `(features, labels) -> (predictions, loss, train_op)`
1157          * `(features, labels, mode) -> (predictions, loss, train_op)`
1158          * `(features, labels, mode, params) -> (predictions, loss, train_op)`
1159          * `(features, labels, mode, params, config) ->
1160             (predictions, loss, train_op)`
1161          * `(features, labels, mode, params, config, model_dir) ->
1162             (predictions, loss, train_op)`
1163
1164      model_dir: Directory to save model parameters, graph and etc. This can
1165        also be used to load checkpoints from the directory into a estimator to
1166        continue training a previously saved model.
1167      config: Configuration object.
1168      params: `dict` of hyper parameters that will be passed into `model_fn`.
1169              Keys are names of parameters, values are basic python types.
1170      feature_engineering_fn: Feature engineering function. Takes features and
1171                              labels which are the output of `input_fn` and
1172                              returns features and labels which will be fed
1173                              into `model_fn`. Please check `model_fn` for
1174                              a definition of features and labels.
1175
1176    Raises:
1177      ValueError: parameters of `model_fn` don't match `params`.
1178    """
1179    super(Estimator, self).__init__(model_dir=model_dir, config=config)
1180    if model_fn is not None:
1181      # Check number of arguments of the given function matches requirements.
1182      model_fn_args = _model_fn_args(model_fn)
1183      if params is not None and 'params' not in model_fn_args:
1184        raise ValueError('Estimator\'s model_fn (%s) does not have a params '
1185                         'argument, but params (%s) were passed to the '
1186                         'Estimator\'s constructor.' % (model_fn, params))
1187      if params is None and 'params' in model_fn_args:
1188        logging.warning('Estimator\'s model_fn (%s) includes params '
1189                        'argument, but params are not passed to Estimator.',
1190                        model_fn)
1191    self._model_fn = model_fn
1192    self.params = params
1193    self._feature_engineering_fn = (
1194        feature_engineering_fn or _identity_feature_engineering_fn)
1195
1196  def _call_model_fn(self, features, labels, mode, metrics=None, config=None):
1197    """Calls model function with support of 2, 3 or 4 arguments.
1198
1199    Args:
1200      features: features dict.
1201      labels: labels dict.
1202      mode: ModeKeys
1203      metrics: Dict of metrics.
1204      config: RunConfig.
1205
1206    Returns:
1207      A `ModelFnOps` object. If model_fn returns a tuple, wraps them up in a
1208      `ModelFnOps` object.
1209
1210    Raises:
1211      ValueError: if model_fn returns invalid objects.
1212    """
1213    features, labels = self._feature_engineering_fn(features, labels)
1214    model_fn_args = _model_fn_args(self._model_fn)
1215    kwargs = {}
1216    if 'mode' in model_fn_args:
1217      kwargs['mode'] = mode
1218    if 'params' in model_fn_args:
1219      kwargs['params'] = self.params
1220    if 'config' in model_fn_args:
1221      if config:
1222        kwargs['config'] = config
1223      else:
1224        kwargs['config'] = self.config
1225    if 'model_dir' in model_fn_args:
1226      kwargs['model_dir'] = self.model_dir
1227    model_fn_results = self._model_fn(features, labels, **kwargs)
1228
1229    if isinstance(model_fn_results, model_fn_lib.ModelFnOps):
1230      model_fn_ops = model_fn_results
1231    else:
1232      # Here model_fn_results should be a tuple with 3 elements.
1233      if len(model_fn_results) != 3:
1234        raise ValueError('Unrecognized value returned by model_fn, '
1235                         'please return ModelFnOps.')
1236      model_fn_ops = model_fn_lib.ModelFnOps(
1237          mode=mode,
1238          predictions=model_fn_results[0],
1239          loss=model_fn_results[1],
1240          train_op=model_fn_results[2])
1241
1242    # Custom metrics should overwrite defaults.
1243    if metrics:
1244      model_fn_ops.eval_metric_ops.update(
1245          _make_metrics_ops(metrics, features, labels,
1246                            model_fn_ops.predictions))
1247
1248    return model_fn_ops
1249
1250  def _get_train_ops(self, features, labels):
1251    """Method that builds model graph and returns trainer ops.
1252
1253    Expected to be overridden by sub-classes that require custom support.
1254    This implementation uses `model_fn` passed as parameter to constructor to
1255    build model.
1256
1257    Args:
1258      features: `Tensor` or `dict` of `Tensor` objects.
1259      labels: `Tensor` or `dict` of `Tensor` objects.
1260
1261    Returns:
1262      `ModelFnOps` object.
1263    """
1264    return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
1265
1266  def _get_eval_ops(self, features, labels, metrics):
1267    """Method that builds model graph and returns evaluation ops.
1268
1269    Expected to be overridden by sub-classes that require custom support.
1270    This implementation uses `model_fn` passed as parameter to constructor to
1271    build model.
1272
1273    Args:
1274      features: `Tensor` or `dict` of `Tensor` objects.
1275      labels: `Tensor` or `dict` of `Tensor` objects.
1276      metrics: Dict of metrics to run. If None, the default metric functions
1277        are used; if {}, no metrics are used. Otherwise, `metrics` should map
1278        friendly names for the metric to a `MetricSpec` object defining which
1279        model outputs to evaluate against which labels with which metric
1280        function. Metric ops should support streaming, e.g., returning
1281        update_op and value tensors. See more details in
1282        `../../../../metrics/python/metrics/ops/streaming_metrics.py` and
1283        `../metric_spec.py`.
1284
1285    Returns:
1286      `ModelFnOps` object.
1287
1288    Raises:
1289      ValueError: if `metrics` don't match `labels`.
1290    """
1291    model_fn_ops = self._call_model_fn(features, labels,
1292                                       model_fn_lib.ModeKeys.EVAL, metrics)
1293
1294    if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops:
1295      model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = (
1296          metrics_lib.mean(model_fn_ops.loss))
1297    return model_fn_ops
1298
1299  def _get_predict_ops(self, features):
1300    """Method that builds model graph and returns prediction ops.
1301
1302    Expected to be overridden by sub-classes that require custom support.
1303    This implementation uses `model_fn` passed as parameter to constructor to
1304    build model.
1305
1306    Args:
1307      features: `Tensor` or `dict` of `Tensor` objects.
1308
1309    Returns:
1310      `ModelFnOps` object.
1311    """
1312    labels = tensor_signature.create_placeholders_from_signatures(
1313        self._labels_info)
1314    return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER)
1315
1316  def export_savedmodel(self,
1317                        export_dir_base,
1318                        serving_input_fn,
1319                        default_output_alternative_key=None,
1320                        assets_extra=None,
1321                        as_text=False,
1322                        checkpoint_path=None,
1323                        graph_rewrite_specs=(GraphRewriteSpec(
1324                            (tag_constants.SERVING,), ()),),
1325                        strip_default_attrs=False):
1326    # pylint: disable=line-too-long
1327    """Exports inference graph as a SavedModel into given dir.
1328
1329    Args:
1330      export_dir_base: A string containing a directory to write the exported
1331        graph and checkpoints.
1332      serving_input_fn: A function that takes no argument and
1333        returns an `InputFnOps`.
1334      default_output_alternative_key: the name of the head to serve when none is
1335        specified.  Not needed for single-headed models.
1336      assets_extra: A dict specifying how to populate the assets.extra directory
1337        within the exported SavedModel.  Each key should give the destination
1338        path (including the filename) relative to the assets.extra directory.
1339        The corresponding value gives the full path of the source file to be
1340        copied.  For example, the simple case of copying a single file without
1341        renaming it is specified as
1342        `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
1343      as_text: whether to write the SavedModel proto in text format.
1344      checkpoint_path: The checkpoint path to export.  If None (the default),
1345        the most recent checkpoint found within the model directory is chosen.
1346      graph_rewrite_specs: an iterable of `GraphRewriteSpec`.  Each element will
1347        produce a separate MetaGraphDef within the exported SavedModel, tagged
1348        and rewritten as specified.  Defaults to a single entry using the
1349        default serving tag ("serve") and no rewriting.
1350      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
1351        removed from the NodeDefs. For a detailed guide, see
1352        [Stripping Default-Valued
1353          Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
1354
1355    Returns:
1356      The string path to the exported directory.
1357
1358    Raises:
1359      ValueError: if an unrecognized export_type is requested.
1360    """
1361    # pylint: enable=line-too-long
1362    if serving_input_fn is None:
1363      raise ValueError('serving_input_fn must be defined.')
1364
1365    if not checkpoint_path:
1366      # Locate the latest checkpoint
1367      checkpoint_path = checkpoint_management.latest_checkpoint(self._model_dir)
1368    if not checkpoint_path:
1369      raise NotFittedError(
1370          "Couldn't find trained model at %s." % self._model_dir)
1371
1372    export_dir = saved_model_export_utils.get_timestamped_export_dir(
1373        export_dir_base)
1374    # We'll write the SavedModel to a temporary directory and then atomically
1375    # rename it at the end.  This helps to avoid corrupt / incomplete outputs,
1376    # which could otherwise occur if the job is preempted or otherwise fails
1377    # in the middle of SavedModel creation.
1378    temp_export_dir = saved_model_export_utils.get_temp_export_dir(export_dir)
1379    builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
1380
1381    # Build the base graph
1382    with ops.Graph().as_default() as g:
1383      training_util.create_global_step(g)
1384
1385      # Call the serving_input_fn and collect the input alternatives.
1386      input_ops = serving_input_fn()
1387      input_alternatives, features = (
1388          saved_model_export_utils.get_input_alternatives(input_ops))
1389
1390      # TODO(b/34388557) This is a stopgap, pending recording model provenance.
1391      # Record which features are expected at serving time.  It is assumed that
1392      # these are the features that were used in training.
1393      for feature_key in input_ops.features.keys():
1394        ops.add_to_collection(
1395            constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS, feature_key)
1396
1397      # Call the model_fn and collect the output alternatives.
1398      model_fn_ops = self._call_model_fn(features, None,
1399                                         model_fn_lib.ModeKeys.INFER)
1400      output_alternatives, actual_default_output_alternative_key = (
1401          saved_model_export_utils.get_output_alternatives(
1402              model_fn_ops, default_output_alternative_key))
1403
1404      init_op = control_flow_ops.group(variables.local_variables_initializer(),
1405                                       resources.initialize_resources(
1406                                           resources.shared_resources()),
1407                                       lookup_ops.tables_initializer())
1408
1409      # Build the SignatureDefs from all pairs of input and output alternatives
1410      signature_def_map = saved_model_export_utils.build_all_signature_defs(
1411          input_alternatives, output_alternatives,
1412          actual_default_output_alternative_key)
1413
1414      # Export the first MetaGraphDef with variables, assets etc.
1415      with tf_session.Session('') as session:
1416
1417        # pylint: disable=protected-access
1418        saveables = variables._all_saveable_objects()
1419        # pylint: enable=protected-access
1420
1421        if (model_fn_ops.scaffold is not None and
1422            model_fn_ops.scaffold.saver is not None):
1423          saver_for_restore = model_fn_ops.scaffold.saver
1424        elif saveables:
1425          saver_for_restore = saver.Saver(saveables, sharded=True)
1426
1427        saver_for_restore.restore(session, checkpoint_path)
1428
1429        # Perform the export
1430        if not graph_rewrite_specs or graph_rewrite_specs[0].transforms:
1431          raise ValueError('The first element of graph_rewrite_specs '
1432                           'must specify no transforms.')
1433        untransformed_tags = graph_rewrite_specs[0].tags
1434
1435        builder.add_meta_graph_and_variables(
1436            session,
1437            untransformed_tags,
1438            signature_def_map=signature_def_map,
1439            assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS),
1440            main_op=init_op,
1441            strip_default_attrs=strip_default_attrs)
1442
1443    # pylint: disable=protected-access
1444    base_meta_graph_def = builder._saved_model.meta_graphs[0]
1445    # pylint: enable=protected-access
1446
1447    if graph_rewrite_specs[1:]:
1448      # Prepare the input_names and output_names needed for the
1449      # meta_graph_transform call below.
1450      input_names = [
1451          tensor.name
1452          for input_dict in input_alternatives.values()
1453          for tensor in input_dict.values()
1454      ]
1455      output_names = [
1456          tensor.name
1457          for output_alternative in output_alternatives.values()
1458          for tensor in output_alternative[1].values()
1459      ]
1460
1461    # Write the additional MetaGraphDefs
1462    for graph_rewrite_spec in graph_rewrite_specs[1:]:
1463
1464      # TODO(soergel) consider moving most of this to saved_model.builder_impl
1465      # as e.g. builder.add_rewritten_meta_graph(rewritten_graph_def, tags)
1466
1467      transformed_meta_graph_def = meta_graph_transform.meta_graph_transform(
1468          base_meta_graph_def, input_names, output_names,
1469          graph_rewrite_spec.transforms, graph_rewrite_spec.tags)
1470
1471      # pylint: disable=protected-access
1472      meta_graph_def = builder._saved_model.meta_graphs.add()
1473      # pylint: enable=protected-access
1474      meta_graph_def.CopyFrom(transformed_meta_graph_def)
1475
1476    # Add the extra assets
1477    if assets_extra:
1478      assets_extra_path = os.path.join(
1479          compat.as_bytes(temp_export_dir), compat.as_bytes('assets.extra'))
1480      for dest_relative, source in assets_extra.items():
1481        dest_absolute = os.path.join(
1482            compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative))
1483        dest_path = os.path.dirname(dest_absolute)
1484        gfile.MakeDirs(dest_path)
1485        gfile.Copy(source, dest_absolute)
1486
1487    builder.save(as_text)
1488    gfile.Rename(temp_export_dir, export_dir)
1489    return export_dir
1490
1491
1492# For time of deprecation x,y from Estimator allow direct access.
1493# pylint: disable=protected-access
1494class SKCompat(sklearn.BaseEstimator):
1495  """Scikit learn wrapper for TensorFlow Learn Estimator.
1496
1497  THIS CLASS IS DEPRECATED. See
1498  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
1499  for general migration instructions.
1500  """
1501
1502  @deprecated(None, 'Please switch to the Estimator interface.')
1503  def __init__(self, estimator):
1504    self._estimator = estimator
1505
1506  def fit(self, x, y, batch_size=128, steps=None, max_steps=None,
1507          monitors=None):
1508    input_fn, feed_fn = _get_input_fn(
1509        x,
1510        y,
1511        input_fn=None,
1512        feed_fn=None,
1513        batch_size=batch_size,
1514        shuffle=True,
1515        epochs=None)
1516    all_monitors = []
1517    if feed_fn:
1518      all_monitors = [basic_session_run_hooks.FeedFnHook(feed_fn)]
1519    if monitors:
1520      all_monitors.extend(monitors)
1521
1522    self._estimator.fit(
1523        input_fn=input_fn,
1524        steps=steps,
1525        max_steps=max_steps,
1526        monitors=all_monitors)
1527    return self
1528
1529  def score(self, x, y, batch_size=128, steps=None, metrics=None, name=None):
1530    input_fn, feed_fn = _get_input_fn(
1531        x,
1532        y,
1533        input_fn=None,
1534        feed_fn=None,
1535        batch_size=batch_size,
1536        shuffle=False,
1537        epochs=1)
1538    if metrics is not None and not isinstance(metrics, dict):
1539      raise ValueError('Metrics argument should be None or dict. '
1540                       'Got %s.' % metrics)
1541    eval_results, global_step = self._estimator._evaluate_model(
1542        input_fn=input_fn,
1543        feed_fn=feed_fn,
1544        steps=steps,
1545        metrics=metrics,
1546        name=name)
1547    if eval_results is not None:
1548      eval_results.update({'global_step': global_step})
1549    return eval_results
1550
1551  def predict(self, x, batch_size=128, outputs=None):
1552    input_fn, feed_fn = _get_input_fn(
1553        x,
1554        None,
1555        input_fn=None,
1556        feed_fn=None,
1557        batch_size=batch_size,
1558        shuffle=False,
1559        epochs=1)
1560    results = list(
1561        self._estimator._infer_model(
1562            input_fn=input_fn,
1563            feed_fn=feed_fn,
1564            outputs=outputs,
1565            as_iterable=True,
1566            iterate_batches=True))
1567    if not isinstance(results[0], dict):
1568      return np.concatenate([output for output in results], axis=0)
1569    return {
1570        key: np.concatenate([output[key] for output in results], axis=0)
1571        for key in results[0]
1572    }
1573