1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Loader implementation for SavedModel with hermetic, language-neutral exports. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23 24from google.protobuf import message 25from google.protobuf import text_format 26 27from tensorflow.core.protobuf import graph_debug_info_pb2 28from tensorflow.core.protobuf import meta_graph_pb2 29from tensorflow.core.protobuf import saved_model_pb2 30from tensorflow.python.framework import ops 31from tensorflow.python.lib.io import file_io 32from tensorflow.python.ops import variables 33from tensorflow.python.platform import tf_logging 34from tensorflow.python.saved_model import constants 35from tensorflow.python.saved_model import signature_def_utils 36from tensorflow.python.saved_model import utils_impl as saved_model_utils 37from tensorflow.python.saved_model.pywrap_saved_model import metrics 38from tensorflow.python.training import saver as tf_saver 39from tensorflow.python.util import compat 40from tensorflow.python.util import deprecation 41from tensorflow.python.util.tf_export import tf_export 42 43# API label for SavedModel metrics. 44_LOADER_LABEL = "loader" 45 46 47def parse_saved_model_with_debug_info(export_dir): 48 """Reads the savedmodel as well as the graph debug info. 49 50 Args: 51 export_dir: Directory containing the SavedModel and GraphDebugInfo files. 52 53 Returns: 54 `SavedModel` and `GraphDebugInfo` protocol buffers. 55 56 Raises: 57 IOError: If the saved model file does not exist, or cannot be successfully 58 parsed. Missing graph debug info file is fine. 59 """ 60 saved_model = _parse_saved_model(export_dir) 61 62 debug_info_path = os.path.join( 63 saved_model_utils.get_debug_dir(export_dir), 64 constants.DEBUG_INFO_FILENAME_PB) 65 debug_info = graph_debug_info_pb2.GraphDebugInfo() 66 if file_io.file_exists(debug_info_path): 67 with file_io.FileIO(debug_info_path, "rb") as debug_file: 68 try: 69 debug_info.ParseFromString(debug_file.read()) 70 except message.DecodeError as e: 71 raise IOError("Cannot parse file %s: %s." % (debug_info_path, str(e))) 72 73 return (saved_model, debug_info) 74 75 76@tf_export("__internal__.saved_model.parse_saved_model", v1=[]) 77def parse_saved_model(export_dir): 78 """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`. 79 80 Args: 81 export_dir: String or Pathlike, path to the directory containing the 82 SavedModel file. 83 84 Returns: 85 A `SavedModel` protocol buffer. 86 87 Raises: 88 IOError: If the file does not exist, or cannot be successfully parsed. 89 """ 90 # Build the path to the SavedModel in pbtxt format. 91 path_to_pbtxt = os.path.join( 92 compat.as_bytes(compat.path_to_str(export_dir)), 93 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 94 # Build the path to the SavedModel in pb format. 95 path_to_pb = os.path.join( 96 compat.as_bytes(compat.path_to_str(export_dir)), 97 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 98 99 # Parse the SavedModel protocol buffer. 100 saved_model = saved_model_pb2.SavedModel() 101 if file_io.file_exists(path_to_pb): 102 with file_io.FileIO(path_to_pb, "rb") as f: 103 file_content = f.read() 104 try: 105 saved_model.ParseFromString(file_content) 106 return saved_model 107 except message.DecodeError as e: 108 raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e))) 109 elif file_io.file_exists(path_to_pbtxt): 110 with file_io.FileIO(path_to_pbtxt, "rb") as f: 111 file_content = f.read() 112 try: 113 text_format.Merge(file_content.decode("utf-8"), saved_model) 114 return saved_model 115 except text_format.ParseError as e: 116 raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e))) 117 else: 118 raise IOError( 119 "SavedModel file does not exist at: %s%s{%s|%s}" % 120 (export_dir, os.path.sep, constants.SAVED_MODEL_FILENAME_PBTXT, 121 constants.SAVED_MODEL_FILENAME_PB)) 122 123 124# TODO(b/120594573): Make this symbol also available as private, so that 125# tensorflow_transform and tensorflow_estimator do not break. 126_parse_saved_model = parse_saved_model 127 128 129def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None): 130 """Gets the asset tensors, if defined in the meta graph def to load. 131 132 Args: 133 export_dir: Directory where the SavedModel is located. 134 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. 135 import_scope: Optional `string` -- if specified, prepend this followed by 136 '/' to all returned asset tensor names. 137 138 Returns: 139 A dictionary of asset tensors, keyed by the name of the asset tensor. The 140 value in the map corresponds to the absolute path of the asset file. 141 """ 142 # Collection-def that may contain the assets key. 143 collection_def = meta_graph_def_to_load.collection_def 144 145 asset_tensor_dict = {} 146 asset_protos = [] 147 148 if meta_graph_def_to_load.asset_file_def: 149 asset_protos = meta_graph_def_to_load.asset_file_def 150 elif constants.ASSETS_KEY in collection_def: 151 assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value 152 for asset_any_proto in assets_any_proto: 153 asset_proto = meta_graph_pb2.AssetFileDef() 154 asset_any_proto.Unpack(asset_proto) 155 asset_protos.append(asset_proto) 156 157 # Location of the assets for SavedModel. 158 assets_directory = os.path.join( 159 compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY)) 160 # Process each asset and add it to the asset tensor dictionary. 161 for asset_proto in asset_protos: 162 tensor_name = asset_proto.tensor_info.name 163 if import_scope: 164 tensor_name = "%s/%s" % (import_scope, tensor_name) 165 asset_tensor_dict[tensor_name] = os.path.join( 166 compat.as_bytes(assets_directory), 167 compat.as_bytes(asset_proto.filename)) 168 169 return asset_tensor_dict 170 171 172def _get_main_op_tensor( 173 meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY): 174 """Gets the main op tensor, if one exists. 175 176 Args: 177 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. 178 init_op_key: name of the collection to check; should be one of MAIN_OP_KEY 179 or the deprecated LEGACY_INIT_OP_KEY 180 181 Returns: 182 The main op tensor, if it exists and `None` otherwise. 183 184 Raises: 185 RuntimeError: If the collection def corresponding to the main op key has 186 other than exactly one tensor. 187 """ 188 # TODO(kathywu): Rename this method to _get_op_from_collection when 189 # dependency from SavedModelEstimator is removed. 190 collection_def = meta_graph_def_to_load.collection_def 191 init_op = None 192 if init_op_key in collection_def: 193 init_op_list = collection_def[init_op_key].node_list.value 194 if len(init_op_list) != 1: 195 raise RuntimeError("Expected exactly one SavedModel init op. " 196 "Found: {}".format(init_op_list)) 197 init_op = ops.get_collection(init_op_key)[0] 198 return init_op 199 200 201def _get_op_from_collection(meta_graph_def, op_key): 202 return _get_main_op_tensor(meta_graph_def, op_key) 203 204 205def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope): 206 """Retrieve op stored in the imported meta graph's signature def.""" 207 if op_signature_key in meta_graph_def.signature_def: 208 return signature_def_utils.load_op_from_signature_def( 209 meta_graph_def.signature_def[op_signature_key], op_signature_key, 210 import_scope) 211 else: 212 return None 213 214 215def get_init_op(meta_graph_def, import_scope=None): 216 return (_get_op_from_signature_def( 217 meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or 218 _get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or 219 _get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY)) 220 221 222def get_train_op(meta_graph_def, import_scope=None): 223 train_op = _get_op_from_signature_def( 224 meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope) 225 if train_op is None: 226 train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY) 227 return train_op 228 229 230@tf_export(v1=[ 231 "saved_model.contains_saved_model", 232 "saved_model.maybe_saved_model_directory", 233 "saved_model.loader.maybe_saved_model_directory" 234]) 235@deprecation.deprecated_endpoints( 236 "saved_model.loader.maybe_saved_model_directory") 237def maybe_saved_model_directory(export_dir): 238 """Checks whether the provided export directory could contain a SavedModel. 239 240 Note that the method does not load any data by itself. If the method returns 241 `false`, the export directory definitely does not contain a SavedModel. If the 242 method returns `true`, the export directory may contain a SavedModel but 243 provides no guarantee that it can be loaded. 244 245 Args: 246 export_dir: Absolute string path to possible export location. For example, 247 '/my/foo/model'. 248 249 Returns: 250 True if the export directory contains SavedModel files, False otherwise. 251 """ 252 txt_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT) 253 pb_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB) 254 return file_io.file_exists(txt_path) or file_io.file_exists(pb_path) 255 256 257@tf_export("saved_model.contains_saved_model", v1=[]) 258def contains_saved_model(export_dir): 259 """Checks whether the provided export directory could contain a SavedModel. 260 261 Note that the method does not load any data by itself. If the method returns 262 `false`, the export directory definitely does not contain a SavedModel. If the 263 method returns `true`, the export directory may contain a SavedModel but 264 provides no guarantee that it can be loaded. 265 266 Args: 267 export_dir: Absolute string path to possible export location. For example, 268 '/my/foo/model'. 269 270 Returns: 271 True if the export directory contains SavedModel files, False otherwise. 272 """ 273 return maybe_saved_model_directory(export_dir) 274 275 276@tf_export(v1=["saved_model.load", "saved_model.loader.load"]) 277@deprecation.deprecated( 278 None, 279 "This function will only be available through the v1 compatibility " 280 "library as tf.compat.v1.saved_model.loader.load or " 281 "tf.compat.v1.saved_model.load. There will be a new function for importing " 282 "SavedModels in Tensorflow 2.0.") 283def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): 284 """Loads the model from a SavedModel as specified by tags. 285 286 Args: 287 sess: The TensorFlow session to restore the variables. 288 tags: Set of string tags to identify the required MetaGraphDef. These should 289 correspond to the tags used when saving the variables using the 290 SavedModel `save()` API. 291 export_dir: Directory in which the SavedModel protocol buffer and variables 292 to be loaded are located. 293 import_scope: Optional `string` -- if specified, prepend this string 294 followed by '/' to all loaded tensor names. This scope is applied to 295 tensor instances loaded into the passed session, but it is *not* written 296 through to the static `MetaGraphDef` protocol buffer that is returned. 297 **saver_kwargs: Optional keyword arguments passed through to Saver. 298 299 Returns: 300 The `MetaGraphDef` protocol buffer loaded in the provided session. This 301 can be used to further extract signature-defs, collection-defs, etc. 302 303 Raises: 304 RuntimeError: MetaGraphDef associated with the tags cannot be found. 305 306 @compatibility(TF2) 307 308 `tf.compat.v1.saved_model.load` or `tf.compat.v1.saved_model.loader.load` is 309 not compatible with eager execution. Please use `tf.saved_model.load` instead 310 to load your model. You can refer to the [SavedModel guide] 311 (https://www.tensorflow.org/guide/saved_model) for more information as well as 312 "Importing SavedModels from TensorFlow 1.x" in the [`tf.saved_model.load`] 313 (https://www.tensorflow.org/api_docs/python/tf/saved_model/load) docstring. 314 315 #### How to Map Arguments 316 317 | TF1 Arg Name | TF2 Arg Name | Note | 318 | :-------------------- | :-------------- | :------------------------- | 319 | `sess` | Not supported | - | 320 | `tags` | `tags` | - | 321 | `export_dir` | `export_dir` | - | 322 | `import_scope` | Not supported | Name scopes are not needed. 323 : : : By default, variables are : 324 : : : associated with the loaded : 325 : : : object and function names : 326 : : : are deduped. : 327 | `saver_kwargs` | Not supported | - | 328 329 #### Before & After Usage Example 330 331 Before: 332 333 ``` 334 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 335 tf.compat.v1.saved_model.loader.load(sess, ["foo-tag"], export_dir) 336 ``` 337 338 After: 339 340 ``` 341 model = tf.saved_model.load(export_dir, tags=["foo-tag"]) 342 ``` 343 @end_compatibility 344 """ 345 loader = SavedModelLoader(export_dir) 346 return loader.load(sess, tags, import_scope, **saver_kwargs) 347 348 349class SavedModelLoader(object): 350 """Load graphs and restore variable values from a `SavedModel`.""" 351 352 def __init__(self, export_dir): 353 """Creates a `SavedModelLoader`. 354 355 Args: 356 export_dir: Directory in which the SavedModel protocol buffer and 357 variables to be loaded are located. 358 """ 359 self._export_dir = export_dir 360 self._variables_path = saved_model_utils.get_variables_path(export_dir) 361 self._saved_model = parse_saved_model(export_dir) 362 363 @property 364 def export_dir(self): 365 """Directory containing the SavedModel.""" 366 return self._export_dir 367 368 @property 369 def variables_path(self): 370 """Path to variable checkpoint files.""" 371 return self._variables_path 372 373 @property 374 def saved_model(self): 375 """SavedModel object parsed from the export directory.""" 376 return self._saved_model 377 378 def get_meta_graph_def_from_tags(self, tags): 379 """Return MetaGraphDef with the exact specified tags. 380 381 Args: 382 tags: A list or set of string tags that identify the MetaGraphDef. 383 384 Returns: 385 MetaGraphDef with the same tags. 386 387 Raises: 388 RuntimeError: if no metagraphs were found with the associated tags. 389 """ 390 found_match = False 391 available_tags = [] 392 for meta_graph_def in self._saved_model.meta_graphs: 393 available_tags.append(set(meta_graph_def.meta_info_def.tags)) 394 if set(meta_graph_def.meta_info_def.tags) == set(tags): 395 meta_graph_def_to_load = meta_graph_def 396 found_match = True 397 break 398 399 if not found_match: 400 raise RuntimeError( 401 "MetaGraphDef associated with tags " + str(tags).strip("[]") + 402 " could not be found in SavedModel. To inspect available tag-sets in" 403 " the SavedModel, please use the SavedModel CLI: `saved_model_cli`" 404 "\navailable_tags: " + str(available_tags)) 405 return meta_graph_def_to_load 406 407 def load_graph(self, graph, tags, import_scope=None, **saver_kwargs): 408 """Load ops and nodes from SavedModel MetaGraph into graph. 409 410 Args: 411 graph: tf.Graph object. 412 tags: a set of string tags identifying a MetaGraphDef. 413 import_scope: Optional `string` -- if specified, prepend this string 414 followed by '/' to all loaded tensor names. This scope is applied to 415 tensor instances loaded into the passed session, but it is *not* written 416 through to the static `MetaGraphDef` protocol buffer that is returned. 417 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. 418 419 Returns: 420 A tuple of 421 * Saver defined by the MetaGraph, which can be used to restore the 422 variable values. 423 * List of `Operation`/`Tensor` objects returned from 424 `tf.import_graph_def` (may be `None`). 425 """ 426 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 427 with graph.as_default(): 428 return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access 429 meta_graph_def, import_scope=import_scope, **saver_kwargs) 430 431 def restore_variables(self, sess, saver, import_scope=None): 432 """Restore SavedModel variable values into the session. 433 434 Args: 435 sess: tf.compat.v1.Session to restore variable values. 436 saver: a tf.compat.v1.train.Saver object. Can be None if there are no 437 variables in graph. This may be the saver returned by the load_graph() 438 function, or a default `tf.compat.v1.train.Saver()`. 439 import_scope: Optional `string` -- if specified, prepend this string 440 followed by '/' to all loaded tensor names. This scope is applied to 441 tensor instances loaded into the passed session, but it is *not* written 442 through to the static `MetaGraphDef` protocol buffer that is returned. 443 444 Raises: 445 ValueError: if no saver was passed to the saver argument, and there are 446 variables in the graph. 447 """ 448 with sess.graph.as_default(): 449 if (saver is None and 450 not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access 451 tf_logging.info("The specified SavedModel has no variables; no " 452 "checkpoints were restored.") 453 elif isinstance(saver, tf_saver.Saver): 454 saver.restore(sess, self._variables_path) 455 else: 456 raise ValueError( 457 "No tf.train.Saver object was passed to the function " 458 "SavedModelLoader.restore_variables. Since there are variables in " 459 "the graph, a saver is required.") 460 461 def run_init_ops(self, sess, tags, import_scope=None): 462 """Run initialization ops defined in the `MetaGraphDef`. 463 464 Args: 465 sess: tf.compat.v1.Session to restore variable values. 466 tags: a set of string tags identifying a MetaGraphDef. 467 import_scope: Optional `string` -- if specified, prepend this string 468 followed by '/' to all loaded tensor names. This scope is applied to 469 tensor instances loaded into the passed session, but it is *not* written 470 through to the static `MetaGraphDef` protocol buffer that is returned. 471 """ 472 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 473 with sess.graph.as_default(): 474 # Get asset tensors, if any. 475 asset_tensors_dictionary = get_asset_tensors( 476 self._export_dir, meta_graph_def, import_scope=import_scope) 477 478 init_op = get_init_op(meta_graph_def, import_scope) 479 if init_op is not None: 480 sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary) 481 482 def load(self, sess, tags, import_scope=None, **saver_kwargs): 483 """Load the MetaGraphDef graph and restore variable values into the session. 484 485 Args: 486 sess: tf.compat.v1.Session to restore variable values. 487 tags: a set of string tags identifying a MetaGraphDef. 488 import_scope: Optional `string` -- if specified, prepend this string 489 followed by '/' to all loaded tensor names. This scope is applied to 490 tensor instances loaded into the passed session, but it is *not* written 491 through to the static `MetaGraphDef` protocol buffer that is returned. 492 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. 493 494 Returns: 495 `MetagraphDef` proto of the graph that was loaded. 496 """ 497 saved_model_proto = parse_saved_model(self._export_dir) 498 metrics.IncrementReadApi(_LOADER_LABEL) 499 500 with sess.graph.as_default(): 501 saver, _ = self.load_graph(sess.graph, tags, import_scope, 502 **saver_kwargs) 503 self.restore_variables(sess, saver, import_scope) 504 self.run_init_ops(sess, tags, import_scope) 505 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 506 507 if (len(saved_model_proto.meta_graphs) == 1 and 508 saved_model_proto.meta_graphs[0].HasField("object_graph_def")): 509 metrics.IncrementRead(write_version="2") 510 else: 511 metrics.IncrementRead(write_version="1") 512 513 return meta_graph_def 514