• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for export tools."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import random
23import tempfile
24import numpy as np
25import six
26
27from tensorflow.contrib import learn
28from tensorflow.contrib.layers.python.layers import feature_column
29from tensorflow.contrib.learn.python.learn.utils import export
30from tensorflow.contrib.session_bundle import exporter
31from tensorflow.contrib.session_bundle import manifest_pb2
32from tensorflow.python.client import session
33from tensorflow.python.framework import dtypes
34from tensorflow.python.framework import errors
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import random_ops
37from tensorflow.python.platform import gfile
38from tensorflow.python.platform import test
39from tensorflow.python.training import saver
40
41_X_KEY = 'my_x_key'
42
43_X_COLUMN = feature_column.real_valued_column(_X_KEY, dimension=1)
44
45
46def _training_input_fn():
47  x = random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0)
48  y = 2 * x + 3
49  return {_X_KEY: x}, y
50
51
52class ExportTest(test.TestCase):
53
54  def _get_default_signature(self, export_meta_filename):
55    """ Gets the default signature from the export.meta file. """
56    with session.Session():
57      save = saver.import_meta_graph(export_meta_filename)
58      meta_graph_def = save.export_meta_graph()
59      collection_def = meta_graph_def.collection_def
60
61      signatures_any = collection_def['serving_signatures'].any_list.value
62      self.assertEquals(len(signatures_any), 1)
63      signatures = manifest_pb2.Signatures()
64      signatures_any[0].Unpack(signatures)
65      default_signature = signatures.default_signature
66      return default_signature
67
68  def _assert_export(self, export_monitor, export_dir, expected_signature):
69    self.assertTrue(gfile.Exists(export_dir))
70    # Only the written checkpoints are exported.
71    self.assertTrue(
72        saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')),
73        'Exported checkpoint expected but not found: %s' % os.path.join(
74            export_dir, '00000001', 'export'))
75    self.assertTrue(
76        saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')),
77        'Exported checkpoint expected but not found: %s' % os.path.join(
78            export_dir, '00000010', 'export'))
79    self.assertEquals(
80        six.b(os.path.join(export_dir, '00000010')),
81        export_monitor.last_export_dir)
82    # Validate the signature
83    signature = self._get_default_signature(
84        os.path.join(export_dir, '00000010', 'export.meta'))
85    self.assertTrue(signature.HasField(expected_signature))
86
87  def testExportMonitor_EstimatorProvidesSignature(self):
88    random.seed(42)
89    x = np.random.rand(1000)
90    y = 2 * x + 3
91    cont_features = [feature_column.real_valued_column('', dimension=1)]
92    regressor = learn.LinearRegressor(feature_columns=cont_features)
93    export_dir = os.path.join(tempfile.mkdtemp(), 'export')
94    export_monitor = learn.monitors.ExportMonitor(
95        every_n_steps=1, export_dir=export_dir, exports_to_keep=2)
96    regressor.fit(x, y, steps=10, monitors=[export_monitor])
97    self._assert_export(export_monitor, export_dir, 'regression_signature')
98
99  def testExportMonitor(self):
100    random.seed(42)
101    x = np.random.rand(1000)
102    y = 2 * x + 3
103    cont_features = [feature_column.real_valued_column('', dimension=1)]
104    export_dir = os.path.join(tempfile.mkdtemp(), 'export')
105    export_monitor = learn.monitors.ExportMonitor(
106        every_n_steps=1,
107        export_dir=export_dir,
108        exports_to_keep=2,
109        signature_fn=export.generic_signature_fn)
110    regressor = learn.LinearRegressor(feature_columns=cont_features)
111    regressor.fit(x, y, steps=10, monitors=[export_monitor])
112    self._assert_export(export_monitor, export_dir, 'generic_signature')
113
114  def testExportMonitorInputFeatureKeyMissing(self):
115    random.seed(42)
116
117    def _serving_input_fn():
118      return {
119          _X_KEY:
120              random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0)
121      }, None
122
123    input_feature_key = 'my_example_key'
124    monitor = learn.monitors.ExportMonitor(
125        every_n_steps=1,
126        export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
127        input_fn=_serving_input_fn,
128        input_feature_key=input_feature_key,
129        exports_to_keep=2,
130        signature_fn=export.generic_signature_fn)
131    regressor = learn.LinearRegressor(feature_columns=[_X_COLUMN])
132    with self.assertRaisesRegexp(KeyError, input_feature_key):
133      regressor.fit(input_fn=_training_input_fn, steps=10, monitors=[monitor])
134
135  def testExportMonitorInputFeatureKeyNoneNoFeatures(self):
136    random.seed(42)
137    input_feature_key = 'my_example_key'
138
139    def _serving_input_fn():
140      return {input_feature_key: None}, None
141
142    monitor = learn.monitors.ExportMonitor(
143        every_n_steps=1,
144        export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
145        input_fn=_serving_input_fn,
146        input_feature_key=input_feature_key,
147        exports_to_keep=2,
148        signature_fn=export.generic_signature_fn)
149    regressor = learn.LinearRegressor(feature_columns=[_X_COLUMN])
150    with self.assertRaisesRegexp(ValueError,
151                                 'features or examples must be defined'):
152      regressor.fit(input_fn=_training_input_fn, steps=10, monitors=[monitor])
153
154  def testExportMonitorInputFeatureKeyNone(self):
155    random.seed(42)
156    input_feature_key = 'my_example_key'
157
158    def _serving_input_fn():
159      return {
160          input_feature_key:
161              None,
162          _X_KEY:
163              random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0)
164      }, None
165
166    monitor = learn.monitors.ExportMonitor(
167        every_n_steps=1,
168        export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
169        input_fn=_serving_input_fn,
170        input_feature_key=input_feature_key,
171        exports_to_keep=2,
172        signature_fn=export.generic_signature_fn)
173    regressor = learn.LinearRegressor(feature_columns=[_X_COLUMN])
174    with self.assertRaisesRegexp(ValueError, 'examples cannot be None'):
175      regressor.fit(input_fn=_training_input_fn, steps=10, monitors=[monitor])
176
177  def testExportMonitorInputFeatureKeyNoFeatures(self):
178    random.seed(42)
179    input_feature_key = 'my_example_key'
180
181    def _serving_input_fn():
182      return {
183          input_feature_key:
184              array_ops.placeholder(dtype=dtypes.string, shape=(1,))
185      }, None
186
187    monitor = learn.monitors.ExportMonitor(
188        every_n_steps=1,
189        export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
190        input_fn=_serving_input_fn,
191        input_feature_key=input_feature_key,
192        exports_to_keep=2,
193        signature_fn=export.generic_signature_fn)
194    regressor = learn.LinearRegressor(feature_columns=[_X_COLUMN])
195    with self.assertRaisesRegexp(KeyError, _X_KEY):
196      regressor.fit(input_fn=_training_input_fn, steps=10, monitors=[monitor])
197
198  def testExportMonitorInputFeature(self):
199    random.seed(42)
200    input_feature_key = 'my_example_key'
201
202    def _serving_input_fn():
203      return {
204          input_feature_key:
205              array_ops.placeholder(dtype=dtypes.string, shape=(1,)),
206          _X_KEY:
207              random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0)
208      }, None
209
210    export_dir = os.path.join(tempfile.mkdtemp(), 'export')
211    monitor = learn.monitors.ExportMonitor(
212        every_n_steps=1,
213        export_dir=export_dir,
214        input_fn=_serving_input_fn,
215        input_feature_key=input_feature_key,
216        exports_to_keep=2,
217        signature_fn=export.generic_signature_fn)
218    regressor = learn.LinearRegressor(feature_columns=[_X_COLUMN])
219    regressor.fit(input_fn=_training_input_fn, steps=10, monitors=[monitor])
220    self._assert_export(monitor, export_dir, 'generic_signature')
221
222  def testExportMonitorRegressionSignature(self):
223
224    def _regression_signature(examples, unused_features, predictions):
225      signatures = {}
226      signatures['regression'] = (
227          exporter.regression_signature(examples, predictions))
228      return signatures['regression'], signatures
229
230    random.seed(42)
231    x = np.random.rand(1000)
232    y = 2 * x + 3
233    cont_features = [feature_column.real_valued_column('', dimension=1)]
234    regressor = learn.LinearRegressor(feature_columns=cont_features)
235    export_dir = os.path.join(tempfile.mkdtemp(), 'export')
236    export_monitor = learn.monitors.ExportMonitor(
237        every_n_steps=1,
238        export_dir=export_dir,
239        exports_to_keep=1,
240        signature_fn=_regression_signature)
241    regressor.fit(x, y, steps=10, monitors=[export_monitor])
242
243    self.assertTrue(gfile.Exists(export_dir))
244    with self.assertRaises(errors.NotFoundError):
245      saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export'))
246    self.assertTrue(
247        saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
248    # Validate the signature
249    signature = self._get_default_signature(
250        os.path.join(export_dir, '00000010', 'export.meta'))
251    self.assertTrue(signature.HasField('regression_signature'))
252
253
254if __name__ == '__main__':
255  test.main()
256