• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Export utilities (deprecated).
17
18This module and all its submodules are deprecated. See
19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
20for migration instructions.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27from tensorflow.contrib.framework import deprecated
28from tensorflow.contrib.session_bundle import exporter
29from tensorflow.contrib.session_bundle import gc
30from tensorflow.python.client import session as tf_session
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import lookup_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.training import checkpoint_management
39from tensorflow.python.training import saver as tf_saver
40from tensorflow.python.training import training_util
41
42
43@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.')
44def _get_first_op_from_collection(collection_name):
45  """Get first element from the collection."""
46  elements = ops.get_collection(collection_name)
47  if elements is not None:
48    if elements:
49      return elements[0]
50  return None
51
52
53@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.')
54def _get_saver():
55  """Lazy init and return saver."""
56  saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS)
57  if saver is not None:
58    if saver:
59      saver = saver[0]
60    else:
61      saver = None
62  if saver is None and variables.global_variables():
63    saver = tf_saver.Saver()
64    ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
65  return saver
66
67
68@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.')
69def _export_graph(graph, saver, checkpoint_path, export_dir,
70                  default_graph_signature, named_graph_signatures,
71                  exports_to_keep):
72  """Exports graph via session_bundle, by creating a Session."""
73  with graph.as_default():
74    with tf_session.Session('') as session:
75      variables.local_variables_initializer()
76      lookup_ops.tables_initializer()
77      saver.restore(session, checkpoint_path)
78
79      export = exporter.Exporter(saver)
80      export.init(
81          init_op=control_flow_ops.group(
82              variables.local_variables_initializer(),
83              lookup_ops.tables_initializer()),
84          default_graph_signature=default_graph_signature,
85          named_graph_signatures=named_graph_signatures,
86          assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))
87      return export.export(export_dir, training_util.get_global_step(),
88                           session, exports_to_keep=exports_to_keep)
89
90
91@deprecated('2017-03-25',
92            'signature_fns are deprecated. For canned Estimators they are no '
93            'longer needed. For custom Estimators, please return '
94            'output_alternatives from your model_fn via ModelFnOps.')
95def generic_signature_fn(examples, unused_features, predictions):
96  """Creates generic signature from given examples and predictions.
97
98  This is needed for backward compatibility with default behavior of
99  export_estimator.
100
101  Args:
102    examples: `Tensor`.
103    unused_features: `dict` of `Tensor`s.
104    predictions: `Tensor` or `dict` of `Tensor`s.
105
106  Returns:
107    Tuple of default signature and empty named signatures.
108
109  Raises:
110    ValueError: If examples is `None`.
111  """
112  if examples is None:
113    raise ValueError('examples cannot be None when using this signature fn.')
114
115  tensors = {'inputs': examples}
116  if not isinstance(predictions, dict):
117    predictions = {'outputs': predictions}
118  tensors.update(predictions)
119  default_signature = exporter.generic_signature(tensors)
120  return default_signature, {}
121
122
123@deprecated('2017-03-25',
124            'signature_fns are deprecated. For canned Estimators they are no '
125            'longer needed. For custom Estimators, please return '
126            'output_alternatives from your model_fn via ModelFnOps.')
127def classification_signature_fn(examples, unused_features, predictions):
128  """Creates classification signature from given examples and predictions.
129
130  Args:
131    examples: `Tensor`.
132    unused_features: `dict` of `Tensor`s.
133    predictions: `Tensor` or dict of tensors that contains the classes tensor
134      as in {'classes': `Tensor`}.
135
136  Returns:
137    Tuple of default classification signature and empty named signatures.
138
139  Raises:
140    ValueError: If examples is `None`.
141  """
142  if examples is None:
143    raise ValueError('examples cannot be None when using this signature fn.')
144
145  if isinstance(predictions, dict):
146    default_signature = exporter.classification_signature(
147        examples, classes_tensor=predictions['classes'])
148  else:
149    default_signature = exporter.classification_signature(
150        examples, classes_tensor=predictions)
151  return default_signature, {}
152
153
154@deprecated('2017-03-25',
155            'signature_fns are deprecated. For canned Estimators they are no '
156            'longer needed. For custom Estimators, please return '
157            'output_alternatives from your model_fn via ModelFnOps.')
158def classification_signature_fn_with_prob(
159    examples, unused_features, predictions):
160  """Classification signature from given examples and predicted probabilities.
161
162  Args:
163    examples: `Tensor`.
164    unused_features: `dict` of `Tensor`s.
165    predictions: `Tensor` of predicted probabilities or dict that contains the
166      probabilities tensor as in {'probabilities', `Tensor`}.
167
168  Returns:
169    Tuple of default classification signature and empty named signatures.
170
171  Raises:
172    ValueError: If examples is `None`.
173  """
174  if examples is None:
175    raise ValueError('examples cannot be None when using this signature fn.')
176
177  if isinstance(predictions, dict):
178    default_signature = exporter.classification_signature(
179        examples, scores_tensor=predictions['probabilities'])
180  else:
181    default_signature = exporter.classification_signature(
182        examples, scores_tensor=predictions)
183  return default_signature, {}
184
185
186@deprecated('2017-03-25',
187            'signature_fns are deprecated. For canned Estimators they are no '
188            'longer needed. For custom Estimators, please return '
189            'output_alternatives from your model_fn via ModelFnOps.')
190def regression_signature_fn(examples, unused_features, predictions):
191  """Creates regression signature from given examples and predictions.
192
193  Args:
194    examples: `Tensor`.
195    unused_features: `dict` of `Tensor`s.
196    predictions: `Tensor`.
197
198  Returns:
199    Tuple of default regression signature and empty named signatures.
200
201  Raises:
202    ValueError: If examples is `None`.
203  """
204  if examples is None:
205    raise ValueError('examples cannot be None when using this signature fn.')
206
207  default_signature = exporter.regression_signature(
208      input_tensor=examples, output_tensor=predictions)
209  return default_signature, {}
210
211
212@deprecated('2017-03-25',
213            'signature_fns are deprecated. For canned Estimators they are no '
214            'longer needed. For custom Estimators, please return '
215            'output_alternatives from your model_fn via ModelFnOps.')
216def logistic_regression_signature_fn(examples, unused_features, predictions):
217  """Creates logistic regression signature from given examples and predictions.
218
219  Args:
220    examples: `Tensor`.
221    unused_features: `dict` of `Tensor`s.
222    predictions: `Tensor` of shape [batch_size, 2] of predicted probabilities or
223      dict that contains the probabilities tensor as in
224      {'probabilities', `Tensor`}.
225
226  Returns:
227    Tuple of default regression signature and named signature.
228
229  Raises:
230    ValueError: If examples is `None`.
231  """
232  if examples is None:
233    raise ValueError('examples cannot be None when using this signature fn.')
234
235  if isinstance(predictions, dict):
236    predictions_tensor = predictions['probabilities']
237  else:
238    predictions_tensor = predictions
239  # predictions should have shape [batch_size, 2] where first column is P(Y=0|x)
240  # while second column is P(Y=1|x). We are only interested in the second
241  # column for inference.
242  predictions_shape = predictions_tensor.get_shape()
243  predictions_rank = len(predictions_shape)
244  if predictions_rank != 2:
245    logging.fatal(
246        'Expected predictions to have rank 2, but received predictions with '
247        'rank: {} and shape: {}'.format(predictions_rank, predictions_shape))
248  if predictions_shape[1] != 2:
249    logging.fatal(
250        'Expected predictions to have 2nd dimension: 2, but received '
251        'predictions with 2nd dimension: {} and shape: {}. Did you mean to use '
252        'regression_signature_fn or classification_signature_fn_with_prob '
253        'instead?'.format(predictions_shape[1], predictions_shape))
254
255  positive_predictions = predictions_tensor[:, 1]
256  default_signature = exporter.regression_signature(
257      input_tensor=examples, output_tensor=positive_predictions)
258  return default_signature, {}
259
260
261# pylint: disable=protected-access
262@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.')
263def _default_input_fn(estimator, examples):
264  """Creates default input parsing using Estimator's feature signatures."""
265  return estimator._get_feature_ops_from_example(examples)
266
267
268@deprecated('2016-09-23', 'Please use Estimator.export_savedmodel() instead.')
269def export_estimator(estimator,
270                     export_dir,
271                     signature_fn=None,
272                     input_fn=_default_input_fn,
273                     default_batch_size=1,
274                     exports_to_keep=None):
275  """Deprecated, please use Estimator.export_savedmodel()."""
276  _export_estimator(estimator=estimator,
277                    export_dir=export_dir,
278                    signature_fn=signature_fn,
279                    input_fn=input_fn,
280                    default_batch_size=default_batch_size,
281                    exports_to_keep=exports_to_keep)
282
283
284@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.')
285def _export_estimator(estimator,
286                      export_dir,
287                      signature_fn,
288                      input_fn,
289                      default_batch_size,
290                      exports_to_keep,
291                      input_feature_key=None,
292                      use_deprecated_input_fn=True,
293                      prediction_key=None,
294                      checkpoint_path=None):
295  if use_deprecated_input_fn:
296    input_fn = input_fn or _default_input_fn
297  elif input_fn is None:
298    raise ValueError('input_fn must be defined.')
299
300  # If checkpoint_path is specified, use the specified checkpoint path.
301  checkpoint_path = (checkpoint_path or
302                     checkpoint_management.latest_checkpoint(
303                         estimator._model_dir))
304  with ops.Graph().as_default() as g:
305    training_util.create_global_step(g)
306
307    if use_deprecated_input_fn:
308      examples = array_ops.placeholder(dtype=dtypes.string,
309                                       shape=[default_batch_size],
310                                       name='input_example_tensor')
311      features = input_fn(estimator, examples)
312    else:
313      features, _ = input_fn()
314      examples = None
315      if input_feature_key is not None:
316        examples = features.pop(input_feature_key)
317
318    if (not features) and (examples is None):
319      raise ValueError('Either features or examples must be defined.')
320
321    predictions = estimator._get_predict_ops(features).predictions
322
323    if prediction_key is not None:
324      predictions = predictions[prediction_key]
325
326    # Explicit signature_fn takes priority
327    if signature_fn:
328      default_signature, named_graph_signatures = signature_fn(examples,
329                                                               features,
330                                                               predictions)
331    else:
332      try:
333        # Some estimators provide a signature function.
334        # TODO(zakaria): check if the estimator has this function,
335        #   raise helpful error if not
336        signature_fn = estimator._create_signature_fn()
337
338        default_signature, named_graph_signatures = (
339            signature_fn(examples, features, predictions))
340      except AttributeError:
341        logging.warn(
342            'Change warning: `signature_fn` will be required after'
343            '2016-08-01.\n'
344            'Using generic signatures for now.  To maintain this behavior, '
345            'pass:\n'
346            '  signature_fn=export.generic_signature_fn\n'
347            'Also consider passing a regression or classification signature; '
348            'see cl/126430915 for an example.')
349        default_signature, named_graph_signatures = generic_signature_fn(
350            examples, features, predictions)
351    if exports_to_keep is not None:
352      exports_to_keep = gc.largest_export_versions(exports_to_keep)
353    return _export_graph(
354        g,
355        _get_saver(),
356        checkpoint_path,
357        export_dir,
358        default_graph_signature=default_signature,
359        named_graph_signatures=named_graph_signatures,
360        exports_to_keep=exports_to_keep)
361# pylint: enable=protected-access
362