1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Loader implementation for SavedModel with hermetic, language-neutral exports. 16""" 17 18import os 19import sys 20 21from google.protobuf import message 22from google.protobuf import text_format 23 24from tensorflow.core.protobuf import graph_debug_info_pb2 25from tensorflow.core.protobuf import meta_graph_pb2 26from tensorflow.core.protobuf import saved_model_pb2 27from tensorflow.python.framework import ops 28from tensorflow.python.lib.io import file_io 29from tensorflow.python.ops import variables 30from tensorflow.python.platform import tf_logging 31from tensorflow.python.saved_model import constants 32from tensorflow.python.saved_model import signature_def_utils 33from tensorflow.python.saved_model import utils_impl as saved_model_utils 34from tensorflow.python.saved_model.pywrap_saved_model import metrics 35from tensorflow.python.training import saver as tf_saver 36from tensorflow.python.util import compat 37from tensorflow.python.util import deprecation 38from tensorflow.python.util.tf_export import tf_export 39 40# API label for SavedModel metrics. 41_LOADER_LABEL = "loader" 42 43 44def parse_saved_model_with_debug_info(export_dir): 45 """Reads the savedmodel as well as the graph debug info. 46 47 Args: 48 export_dir: Directory containing the SavedModel and GraphDebugInfo files. 49 50 Returns: 51 `SavedModel` and `GraphDebugInfo` protocol buffers. 52 53 Raises: 54 IOError: If the saved model file does not exist, or cannot be successfully 55 parsed. Missing graph debug info file is fine. 56 """ 57 saved_model = parse_saved_model(export_dir) 58 59 debug_info_path = file_io.join( 60 saved_model_utils.get_debug_dir(export_dir), 61 constants.DEBUG_INFO_FILENAME_PB) 62 debug_info = graph_debug_info_pb2.GraphDebugInfo() 63 if file_io.file_exists(debug_info_path): 64 with file_io.FileIO(debug_info_path, "rb") as debug_file: 65 try: 66 debug_info.ParseFromString(debug_file.read()) 67 except message.DecodeError as e: 68 raise IOError(f"Cannot parse file {debug_info_path}: {e}.") 69 70 return (saved_model, debug_info) 71 72 73@tf_export("__internal__.saved_model.parse_saved_model", v1=[]) 74def parse_saved_model(export_dir): 75 """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`. 76 77 Args: 78 export_dir: String or Pathlike, path to the directory containing the 79 SavedModel file. 80 81 Returns: 82 A `SavedModel` protocol buffer. 83 84 Raises: 85 IOError: If the file does not exist, or cannot be successfully parsed. 86 """ 87 # Build the path to the SavedModel in pbtxt format. 88 path_to_pbtxt = file_io.join( 89 compat.as_bytes(compat.path_to_str(export_dir)), 90 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 91 # Build the path to the SavedModel in pb format. 92 path_to_pb = file_io.join( 93 compat.as_bytes(compat.path_to_str(export_dir)), 94 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 95 96 # Parse the SavedModel protocol buffer. 97 saved_model = saved_model_pb2.SavedModel() 98 if file_io.file_exists(path_to_pb): 99 with file_io.FileIO(path_to_pb, "rb") as f: 100 file_content = f.read() 101 try: 102 saved_model.ParseFromString(file_content) 103 return saved_model 104 except message.DecodeError as e: 105 raise IOError(f"Cannot parse file {path_to_pb}: {str(e)}.") 106 elif file_io.file_exists(path_to_pbtxt): 107 with file_io.FileIO(path_to_pbtxt, "rb") as f: 108 file_content = f.read() 109 try: 110 text_format.Merge(file_content.decode("utf-8"), saved_model) 111 return saved_model 112 except text_format.ParseError as e: 113 raise IOError(f"Cannot parse file {path_to_pbtxt}: {str(e)}.") 114 else: 115 raise IOError( 116 f"SavedModel file does not exist at: {export_dir}{os.path.sep}" 117 f"{{{constants.SAVED_MODEL_FILENAME_PBTXT}|" 118 f"{constants.SAVED_MODEL_FILENAME_PB}}}") 119 120 121def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None): 122 """Gets the asset tensors, if defined in the meta graph def to load. 123 124 Args: 125 export_dir: Directory where the SavedModel is located. 126 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. 127 import_scope: Optional `string` -- if specified, prepend this followed by 128 '/' to all returned asset tensor names. 129 130 Returns: 131 A dictionary of asset tensors, keyed by the name of the asset tensor. The 132 value in the map corresponds to the absolute path of the asset file. 133 """ 134 # Collection-def that may contain the assets key. 135 collection_def = meta_graph_def_to_load.collection_def 136 137 asset_tensor_dict = {} 138 asset_protos = [] 139 140 if meta_graph_def_to_load.asset_file_def: 141 asset_protos = meta_graph_def_to_load.asset_file_def 142 elif constants.ASSETS_KEY in collection_def: 143 assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value 144 for asset_any_proto in assets_any_proto: 145 asset_proto = meta_graph_pb2.AssetFileDef() 146 asset_any_proto.Unpack(asset_proto) 147 asset_protos.append(asset_proto) 148 149 # Location of the assets for SavedModel. 150 assets_directory = file_io.join( 151 compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY)) 152 # Process each asset and add it to the asset tensor dictionary. 153 for asset_proto in asset_protos: 154 tensor_name = asset_proto.tensor_info.name 155 if import_scope: 156 tensor_name = "%s/%s" % (import_scope, tensor_name) 157 asset_tensor_dict[tensor_name] = file_io.join( 158 compat.as_bytes(assets_directory), 159 compat.as_bytes(asset_proto.filename)) 160 161 return asset_tensor_dict 162 163 164def _get_main_op_tensor( 165 meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY): 166 """Gets the main op tensor, if one exists. 167 168 Args: 169 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. 170 init_op_key: name of the collection to check; should be one of MAIN_OP_KEY 171 or the deprecated LEGACY_INIT_OP_KEY 172 173 Returns: 174 The main op tensor, if it exists and `None` otherwise. 175 176 Raises: 177 RuntimeError: If the collection def corresponding to the main op key has 178 other than exactly one tensor. 179 """ 180 # TODO(kathywu): Rename this method to _get_op_from_collection when 181 # dependency from SavedModelEstimator is removed. 182 collection_def = meta_graph_def_to_load.collection_def 183 init_op = None 184 if init_op_key in collection_def: 185 init_op_list = collection_def[init_op_key].node_list.value 186 if len(init_op_list) != 1: 187 raise RuntimeError("Expected exactly one SavedModel init op. " 188 f"Found {len(init_op_list)}: {init_op_list}.") 189 init_op = ops.get_collection(init_op_key)[0] 190 return init_op 191 192 193def _get_op_from_collection(meta_graph_def, op_key): 194 return _get_main_op_tensor(meta_graph_def, op_key) 195 196 197def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope): 198 """Retrieve op stored in the imported meta graph's signature def.""" 199 if op_signature_key in meta_graph_def.signature_def: 200 return signature_def_utils.load_op_from_signature_def( 201 meta_graph_def.signature_def[op_signature_key], op_signature_key, 202 import_scope) 203 else: 204 return None 205 206 207def get_init_op(meta_graph_def, import_scope=None): 208 return (_get_op_from_signature_def( 209 meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or 210 _get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or 211 _get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY)) 212 213 214def get_train_op(meta_graph_def, import_scope=None): 215 train_op = _get_op_from_signature_def( 216 meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope) 217 if train_op is None: 218 train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY) 219 return train_op 220 221 222@tf_export(v1=[ 223 "saved_model.contains_saved_model", 224 "saved_model.maybe_saved_model_directory", 225 "saved_model.loader.maybe_saved_model_directory" 226]) 227@deprecation.deprecated_endpoints( 228 "saved_model.loader.maybe_saved_model_directory") 229def maybe_saved_model_directory(export_dir): 230 """Checks whether the provided export directory could contain a SavedModel. 231 232 Note that the method does not load any data by itself. If the method returns 233 `false`, the export directory definitely does not contain a SavedModel. If the 234 method returns `true`, the export directory may contain a SavedModel but 235 provides no guarantee that it can be loaded. 236 237 Args: 238 export_dir: Absolute string path to possible export location. For example, 239 '/my/foo/model'. 240 241 Returns: 242 True if the export directory contains SavedModel files, False otherwise. 243 """ 244 txt_path = file_io.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT) 245 pb_path = file_io.join(export_dir, constants.SAVED_MODEL_FILENAME_PB) 246 return file_io.file_exists(txt_path) or file_io.file_exists(pb_path) 247 248 249@tf_export("saved_model.contains_saved_model", v1=[]) 250def contains_saved_model(export_dir): 251 """Checks whether the provided export directory could contain a SavedModel. 252 253 Note that the method does not load any data by itself. If the method returns 254 `false`, the export directory definitely does not contain a SavedModel. If the 255 method returns `true`, the export directory may contain a SavedModel but 256 provides no guarantee that it can be loaded. 257 258 Args: 259 export_dir: Absolute path to possible export location. For example, 260 '/my/foo/model'. 261 262 Returns: 263 True if the export directory contains SavedModel files, False otherwise. 264 """ 265 if isinstance(export_dir, os.PathLike): 266 export_dir = os.fspath(export_dir) 267 return maybe_saved_model_directory(export_dir) 268 269 270@tf_export(v1=["saved_model.load", "saved_model.loader.load"]) 271@deprecation.deprecated( 272 None, 273 "Use `tf.saved_model.load` instead.") 274def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): 275 """Loads the model from a SavedModel as specified by tags. 276 277 Args: 278 sess: The TensorFlow session to restore the variables. 279 tags: Set of string tags to identify the required MetaGraphDef. These should 280 correspond to the tags used when saving the variables using the 281 SavedModel `save()` API. 282 export_dir: Directory in which the SavedModel protocol buffer and variables 283 to be loaded are located. 284 import_scope: Optional `string` -- if specified, prepend this string 285 followed by '/' to all loaded tensor names. This scope is applied to 286 tensor instances loaded into the passed session, but it is *not* written 287 through to the static `MetaGraphDef` protocol buffer that is returned. 288 **saver_kwargs: Optional keyword arguments passed through to Saver. 289 290 Returns: 291 The `MetaGraphDef` protocol buffer loaded in the provided session. This 292 can be used to further extract signature-defs, collection-defs, etc. 293 294 Raises: 295 RuntimeError: MetaGraphDef associated with the tags cannot be found. 296 297 @compatibility(TF2) 298 299 `tf.compat.v1.saved_model.load` or `tf.compat.v1.saved_model.loader.load` is 300 not compatible with eager execution. Please use `tf.saved_model.load` instead 301 to load your model. You can refer to the [SavedModel guide] 302 (https://www.tensorflow.org/guide/saved_model) for more information as well as 303 "Importing SavedModels from TensorFlow 1.x" in the [`tf.saved_model.load`] 304 (https://www.tensorflow.org/api_docs/python/tf/saved_model/load) docstring. 305 306 #### How to Map Arguments 307 308 | TF1 Arg Name | TF2 Arg Name | Note | 309 | :-------------------- | :-------------- | :------------------------- | 310 | `sess` | Not supported | - | 311 | `tags` | `tags` | - | 312 | `export_dir` | `export_dir` | - | 313 | `import_scope` | Not supported | Name scopes are not needed. 314 : : : By default, variables are : 315 : : : associated with the loaded : 316 : : : object and function names : 317 : : : are deduped. : 318 | `saver_kwargs` | Not supported | - | 319 320 #### Before & After Usage Example 321 322 Before: 323 324 ``` 325 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 326 tf.compat.v1.saved_model.loader.load(sess, ["foo-tag"], export_dir) 327 ``` 328 329 After: 330 331 ``` 332 model = tf.saved_model.load(export_dir, tags=["foo-tag"]) 333 ``` 334 @end_compatibility 335 """ 336 loader = SavedModelLoader(export_dir) 337 return loader.load(sess, tags, import_scope, **saver_kwargs) 338 339 340class SavedModelLoader(object): 341 """Load graphs and restore variable values from a `SavedModel`.""" 342 343 def __init__(self, export_dir): 344 """Creates a `SavedModelLoader`. 345 346 Args: 347 export_dir: Directory in which the SavedModel protocol buffer and 348 variables to be loaded are located. 349 """ 350 self._export_dir = export_dir 351 self._variables_path = saved_model_utils.get_variables_path(export_dir) 352 self._saved_model = parse_saved_model(export_dir) 353 354 @property 355 def export_dir(self): 356 """Directory containing the SavedModel.""" 357 return self._export_dir 358 359 @property 360 def variables_path(self): 361 """Path to variable checkpoint files.""" 362 return self._variables_path 363 364 @property 365 def saved_model(self): 366 """SavedModel object parsed from the export directory.""" 367 return self._saved_model 368 369 def get_meta_graph_def_from_tags(self, tags): 370 """Return MetaGraphDef with the exact specified tags. 371 372 Args: 373 tags: A list or set of string tags that identify the MetaGraphDef. 374 375 Returns: 376 MetaGraphDef with the same tags. 377 378 Raises: 379 RuntimeError: if no metagraphs were found with the associated tags. 380 """ 381 found_match = False 382 available_tags = [] 383 for meta_graph_def in self._saved_model.meta_graphs: 384 available_tags.append(set(meta_graph_def.meta_info_def.tags)) 385 if set(meta_graph_def.meta_info_def.tags) == set(tags): 386 meta_graph_def_to_load = meta_graph_def 387 found_match = True 388 break 389 390 if not found_match: 391 raise RuntimeError( 392 f"MetaGraphDef associated with tags {str(tags).strip('[]')} " 393 "could not be found in SavedModel, with available tags " 394 f"'{available_tags}'. To inspect available tag-sets in" 395 " the SavedModel, please use the SavedModel CLI: `saved_model_cli`.") 396 return meta_graph_def_to_load 397 398 def load_graph(self, graph, tags, import_scope=None, **saver_kwargs): 399 """Load ops and nodes from SavedModel MetaGraph into graph. 400 401 Args: 402 graph: tf.Graph object. 403 tags: a set of string tags identifying a MetaGraphDef. 404 import_scope: Optional `string` -- if specified, prepend this string 405 followed by '/' to all loaded tensor names. This scope is applied to 406 tensor instances loaded into the passed session, but it is *not* written 407 through to the static `MetaGraphDef` protocol buffer that is returned. 408 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. 409 410 Returns: 411 A tuple of 412 * Saver defined by the MetaGraph, which can be used to restore the 413 variable values. 414 * List of `Operation`/`Tensor` objects returned from 415 `tf.import_graph_def` (may be `None`). 416 """ 417 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 418 if sys.byteorder == "big": 419 saved_model_utils.swap_function_tensor_content(meta_graph_def, "little", 420 "big") 421 with graph.as_default(): 422 return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access 423 meta_graph_def, import_scope=import_scope, **saver_kwargs) 424 425 def restore_variables(self, sess, saver, import_scope=None): 426 """Restore SavedModel variable values into the session. 427 428 Args: 429 sess: tf.compat.v1.Session to restore variable values. 430 saver: a tf.compat.v1.train.Saver object. Can be None if there are no 431 variables in graph. This may be the saver returned by the load_graph() 432 function, or a default `tf.compat.v1.train.Saver()`. 433 import_scope: Optional `string` -- if specified, prepend this string 434 followed by '/' to all loaded tensor names. This scope is applied to 435 tensor instances loaded into the passed session, but it is *not* written 436 through to the static `MetaGraphDef` protocol buffer that is returned. 437 438 Raises: 439 ValueError: if no saver was passed to the saver argument, and there are 440 variables in the graph. 441 """ 442 with sess.graph.as_default(): 443 if (saver is None and 444 not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access 445 tf_logging.info("The specified SavedModel has no variables; no " 446 "checkpoints were restored.") 447 elif isinstance(saver, tf_saver.Saver): 448 saver.restore(sess, self._variables_path) 449 else: 450 raise ValueError( 451 "No tf.train.Saver object was passed to the function " 452 "`SavedModelLoader.restore_variables`. Since there are variables in" 453 " the graph, a saver is required.") 454 455 def run_init_ops(self, sess, tags, import_scope=None): 456 """Run initialization ops defined in the `MetaGraphDef`. 457 458 Args: 459 sess: tf.compat.v1.Session to restore variable values. 460 tags: a set of string tags identifying a MetaGraphDef. 461 import_scope: Optional `string` -- if specified, prepend this string 462 followed by '/' to all loaded tensor names. This scope is applied to 463 tensor instances loaded into the passed session, but it is *not* written 464 through to the static `MetaGraphDef` protocol buffer that is returned. 465 """ 466 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 467 with sess.graph.as_default(): 468 # Get asset tensors, if any. 469 asset_tensors_dictionary = get_asset_tensors( 470 self._export_dir, meta_graph_def, import_scope=import_scope) 471 472 init_op = get_init_op(meta_graph_def, import_scope) 473 if init_op is not None: 474 sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary) 475 476 def load(self, sess, tags, import_scope=None, **saver_kwargs): 477 """Load the MetaGraphDef graph and restore variable values into the session. 478 479 Args: 480 sess: tf.compat.v1.Session to restore variable values. 481 tags: a set of string tags identifying a MetaGraphDef. 482 import_scope: Optional `string` -- if specified, prepend this string 483 followed by '/' to all loaded tensor names. This scope is applied to 484 tensor instances loaded into the passed session, but it is *not* written 485 through to the static `MetaGraphDef` protocol buffer that is returned. 486 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. 487 488 Returns: 489 `MetagraphDef` proto of the graph that was loaded. 490 """ 491 saved_model_proto = parse_saved_model(self._export_dir) 492 metrics.IncrementReadApi(_LOADER_LABEL) 493 494 with sess.graph.as_default(): 495 saver, _ = self.load_graph(sess.graph, tags, import_scope, 496 **saver_kwargs) 497 self.restore_variables(sess, saver, import_scope) 498 self.run_init_ops(sess, tags, import_scope) 499 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 500 501 if (len(saved_model_proto.meta_graphs) == 1 and 502 saved_model_proto.meta_graphs[0].HasField("object_graph_def")): 503 metrics.IncrementRead(write_version="2") 504 else: 505 metrics.IncrementRead(write_version="1") 506 507 return meta_graph_def 508