1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Export utilities (deprecated). 17 18This module and all its submodules are deprecated. See 19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 20for migration instructions. 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27from tensorflow.contrib.framework import deprecated 28from tensorflow.contrib.session_bundle import exporter 29from tensorflow.contrib.session_bundle import gc 30from tensorflow.python.client import session as tf_session 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import lookup_ops 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import tf_logging as logging 38from tensorflow.python.training import checkpoint_management 39from tensorflow.python.training import saver as tf_saver 40from tensorflow.python.training import training_util 41 42 43@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') 44def _get_first_op_from_collection(collection_name): 45 """Get first element from the collection.""" 46 elements = ops.get_collection(collection_name) 47 if elements is not None: 48 if elements: 49 return elements[0] 50 return None 51 52 53@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') 54def _get_saver(): 55 """Lazy init and return saver.""" 56 saver = _get_first_op_from_collection(ops.GraphKeys.SAVERS) 57 if saver is not None: 58 if saver: 59 saver = saver[0] 60 else: 61 saver = None 62 if saver is None and variables.global_variables(): 63 saver = tf_saver.Saver() 64 ops.add_to_collection(ops.GraphKeys.SAVERS, saver) 65 return saver 66 67 68@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') 69def _export_graph(graph, saver, checkpoint_path, export_dir, 70 default_graph_signature, named_graph_signatures, 71 exports_to_keep): 72 """Exports graph via session_bundle, by creating a Session.""" 73 with graph.as_default(): 74 with tf_session.Session('') as session: 75 variables.local_variables_initializer() 76 lookup_ops.tables_initializer() 77 saver.restore(session, checkpoint_path) 78 79 export = exporter.Exporter(saver) 80 export.init( 81 init_op=control_flow_ops.group( 82 variables.local_variables_initializer(), 83 lookup_ops.tables_initializer()), 84 default_graph_signature=default_graph_signature, 85 named_graph_signatures=named_graph_signatures, 86 assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) 87 return export.export(export_dir, training_util.get_global_step(), 88 session, exports_to_keep=exports_to_keep) 89 90 91@deprecated('2017-03-25', 92 'signature_fns are deprecated. For canned Estimators they are no ' 93 'longer needed. For custom Estimators, please return ' 94 'output_alternatives from your model_fn via ModelFnOps.') 95def generic_signature_fn(examples, unused_features, predictions): 96 """Creates generic signature from given examples and predictions. 97 98 This is needed for backward compatibility with default behavior of 99 export_estimator. 100 101 Args: 102 examples: `Tensor`. 103 unused_features: `dict` of `Tensor`s. 104 predictions: `Tensor` or `dict` of `Tensor`s. 105 106 Returns: 107 Tuple of default signature and empty named signatures. 108 109 Raises: 110 ValueError: If examples is `None`. 111 """ 112 if examples is None: 113 raise ValueError('examples cannot be None when using this signature fn.') 114 115 tensors = {'inputs': examples} 116 if not isinstance(predictions, dict): 117 predictions = {'outputs': predictions} 118 tensors.update(predictions) 119 default_signature = exporter.generic_signature(tensors) 120 return default_signature, {} 121 122 123@deprecated('2017-03-25', 124 'signature_fns are deprecated. For canned Estimators they are no ' 125 'longer needed. For custom Estimators, please return ' 126 'output_alternatives from your model_fn via ModelFnOps.') 127def classification_signature_fn(examples, unused_features, predictions): 128 """Creates classification signature from given examples and predictions. 129 130 Args: 131 examples: `Tensor`. 132 unused_features: `dict` of `Tensor`s. 133 predictions: `Tensor` or dict of tensors that contains the classes tensor 134 as in {'classes': `Tensor`}. 135 136 Returns: 137 Tuple of default classification signature and empty named signatures. 138 139 Raises: 140 ValueError: If examples is `None`. 141 """ 142 if examples is None: 143 raise ValueError('examples cannot be None when using this signature fn.') 144 145 if isinstance(predictions, dict): 146 default_signature = exporter.classification_signature( 147 examples, classes_tensor=predictions['classes']) 148 else: 149 default_signature = exporter.classification_signature( 150 examples, classes_tensor=predictions) 151 return default_signature, {} 152 153 154@deprecated('2017-03-25', 155 'signature_fns are deprecated. For canned Estimators they are no ' 156 'longer needed. For custom Estimators, please return ' 157 'output_alternatives from your model_fn via ModelFnOps.') 158def classification_signature_fn_with_prob( 159 examples, unused_features, predictions): 160 """Classification signature from given examples and predicted probabilities. 161 162 Args: 163 examples: `Tensor`. 164 unused_features: `dict` of `Tensor`s. 165 predictions: `Tensor` of predicted probabilities or dict that contains the 166 probabilities tensor as in {'probabilities', `Tensor`}. 167 168 Returns: 169 Tuple of default classification signature and empty named signatures. 170 171 Raises: 172 ValueError: If examples is `None`. 173 """ 174 if examples is None: 175 raise ValueError('examples cannot be None when using this signature fn.') 176 177 if isinstance(predictions, dict): 178 default_signature = exporter.classification_signature( 179 examples, scores_tensor=predictions['probabilities']) 180 else: 181 default_signature = exporter.classification_signature( 182 examples, scores_tensor=predictions) 183 return default_signature, {} 184 185 186@deprecated('2017-03-25', 187 'signature_fns are deprecated. For canned Estimators they are no ' 188 'longer needed. For custom Estimators, please return ' 189 'output_alternatives from your model_fn via ModelFnOps.') 190def regression_signature_fn(examples, unused_features, predictions): 191 """Creates regression signature from given examples and predictions. 192 193 Args: 194 examples: `Tensor`. 195 unused_features: `dict` of `Tensor`s. 196 predictions: `Tensor`. 197 198 Returns: 199 Tuple of default regression signature and empty named signatures. 200 201 Raises: 202 ValueError: If examples is `None`. 203 """ 204 if examples is None: 205 raise ValueError('examples cannot be None when using this signature fn.') 206 207 default_signature = exporter.regression_signature( 208 input_tensor=examples, output_tensor=predictions) 209 return default_signature, {} 210 211 212@deprecated('2017-03-25', 213 'signature_fns are deprecated. For canned Estimators they are no ' 214 'longer needed. For custom Estimators, please return ' 215 'output_alternatives from your model_fn via ModelFnOps.') 216def logistic_regression_signature_fn(examples, unused_features, predictions): 217 """Creates logistic regression signature from given examples and predictions. 218 219 Args: 220 examples: `Tensor`. 221 unused_features: `dict` of `Tensor`s. 222 predictions: `Tensor` of shape [batch_size, 2] of predicted probabilities or 223 dict that contains the probabilities tensor as in 224 {'probabilities', `Tensor`}. 225 226 Returns: 227 Tuple of default regression signature and named signature. 228 229 Raises: 230 ValueError: If examples is `None`. 231 """ 232 if examples is None: 233 raise ValueError('examples cannot be None when using this signature fn.') 234 235 if isinstance(predictions, dict): 236 predictions_tensor = predictions['probabilities'] 237 else: 238 predictions_tensor = predictions 239 # predictions should have shape [batch_size, 2] where first column is P(Y=0|x) 240 # while second column is P(Y=1|x). We are only interested in the second 241 # column for inference. 242 predictions_shape = predictions_tensor.get_shape() 243 predictions_rank = len(predictions_shape) 244 if predictions_rank != 2: 245 logging.fatal( 246 'Expected predictions to have rank 2, but received predictions with ' 247 'rank: {} and shape: {}'.format(predictions_rank, predictions_shape)) 248 if predictions_shape[1] != 2: 249 logging.fatal( 250 'Expected predictions to have 2nd dimension: 2, but received ' 251 'predictions with 2nd dimension: {} and shape: {}. Did you mean to use ' 252 'regression_signature_fn or classification_signature_fn_with_prob ' 253 'instead?'.format(predictions_shape[1], predictions_shape)) 254 255 positive_predictions = predictions_tensor[:, 1] 256 default_signature = exporter.regression_signature( 257 input_tensor=examples, output_tensor=positive_predictions) 258 return default_signature, {} 259 260 261# pylint: disable=protected-access 262@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') 263def _default_input_fn(estimator, examples): 264 """Creates default input parsing using Estimator's feature signatures.""" 265 return estimator._get_feature_ops_from_example(examples) 266 267 268@deprecated('2016-09-23', 'Please use Estimator.export_savedmodel() instead.') 269def export_estimator(estimator, 270 export_dir, 271 signature_fn=None, 272 input_fn=_default_input_fn, 273 default_batch_size=1, 274 exports_to_keep=None): 275 """Deprecated, please use Estimator.export_savedmodel().""" 276 _export_estimator(estimator=estimator, 277 export_dir=export_dir, 278 signature_fn=signature_fn, 279 input_fn=input_fn, 280 default_batch_size=default_batch_size, 281 exports_to_keep=exports_to_keep) 282 283 284@deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') 285def _export_estimator(estimator, 286 export_dir, 287 signature_fn, 288 input_fn, 289 default_batch_size, 290 exports_to_keep, 291 input_feature_key=None, 292 use_deprecated_input_fn=True, 293 prediction_key=None, 294 checkpoint_path=None): 295 if use_deprecated_input_fn: 296 input_fn = input_fn or _default_input_fn 297 elif input_fn is None: 298 raise ValueError('input_fn must be defined.') 299 300 # If checkpoint_path is specified, use the specified checkpoint path. 301 checkpoint_path = (checkpoint_path or 302 checkpoint_management.latest_checkpoint( 303 estimator._model_dir)) 304 with ops.Graph().as_default() as g: 305 training_util.create_global_step(g) 306 307 if use_deprecated_input_fn: 308 examples = array_ops.placeholder(dtype=dtypes.string, 309 shape=[default_batch_size], 310 name='input_example_tensor') 311 features = input_fn(estimator, examples) 312 else: 313 features, _ = input_fn() 314 examples = None 315 if input_feature_key is not None: 316 examples = features.pop(input_feature_key) 317 318 if (not features) and (examples is None): 319 raise ValueError('Either features or examples must be defined.') 320 321 predictions = estimator._get_predict_ops(features).predictions 322 323 if prediction_key is not None: 324 predictions = predictions[prediction_key] 325 326 # Explicit signature_fn takes priority 327 if signature_fn: 328 default_signature, named_graph_signatures = signature_fn(examples, 329 features, 330 predictions) 331 else: 332 try: 333 # Some estimators provide a signature function. 334 # TODO(zakaria): check if the estimator has this function, 335 # raise helpful error if not 336 signature_fn = estimator._create_signature_fn() 337 338 default_signature, named_graph_signatures = ( 339 signature_fn(examples, features, predictions)) 340 except AttributeError: 341 logging.warn( 342 'Change warning: `signature_fn` will be required after' 343 '2016-08-01.\n' 344 'Using generic signatures for now. To maintain this behavior, ' 345 'pass:\n' 346 ' signature_fn=export.generic_signature_fn\n' 347 'Also consider passing a regression or classification signature; ' 348 'see cl/126430915 for an example.') 349 default_signature, named_graph_signatures = generic_signature_fn( 350 examples, features, predictions) 351 if exports_to_keep is not None: 352 exports_to_keep = gc.largest_export_versions(exports_to_keep) 353 return _export_graph( 354 g, 355 _get_saver(), 356 checkpoint_path, 357 export_dir, 358 default_graph_signature=default_signature, 359 named_graph_signatures=named_graph_signatures, 360 exports_to_keep=exports_to_keep) 361# pylint: enable=protected-access 362