1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Loader implementation for SavedModel with hermetic, language-neutral exports. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23 24from google.protobuf import message 25from google.protobuf import text_format 26 27from tensorflow.core.protobuf import graph_debug_info_pb2 28from tensorflow.core.protobuf import meta_graph_pb2 29from tensorflow.core.protobuf import saved_model_pb2 30from tensorflow.python.framework import ops 31from tensorflow.python.lib.io import file_io 32from tensorflow.python.ops import variables 33from tensorflow.python.platform import tf_logging 34from tensorflow.python.saved_model import constants 35from tensorflow.python.saved_model import signature_def_utils 36from tensorflow.python.saved_model import utils_impl as saved_model_utils 37from tensorflow.python.training import saver as tf_saver 38from tensorflow.python.util import compat 39from tensorflow.python.util import deprecation 40from tensorflow.python.util.tf_export import tf_export 41 42 43def parse_saved_model_with_debug_info(export_dir): 44 """Reads the savedmodel as well as the graph debug info. 45 46 Args: 47 export_dir: Directory containing the SavedModel and GraphDebugInfo files. 48 49 Returns: 50 `SavedModel` and `GraphDebugInfo` protocol buffers. 51 52 Raises: 53 IOError: If the saved model file does not exist, or cannot be successfully 54 parsed. Missing graph debug info file is fine. 55 """ 56 saved_model = _parse_saved_model(export_dir) 57 58 debug_info_path = os.path.join( 59 saved_model_utils.get_debug_dir(export_dir), 60 constants.DEBUG_INFO_FILENAME_PB) 61 debug_info = graph_debug_info_pb2.GraphDebugInfo() 62 if file_io.file_exists(debug_info_path): 63 with file_io.FileIO(debug_info_path, "rb") as debug_file: 64 try: 65 debug_info.ParseFromString(debug_file.read()) 66 except message.DecodeError as e: 67 raise IOError("Cannot parse file %s: %s." % (debug_info_path, str(e))) 68 69 return (saved_model, debug_info) 70 71 72def parse_saved_model(export_dir): 73 """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`. 74 75 Args: 76 export_dir: String or Pathlike, path to the directory containing the 77 SavedModel file. 78 79 Returns: 80 A `SavedModel` protocol buffer. 81 82 Raises: 83 IOError: If the file does not exist, or cannot be successfully parsed. 84 """ 85 # Build the path to the SavedModel in pbtxt format. 86 path_to_pbtxt = os.path.join( 87 compat.as_bytes(compat.path_to_str(export_dir)), 88 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 89 # Build the path to the SavedModel in pb format. 90 path_to_pb = os.path.join( 91 compat.as_bytes(compat.path_to_str(export_dir)), 92 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 93 94 # Parse the SavedModel protocol buffer. 95 saved_model = saved_model_pb2.SavedModel() 96 if file_io.file_exists(path_to_pb): 97 try: 98 file_content = file_io.FileIO(path_to_pb, "rb").read() 99 saved_model.ParseFromString(file_content) 100 return saved_model 101 except message.DecodeError as e: 102 raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e))) 103 elif file_io.file_exists(path_to_pbtxt): 104 try: 105 file_content = file_io.FileIO(path_to_pbtxt, "rb").read() 106 text_format.Merge(file_content.decode("utf-8"), saved_model) 107 return saved_model 108 except text_format.ParseError as e: 109 raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e))) 110 else: 111 raise IOError("SavedModel file does not exist at: %s/{%s|%s}" % 112 (export_dir, 113 constants.SAVED_MODEL_FILENAME_PBTXT, 114 constants.SAVED_MODEL_FILENAME_PB)) 115 116 117# TODO(b/120594573): Make this symbol also available as private, so that 118# tensorflow_transform and tensorflow_estimator do not break. 119_parse_saved_model = parse_saved_model 120 121 122def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None): 123 """Gets the asset tensors, if defined in the meta graph def to load. 124 125 Args: 126 export_dir: Directory where the SavedModel is located. 127 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. 128 import_scope: Optional `string` -- if specified, prepend this followed by 129 '/' to all returned asset tensor names. 130 131 Returns: 132 A dictionary of asset tensors, keyed by the name of the asset tensor. The 133 value in the map corresponds to the absolute path of the asset file. 134 """ 135 # Collection-def that may contain the assets key. 136 collection_def = meta_graph_def_to_load.collection_def 137 138 asset_tensor_dict = {} 139 asset_protos = [] 140 141 if meta_graph_def_to_load.asset_file_def: 142 asset_protos = meta_graph_def_to_load.asset_file_def 143 elif constants.ASSETS_KEY in collection_def: 144 assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value 145 for asset_any_proto in assets_any_proto: 146 asset_proto = meta_graph_pb2.AssetFileDef() 147 asset_any_proto.Unpack(asset_proto) 148 asset_protos.append(asset_proto) 149 150 # Location of the assets for SavedModel. 151 assets_directory = os.path.join( 152 compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY)) 153 # Process each asset and add it to the asset tensor dictionary. 154 for asset_proto in asset_protos: 155 tensor_name = asset_proto.tensor_info.name 156 if import_scope: 157 tensor_name = "%s/%s" % (import_scope, tensor_name) 158 asset_tensor_dict[tensor_name] = os.path.join( 159 compat.as_bytes(assets_directory), 160 compat.as_bytes(asset_proto.filename)) 161 162 return asset_tensor_dict 163 164 165def _get_main_op_tensor( 166 meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY): 167 """Gets the main op tensor, if one exists. 168 169 Args: 170 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. 171 init_op_key: name of the collection to check; should be one of MAIN_OP_KEY 172 or the deprecated LEGACY_INIT_OP_KEY 173 174 Returns: 175 The main op tensor, if it exists and `None` otherwise. 176 177 Raises: 178 RuntimeError: If the collection def corresponding to the main op key has 179 other than exactly one tensor. 180 """ 181 # TODO(kathywu): Rename this method to _get_op_from_collection when 182 # dependency from SavedModelEstimator is removed. 183 collection_def = meta_graph_def_to_load.collection_def 184 init_op = None 185 if init_op_key in collection_def: 186 init_op_list = collection_def[init_op_key].node_list.value 187 if len(init_op_list) != 1: 188 raise RuntimeError("Expected exactly one SavedModel init op. " 189 "Found: {}".format(init_op_list)) 190 init_op = ops.get_collection(init_op_key)[0] 191 return init_op 192 193 194def _get_op_from_collection(meta_graph_def, op_key): 195 return _get_main_op_tensor(meta_graph_def, op_key) 196 197 198def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope): 199 """Retrieve op stored in the imported meta graph's signature def.""" 200 if op_signature_key in meta_graph_def.signature_def: 201 return signature_def_utils.load_op_from_signature_def( 202 meta_graph_def.signature_def[op_signature_key], op_signature_key, 203 import_scope) 204 else: 205 return None 206 207 208def get_init_op(meta_graph_def, import_scope=None): 209 return (_get_op_from_signature_def( 210 meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or 211 _get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or 212 _get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY)) 213 214 215def get_train_op(meta_graph_def, import_scope=None): 216 train_op = _get_op_from_signature_def( 217 meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope) 218 if train_op is None: 219 train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY) 220 return train_op 221 222 223@tf_export(v1=[ 224 "saved_model.contains_saved_model", 225 "saved_model.maybe_saved_model_directory", 226 "saved_model.loader.maybe_saved_model_directory" 227]) 228@deprecation.deprecated_endpoints( 229 "saved_model.loader.maybe_saved_model_directory") 230def maybe_saved_model_directory(export_dir): 231 """Checks whether the provided export directory could contain a SavedModel. 232 233 Note that the method does not load any data by itself. If the method returns 234 `false`, the export directory definitely does not contain a SavedModel. If the 235 method returns `true`, the export directory may contain a SavedModel but 236 provides no guarantee that it can be loaded. 237 238 Args: 239 export_dir: Absolute string path to possible export location. For example, 240 '/my/foo/model'. 241 242 Returns: 243 True if the export directory contains SavedModel files, False otherwise. 244 """ 245 txt_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT) 246 pb_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB) 247 return file_io.file_exists(txt_path) or file_io.file_exists(pb_path) 248 249 250@tf_export("saved_model.contains_saved_model", v1=[]) 251def contains_saved_model(export_dir): 252 """Checks whether the provided export directory could contain a SavedModel. 253 254 Note that the method does not load any data by itself. If the method returns 255 `false`, the export directory definitely does not contain a SavedModel. If the 256 method returns `true`, the export directory may contain a SavedModel but 257 provides no guarantee that it can be loaded. 258 259 Args: 260 export_dir: Absolute string path to possible export location. For example, 261 '/my/foo/model'. 262 263 Returns: 264 True if the export directory contains SavedModel files, False otherwise. 265 """ 266 return maybe_saved_model_directory(export_dir) 267 268 269@tf_export(v1=["saved_model.load", "saved_model.loader.load"]) 270@deprecation.deprecated( 271 None, 272 "This function will only be available through the v1 compatibility " 273 "library as tf.compat.v1.saved_model.loader.load or " 274 "tf.compat.v1.saved_model.load. There will be a new function for importing " 275 "SavedModels in Tensorflow 2.0.") 276def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): 277 """Loads the model from a SavedModel as specified by tags. 278 279 Args: 280 sess: The TensorFlow session to restore the variables. 281 tags: Set of string tags to identify the required MetaGraphDef. These should 282 correspond to the tags used when saving the variables using the 283 SavedModel `save()` API. 284 export_dir: Directory in which the SavedModel protocol buffer and variables 285 to be loaded are located. 286 import_scope: Optional `string` -- if specified, prepend this string 287 followed by '/' to all loaded tensor names. This scope is applied to 288 tensor instances loaded into the passed session, but it is *not* written 289 through to the static `MetaGraphDef` protocol buffer that is returned. 290 **saver_kwargs: Optional keyword arguments passed through to Saver. 291 292 Returns: 293 The `MetaGraphDef` protocol buffer loaded in the provided session. This 294 can be used to further extract signature-defs, collection-defs, etc. 295 296 Raises: 297 RuntimeError: MetaGraphDef associated with the tags cannot be found. 298 """ 299 loader = SavedModelLoader(export_dir) 300 return loader.load(sess, tags, import_scope, **saver_kwargs) 301 302 303class SavedModelLoader(object): 304 """Load graphs and restore variable values from a `SavedModel`.""" 305 306 def __init__(self, export_dir): 307 """Creates a `SavedModelLoader`. 308 309 Args: 310 export_dir: Directory in which the SavedModel protocol buffer and 311 variables to be loaded are located. 312 """ 313 self._export_dir = export_dir 314 self._variables_path = saved_model_utils.get_variables_path(export_dir) 315 self._saved_model = parse_saved_model(export_dir) 316 317 @property 318 def export_dir(self): 319 """Directory containing the SavedModel.""" 320 return self._export_dir 321 322 @property 323 def variables_path(self): 324 """Path to variable checkpoint files.""" 325 return self._variables_path 326 327 @property 328 def saved_model(self): 329 """SavedModel object parsed from the export directory.""" 330 return self._saved_model 331 332 def get_meta_graph_def_from_tags(self, tags): 333 """Return MetaGraphDef with the exact specified tags. 334 335 Args: 336 tags: A list or set of string tags that identify the MetaGraphDef. 337 338 Returns: 339 MetaGraphDef with the same tags. 340 341 Raises: 342 RuntimeError: if no metagraphs were found with the associated tags. 343 """ 344 found_match = False 345 available_tags = [] 346 for meta_graph_def in self._saved_model.meta_graphs: 347 available_tags.append(set(meta_graph_def.meta_info_def.tags)) 348 if set(meta_graph_def.meta_info_def.tags) == set(tags): 349 meta_graph_def_to_load = meta_graph_def 350 found_match = True 351 break 352 353 if not found_match: 354 raise RuntimeError( 355 "MetaGraphDef associated with tags " + str(tags).strip("[]") + 356 " could not be found in SavedModel. To inspect available tag-sets in" 357 " the SavedModel, please use the SavedModel CLI: `saved_model_cli`" 358 "\navailable_tags: " + str(available_tags)) 359 return meta_graph_def_to_load 360 361 def load_graph(self, graph, tags, import_scope=None, **saver_kwargs): 362 """Load ops and nodes from SavedModel MetaGraph into graph. 363 364 Args: 365 graph: tf.Graph object. 366 tags: a set of string tags identifying a MetaGraphDef. 367 import_scope: Optional `string` -- if specified, prepend this string 368 followed by '/' to all loaded tensor names. This scope is applied to 369 tensor instances loaded into the passed session, but it is *not* written 370 through to the static `MetaGraphDef` protocol buffer that is returned. 371 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. 372 373 Returns: 374 A tuple of 375 * Saver defined by the MetaGraph, which can be used to restore the 376 variable values. 377 * List of `Operation`/`Tensor` objects returned from 378 `tf.import_graph_def` (may be `None`). 379 """ 380 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 381 with graph.as_default(): 382 return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access 383 meta_graph_def, import_scope=import_scope, **saver_kwargs) 384 385 def restore_variables(self, sess, saver, import_scope=None): 386 """Restore SavedModel variable values into the session. 387 388 Args: 389 sess: tf.compat.v1.Session to restore variable values. 390 saver: a tf.compat.v1.train.Saver object. Can be None if there are no 391 variables in graph. This may be the saver returned by the load_graph() 392 function, or a default `tf.compat.v1.train.Saver()`. 393 import_scope: Optional `string` -- if specified, prepend this string 394 followed by '/' to all loaded tensor names. This scope is applied to 395 tensor instances loaded into the passed session, but it is *not* written 396 through to the static `MetaGraphDef` protocol buffer that is returned. 397 398 Raises: 399 ValueError: if no saver was passed to the saver argument, and there are 400 variables in the graph. 401 """ 402 with sess.graph.as_default(): 403 if (saver is None and 404 not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access 405 tf_logging.info("The specified SavedModel has no variables; no " 406 "checkpoints were restored.") 407 elif isinstance(saver, tf_saver.Saver): 408 saver.restore(sess, self._variables_path) 409 else: 410 raise ValueError( 411 "No tf.train.Saver object was passed to the function " 412 "SavedModelLoader.restore_variables. Since there are variables in " 413 "the graph, a saver is required.") 414 415 def run_init_ops(self, sess, tags, import_scope=None): 416 """Run initialization ops defined in the `MetaGraphDef`. 417 418 Args: 419 sess: tf.compat.v1.Session to restore variable values. 420 tags: a set of string tags identifying a MetaGraphDef. 421 import_scope: Optional `string` -- if specified, prepend this string 422 followed by '/' to all loaded tensor names. This scope is applied to 423 tensor instances loaded into the passed session, but it is *not* written 424 through to the static `MetaGraphDef` protocol buffer that is returned. 425 """ 426 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 427 with sess.graph.as_default(): 428 # Get asset tensors, if any. 429 asset_tensors_dictionary = get_asset_tensors( 430 self._export_dir, meta_graph_def, import_scope=import_scope) 431 432 init_op = get_init_op(meta_graph_def, import_scope) 433 if init_op is not None: 434 sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary) 435 436 def load(self, sess, tags, import_scope=None, **saver_kwargs): 437 """Load the MetaGraphDef graph and restore variable values into the session. 438 439 Args: 440 sess: tf.compat.v1.Session to restore variable values. 441 tags: a set of string tags identifying a MetaGraphDef. 442 import_scope: Optional `string` -- if specified, prepend this string 443 followed by '/' to all loaded tensor names. This scope is applied to 444 tensor instances loaded into the passed session, but it is *not* written 445 through to the static `MetaGraphDef` protocol buffer that is returned. 446 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. 447 448 Returns: 449 `MetagraphDef` proto of the graph that was loaded. 450 """ 451 with sess.graph.as_default(): 452 saver, _ = self.load_graph(sess.graph, tags, import_scope, 453 **saver_kwargs) 454 self.restore_variables(sess, saver, import_scope) 455 self.run_init_ops(sess, tags, import_scope) 456 return self.get_meta_graph_def_from_tags(tags) 457