1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for export tools.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import random 23import tempfile 24import numpy as np 25import six 26 27from tensorflow.contrib import learn 28from tensorflow.contrib.layers.python.layers import feature_column 29from tensorflow.contrib.learn.python.learn.utils import export 30from tensorflow.contrib.session_bundle import exporter 31from tensorflow.contrib.session_bundle import manifest_pb2 32from tensorflow.python.client import session 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import errors 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import random_ops 37from tensorflow.python.platform import gfile 38from tensorflow.python.platform import test 39from tensorflow.python.training import saver 40 41_X_KEY = 'my_x_key' 42 43_X_COLUMN = feature_column.real_valued_column(_X_KEY, dimension=1) 44 45 46def _training_input_fn(): 47 x = random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0) 48 y = 2 * x + 3 49 return {_X_KEY: x}, y 50 51 52class ExportTest(test.TestCase): 53 54 def _get_default_signature(self, export_meta_filename): 55 """ Gets the default signature from the export.meta file. """ 56 with session.Session(): 57 save = saver.import_meta_graph(export_meta_filename) 58 meta_graph_def = save.export_meta_graph() 59 collection_def = meta_graph_def.collection_def 60 61 signatures_any = collection_def['serving_signatures'].any_list.value 62 self.assertEquals(len(signatures_any), 1) 63 signatures = manifest_pb2.Signatures() 64 signatures_any[0].Unpack(signatures) 65 default_signature = signatures.default_signature 66 return default_signature 67 68 def _assert_export(self, export_monitor, export_dir, expected_signature): 69 self.assertTrue(gfile.Exists(export_dir)) 70 # Only the written checkpoints are exported. 71 self.assertTrue( 72 saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')), 73 'Exported checkpoint expected but not found: %s' % os.path.join( 74 export_dir, '00000001', 'export')) 75 self.assertTrue( 76 saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')), 77 'Exported checkpoint expected but not found: %s' % os.path.join( 78 export_dir, '00000010', 'export')) 79 self.assertEquals( 80 six.b(os.path.join(export_dir, '00000010')), 81 export_monitor.last_export_dir) 82 # Validate the signature 83 signature = self._get_default_signature( 84 os.path.join(export_dir, '00000010', 'export.meta')) 85 self.assertTrue(signature.HasField(expected_signature)) 86 87 def testExportMonitor_EstimatorProvidesSignature(self): 88 random.seed(42) 89 x = np.random.rand(1000) 90 y = 2 * x + 3 91 cont_features = [feature_column.real_valued_column('', dimension=1)] 92 regressor = learn.LinearRegressor(feature_columns=cont_features) 93 export_dir = os.path.join(tempfile.mkdtemp(), 'export') 94 export_monitor = learn.monitors.ExportMonitor( 95 every_n_steps=1, export_dir=export_dir, exports_to_keep=2) 96 regressor.fit(x, y, steps=10, monitors=[export_monitor]) 97 self._assert_export(export_monitor, export_dir, 'regression_signature') 98 99 def testExportMonitor(self): 100 random.seed(42) 101 x = np.random.rand(1000) 102 y = 2 * x + 3 103 cont_features = [feature_column.real_valued_column('', dimension=1)] 104 export_dir = os.path.join(tempfile.mkdtemp(), 'export') 105 export_monitor = learn.monitors.ExportMonitor( 106 every_n_steps=1, 107 export_dir=export_dir, 108 exports_to_keep=2, 109 signature_fn=export.generic_signature_fn) 110 regressor = learn.LinearRegressor(feature_columns=cont_features) 111 regressor.fit(x, y, steps=10, monitors=[export_monitor]) 112 self._assert_export(export_monitor, export_dir, 'generic_signature') 113 114 def testExportMonitorInputFeatureKeyMissing(self): 115 random.seed(42) 116 117 def _serving_input_fn(): 118 return { 119 _X_KEY: 120 random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0) 121 }, None 122 123 input_feature_key = 'my_example_key' 124 monitor = learn.monitors.ExportMonitor( 125 every_n_steps=1, 126 export_dir=os.path.join(tempfile.mkdtemp(), 'export'), 127 input_fn=_serving_input_fn, 128 input_feature_key=input_feature_key, 129 exports_to_keep=2, 130 signature_fn=export.generic_signature_fn) 131 regressor = learn.LinearRegressor(feature_columns=[_X_COLUMN]) 132 with self.assertRaisesRegexp(KeyError, input_feature_key): 133 regressor.fit(input_fn=_training_input_fn, steps=10, monitors=[monitor]) 134 135 def testExportMonitorInputFeatureKeyNoneNoFeatures(self): 136 random.seed(42) 137 input_feature_key = 'my_example_key' 138 139 def _serving_input_fn(): 140 return {input_feature_key: None}, None 141 142 monitor = learn.monitors.ExportMonitor( 143 every_n_steps=1, 144 export_dir=os.path.join(tempfile.mkdtemp(), 'export'), 145 input_fn=_serving_input_fn, 146 input_feature_key=input_feature_key, 147 exports_to_keep=2, 148 signature_fn=export.generic_signature_fn) 149 regressor = learn.LinearRegressor(feature_columns=[_X_COLUMN]) 150 with self.assertRaisesRegexp(ValueError, 151 'features or examples must be defined'): 152 regressor.fit(input_fn=_training_input_fn, steps=10, monitors=[monitor]) 153 154 def testExportMonitorInputFeatureKeyNone(self): 155 random.seed(42) 156 input_feature_key = 'my_example_key' 157 158 def _serving_input_fn(): 159 return { 160 input_feature_key: 161 None, 162 _X_KEY: 163 random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0) 164 }, None 165 166 monitor = learn.monitors.ExportMonitor( 167 every_n_steps=1, 168 export_dir=os.path.join(tempfile.mkdtemp(), 'export'), 169 input_fn=_serving_input_fn, 170 input_feature_key=input_feature_key, 171 exports_to_keep=2, 172 signature_fn=export.generic_signature_fn) 173 regressor = learn.LinearRegressor(feature_columns=[_X_COLUMN]) 174 with self.assertRaisesRegexp(ValueError, 'examples cannot be None'): 175 regressor.fit(input_fn=_training_input_fn, steps=10, monitors=[monitor]) 176 177 def testExportMonitorInputFeatureKeyNoFeatures(self): 178 random.seed(42) 179 input_feature_key = 'my_example_key' 180 181 def _serving_input_fn(): 182 return { 183 input_feature_key: 184 array_ops.placeholder(dtype=dtypes.string, shape=(1,)) 185 }, None 186 187 monitor = learn.monitors.ExportMonitor( 188 every_n_steps=1, 189 export_dir=os.path.join(tempfile.mkdtemp(), 'export'), 190 input_fn=_serving_input_fn, 191 input_feature_key=input_feature_key, 192 exports_to_keep=2, 193 signature_fn=export.generic_signature_fn) 194 regressor = learn.LinearRegressor(feature_columns=[_X_COLUMN]) 195 with self.assertRaisesRegexp(KeyError, _X_KEY): 196 regressor.fit(input_fn=_training_input_fn, steps=10, monitors=[monitor]) 197 198 def testExportMonitorInputFeature(self): 199 random.seed(42) 200 input_feature_key = 'my_example_key' 201 202 def _serving_input_fn(): 203 return { 204 input_feature_key: 205 array_ops.placeholder(dtype=dtypes.string, shape=(1,)), 206 _X_KEY: 207 random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0) 208 }, None 209 210 export_dir = os.path.join(tempfile.mkdtemp(), 'export') 211 monitor = learn.monitors.ExportMonitor( 212 every_n_steps=1, 213 export_dir=export_dir, 214 input_fn=_serving_input_fn, 215 input_feature_key=input_feature_key, 216 exports_to_keep=2, 217 signature_fn=export.generic_signature_fn) 218 regressor = learn.LinearRegressor(feature_columns=[_X_COLUMN]) 219 regressor.fit(input_fn=_training_input_fn, steps=10, monitors=[monitor]) 220 self._assert_export(monitor, export_dir, 'generic_signature') 221 222 def testExportMonitorRegressionSignature(self): 223 224 def _regression_signature(examples, unused_features, predictions): 225 signatures = {} 226 signatures['regression'] = ( 227 exporter.regression_signature(examples, predictions)) 228 return signatures['regression'], signatures 229 230 random.seed(42) 231 x = np.random.rand(1000) 232 y = 2 * x + 3 233 cont_features = [feature_column.real_valued_column('', dimension=1)] 234 regressor = learn.LinearRegressor(feature_columns=cont_features) 235 export_dir = os.path.join(tempfile.mkdtemp(), 'export') 236 export_monitor = learn.monitors.ExportMonitor( 237 every_n_steps=1, 238 export_dir=export_dir, 239 exports_to_keep=1, 240 signature_fn=_regression_signature) 241 regressor.fit(x, y, steps=10, monitors=[export_monitor]) 242 243 self.assertTrue(gfile.Exists(export_dir)) 244 with self.assertRaises(errors.NotFoundError): 245 saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export')) 246 self.assertTrue( 247 saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export'))) 248 # Validate the signature 249 signature = self._get_default_signature( 250 os.path.join(export_dir, '00000010', 'export.meta')) 251 self.assertTrue(signature.HasField('regression_signature')) 252 253 254if __name__ == '__main__': 255 test.main() 256