• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""ExportStrategy class represents different flavors of model export (deprecated).
16
17This module and all its submodules are deprecated. See
18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
19for migration instructions.
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25
26import collections
27
28from tensorflow.python.util import tf_inspect
29from tensorflow.python.util.deprecation import deprecated
30
31__all__ = ['ExportStrategy']
32
33
34class ExportStrategy(
35    collections.namedtuple('ExportStrategy',
36                           ['name', 'export_fn', 'strip_default_attrs'])):
37  """A class representing a type of model export.
38
39  THIS CLASS IS DEPRECATED. See
40  [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md)
41  for general migration instructions.
42
43  Typically constructed by a utility function specific to the exporter, such as
44  `saved_model_export_utils.make_export_strategy()`.
45
46  Attributes:
47    name: The directory name under the export base directory where exports of
48      this type will be written.
49    export_fn: A function that writes an export, given an estimator, a
50      destination path, and optionally a checkpoint path and an evaluation
51      result for that checkpoint.  This export_fn() may be run repeatedly during
52      continuous training, or just once at the end of fixed-length training.
53      Note the export_fn() may choose whether or not to export based on the eval
54      result or based on an internal timer or any other criterion, if exports
55      are not desired for every checkpoint.
56
57    The signature of this function must be one of:
58
59      * `(estimator, export_path) -> export_path`
60      * `(estimator, export_path, checkpoint_path) -> export_path`
61      * `(estimator, export_path, checkpoint_path, eval_result) -> export_path`
62      * `(estimator, export_path, checkpoint_path, eval_result,
63          strip_default_attrs) -> export_path`
64    strip_default_attrs: (Optional) Boolean. If set as True, default attrs in
65        the `GraphDef` will be stripped on write. This is recommended for better
66        forward compatibility of the resulting `SavedModel`.
67  """
68
69  @deprecated(None, 'Please switch to tf.estimator.train_and_evaluate, and use '
70              'tf.estimator.Exporter.')
71  def __new__(cls, name, export_fn, strip_default_attrs=None):
72    return super(ExportStrategy, cls).__new__(
73        cls, name, export_fn, strip_default_attrs)
74
75  def export(self,
76             estimator,
77             export_path,
78             checkpoint_path=None,
79             eval_result=None):
80    """Exports the given Estimator to a specific format.
81
82    Args:
83      estimator: the Estimator to export.
84      export_path: A string containing a directory where to write the export.
85      checkpoint_path: The checkpoint path to export.  If None (the default),
86        the strategy may locate a checkpoint (e.g. the most recent) by itself.
87      eval_result: The output of Estimator.evaluate on this checkpoint.  This
88        should be set only if checkpoint_path is provided (otherwise it is
89        unclear which checkpoint this eval refers to).
90
91    Returns:
92      The string path to the exported directory.
93
94    Raises:
95      ValueError: if the export_fn does not have the required signature
96    """
97    # don't break existing export_fns that don't accept checkpoint_path and
98    # eval_result
99    export_fn_args = tf_inspect.getargspec(self.export_fn).args
100    kwargs = {}
101    if 'checkpoint_path' in export_fn_args:
102      kwargs['checkpoint_path'] = checkpoint_path
103    if 'eval_result' in export_fn_args:
104      if 'checkpoint_path' not in export_fn_args:
105        raise ValueError('An export_fn accepting eval_result must also accept '
106                         'checkpoint_path.')
107      kwargs['eval_result'] = eval_result
108    if 'strip_default_attrs' in export_fn_args:
109      kwargs['strip_default_attrs'] = self.strip_default_attrs
110    return self.export_fn(estimator, export_path, **kwargs)
111