1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Loader implementation for SavedModel with hermetic, language-neutral exports. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23 24from google.protobuf import message 25from google.protobuf import text_format 26 27from tensorflow.core.protobuf import meta_graph_pb2 28from tensorflow.core.protobuf import saved_model_pb2 29from tensorflow.python.framework import ops 30from tensorflow.python.lib.io import file_io 31from tensorflow.python.ops import variables 32from tensorflow.python.platform import tf_logging 33from tensorflow.python.saved_model import constants 34from tensorflow.python.saved_model import signature_def_utils 35from tensorflow.python.saved_model import utils_impl as saved_model_utils 36from tensorflow.python.training import saver as tf_saver 37from tensorflow.python.util import compat 38from tensorflow.python.util import deprecation 39from tensorflow.python.util.tf_export import tf_export 40 41 42def parse_saved_model(export_dir): 43 """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`. 44 45 Args: 46 export_dir: Directory containing the SavedModel file. 47 48 Returns: 49 A `SavedModel` protocol buffer. 50 51 Raises: 52 IOError: If the file does not exist, or cannot be successfully parsed. 53 """ 54 # Build the path to the SavedModel in pbtxt format. 55 path_to_pbtxt = os.path.join( 56 compat.as_bytes(export_dir), 57 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 58 # Build the path to the SavedModel in pb format. 59 path_to_pb = os.path.join( 60 compat.as_bytes(export_dir), 61 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 62 63 # Parse the SavedModel protocol buffer. 64 saved_model = saved_model_pb2.SavedModel() 65 if file_io.file_exists(path_to_pb): 66 try: 67 file_content = file_io.FileIO(path_to_pb, "rb").read() 68 saved_model.ParseFromString(file_content) 69 return saved_model 70 except message.DecodeError as e: 71 raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e))) 72 elif file_io.file_exists(path_to_pbtxt): 73 try: 74 file_content = file_io.FileIO(path_to_pbtxt, "rb").read() 75 text_format.Merge(file_content.decode("utf-8"), saved_model) 76 return saved_model 77 except text_format.ParseError as e: 78 raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e))) 79 else: 80 raise IOError("SavedModel file does not exist at: %s/{%s|%s}" % 81 (export_dir, 82 constants.SAVED_MODEL_FILENAME_PBTXT, 83 constants.SAVED_MODEL_FILENAME_PB)) 84 85 86# TODO(b/120594573): Make this symbol also available as private, so that 87# tensorflow_transform and tensorflow_estimator do not break. 88_parse_saved_model = parse_saved_model 89 90 91def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None): 92 """Gets the asset tensors, if defined in the meta graph def to load. 93 94 Args: 95 export_dir: Directory where the SavedModel is located. 96 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. 97 import_scope: Optional `string` -- if specified, prepend this followed by 98 '/' to all returned asset tensor names. 99 100 Returns: 101 A dictionary of asset tensors, keyed by the name of the asset tensor. The 102 value in the map corresponds to the absolute path of the asset file. 103 """ 104 # Collection-def that may contain the assets key. 105 collection_def = meta_graph_def_to_load.collection_def 106 107 asset_tensor_dict = {} 108 asset_protos = [] 109 110 if meta_graph_def_to_load.asset_file_def: 111 asset_protos = meta_graph_def_to_load.asset_file_def 112 elif constants.ASSETS_KEY in collection_def: 113 assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value 114 for asset_any_proto in assets_any_proto: 115 asset_proto = meta_graph_pb2.AssetFileDef() 116 asset_any_proto.Unpack(asset_proto) 117 asset_protos.append(asset_proto) 118 119 # Location of the assets for SavedModel. 120 assets_directory = os.path.join( 121 compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY)) 122 # Process each asset and add it to the asset tensor dictionary. 123 for asset_proto in asset_protos: 124 tensor_name = asset_proto.tensor_info.name 125 if import_scope: 126 tensor_name = "%s/%s" % (import_scope, tensor_name) 127 asset_tensor_dict[tensor_name] = os.path.join( 128 compat.as_bytes(assets_directory), 129 compat.as_bytes(asset_proto.filename)) 130 131 return asset_tensor_dict 132 133 134def _get_main_op_tensor( 135 meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY): 136 """Gets the main op tensor, if one exists. 137 138 Args: 139 meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. 140 init_op_key: name of collection to check; should be one of MAIN_OP_KEY 141 or the deprecated LEGACY_INIT_OP_KEY 142 143 Returns: 144 The main op tensor, if it exists and `None` otherwise. 145 146 Raises: 147 RuntimeError: If the collection def corresponding to the main op key has 148 other than exactly one tensor. 149 """ 150 # TODO(kathywu): Rename this method to _get_op_from_collection when 151 # dependency from SavedModelEstimator is removed. 152 collection_def = meta_graph_def_to_load.collection_def 153 init_op = None 154 if init_op_key in collection_def: 155 init_op_list = collection_def[init_op_key].node_list.value 156 if len(init_op_list) != 1: 157 raise RuntimeError("Expected exactly one SavedModel init op. " 158 "Found: {}".format(init_op_list)) 159 init_op = ops.get_collection(init_op_key)[0] 160 return init_op 161 162 163def _get_op_from_collection(meta_graph_def, op_key): 164 return _get_main_op_tensor(meta_graph_def, op_key) 165 166 167def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope): 168 """Retrieve op stored in the imported meta graph's signature def.""" 169 if op_signature_key in meta_graph_def.signature_def: 170 return signature_def_utils.load_op_from_signature_def( 171 meta_graph_def.signature_def[op_signature_key], op_signature_key, 172 import_scope) 173 else: 174 return None 175 176 177def get_init_op(meta_graph_def, import_scope=None): 178 return (_get_op_from_signature_def( 179 meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or 180 _get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or 181 _get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY)) 182 183 184def get_train_op(meta_graph_def, import_scope=None): 185 train_op = _get_op_from_signature_def( 186 meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope) 187 if train_op is None: 188 train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY) 189 return train_op 190 191 192@tf_export(v1=[ 193 "saved_model.contains_saved_model", 194 "saved_model.maybe_saved_model_directory", 195 "saved_model.loader.maybe_saved_model_directory" 196]) 197@deprecation.deprecated_endpoints( 198 "saved_model.loader.maybe_saved_model_directory") 199def maybe_saved_model_directory(export_dir): 200 """Checks whether the provided export directory could contain a SavedModel. 201 202 Note that the method does not load any data by itself. If the method returns 203 `false`, the export directory definitely does not contain a SavedModel. If the 204 method returns `true`, the export directory may contain a SavedModel but 205 provides no guarantee that it can be loaded. 206 207 Args: 208 export_dir: Absolute string path to possible export location. For example, 209 '/my/foo/model'. 210 211 Returns: 212 True if the export directory contains SavedModel files, False otherwise. 213 """ 214 txt_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT) 215 pb_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB) 216 return file_io.file_exists(txt_path) or file_io.file_exists(pb_path) 217 218 219@tf_export("saved_model.contains_saved_model", v1=[]) 220def contains_saved_model(export_dir): 221 """Checks whether the provided export directory could contain a SavedModel. 222 223 Note that the method does not load any data by itself. If the method returns 224 `false`, the export directory definitely does not contain a SavedModel. If the 225 method returns `true`, the export directory may contain a SavedModel but 226 provides no guarantee that it can be loaded. 227 228 Args: 229 export_dir: Absolute string path to possible export location. For example, 230 '/my/foo/model'. 231 232 Returns: 233 True if the export directory contains SavedModel files, False otherwise. 234 """ 235 return maybe_saved_model_directory(export_dir) 236 237 238@tf_export(v1=["saved_model.load", "saved_model.loader.load"]) 239@deprecation.deprecated( 240 None, 241 "This function will only be available through the v1 compatibility " 242 "library as tf.compat.v1.saved_model.loader.load or " 243 "tf.compat.v1.saved_model.load. There will be a new function for importing " 244 "SavedModels in Tensorflow 2.0.") 245def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): 246 """Loads the model from a SavedModel as specified by tags. 247 248 Args: 249 sess: The TensorFlow session to restore the variables. 250 tags: Set of string tags to identify the required MetaGraphDef. These should 251 correspond to the tags used when saving the variables using the 252 SavedModel `save()` API. 253 export_dir: Directory in which the SavedModel protocol buffer and variables 254 to be loaded are located. 255 import_scope: Optional `string` -- if specified, prepend this string 256 followed by '/' to all loaded tensor names. This scope is applied to 257 tensor instances loaded into the passed session, but it is *not* written 258 through to the static `MetaGraphDef` protocol buffer that is returned. 259 **saver_kwargs: Optional keyword arguments passed through to Saver. 260 261 Returns: 262 The `MetaGraphDef` protocol buffer loaded in the provided session. This 263 can be used to further extract signature-defs, collection-defs, etc. 264 265 Raises: 266 RuntimeError: MetaGraphDef associated with the tags cannot be found. 267 """ 268 loader = SavedModelLoader(export_dir) 269 return loader.load(sess, tags, import_scope, **saver_kwargs) 270 271 272class SavedModelLoader(object): 273 """Load graphs and restore variable values from a `SavedModel`.""" 274 275 def __init__(self, export_dir): 276 """Creates a `SavedModelLoader`. 277 278 Args: 279 export_dir: Directory in which the SavedModel protocol buffer and 280 variables to be loaded are located. 281 """ 282 self._export_dir = export_dir 283 self._variables_path = saved_model_utils.get_variables_path(export_dir) 284 self._saved_model = parse_saved_model(export_dir) 285 286 @property 287 def export_dir(self): 288 """Directory containing the SavedModel.""" 289 return self._export_dir 290 291 @property 292 def variables_path(self): 293 """Path to variable checkpoint files.""" 294 return self._variables_path 295 296 @property 297 def saved_model(self): 298 """SavedModel object parsed from the export directory.""" 299 return self._saved_model 300 301 def get_meta_graph_def_from_tags(self, tags): 302 """Return MetaGraphDef with the exact specified tags. 303 304 Args: 305 tags: A list or set of string tags that identify the MetaGraphDef. 306 307 Returns: 308 MetaGraphDef with the same tags. 309 310 Raises: 311 RuntimeError: if no metagraphs were found with the associated tags. 312 """ 313 found_match = False 314 for meta_graph_def in self._saved_model.meta_graphs: 315 if set(meta_graph_def.meta_info_def.tags) == set(tags): 316 meta_graph_def_to_load = meta_graph_def 317 found_match = True 318 break 319 320 if not found_match: 321 raise RuntimeError( 322 "MetaGraphDef associated with tags " + str(tags).strip("[]") + 323 " could not be found in SavedModel. To inspect available tag-sets in" 324 " the SavedModel, please use the SavedModel CLI: `saved_model_cli`" 325 ) 326 return meta_graph_def_to_load 327 328 def load_graph(self, graph, tags, import_scope=None, **saver_kwargs): 329 """Load ops and nodes from SavedModel MetaGraph into graph. 330 331 Args: 332 graph: tf.Graph object. 333 tags: a set of string tags identifying a MetaGraphDef. 334 import_scope: Optional `string` -- if specified, prepend this string 335 followed by '/' to all loaded tensor names. This scope is applied to 336 tensor instances loaded into the passed session, but it is *not* written 337 through to the static `MetaGraphDef` protocol buffer that is returned. 338 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. 339 340 Returns: 341 A tuple of 342 * Saver defined by the MetaGraph, which can be used to restore the 343 variable values. 344 * List of `Operation`/`Tensor` objects returned from 345 `tf.import_graph_def` (may be `None`). 346 """ 347 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 348 with graph.as_default(): 349 return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access 350 meta_graph_def, import_scope=import_scope, **saver_kwargs) 351 352 def restore_variables(self, sess, saver, import_scope=None): 353 """Restore SavedModel variable values into the session. 354 355 Args: 356 sess: tf.Session to restore variable values. 357 saver: a tf.train.Saver object. Can be None if there are no variables in 358 graph. This may be the saver returned by the load_graph() function, or a 359 default `tf.train.Saver()`. 360 import_scope: Optional `string` -- if specified, prepend this string 361 followed by '/' to all loaded tensor names. This scope is applied to 362 tensor instances loaded into the passed session, but it is *not* written 363 through to the static `MetaGraphDef` protocol buffer that is returned. 364 365 Raises: 366 ValueError: if no saver was passed to the saver argument, and there are 367 variables in the graph. 368 """ 369 with sess.graph.as_default(): 370 if (saver is None and 371 not variables._all_saveable_objects(scope=import_scope)): # pylint: disable=protected-access 372 tf_logging.info("The specified SavedModel has no variables; no " 373 "checkpoints were restored.") 374 elif isinstance(saver, tf_saver.Saver): 375 saver.restore(sess, self._variables_path) 376 else: 377 raise ValueError( 378 "No tf.train.Saver object was passed to the function " 379 "SavedModelLoader.restore_variables. Since there are variables in " 380 "the graph, a saver is required.") 381 382 def run_init_ops(self, sess, tags, import_scope=None): 383 """Run initialization ops defined in the `MetaGraphDef`. 384 385 Args: 386 sess: tf.Session to restore variable values. 387 tags: a set of string tags identifying a MetaGraphDef. 388 import_scope: Optional `string` -- if specified, prepend this string 389 followed by '/' to all loaded tensor names. This scope is applied to 390 tensor instances loaded into the passed session, but it is *not* written 391 through to the static `MetaGraphDef` protocol buffer that is returned. 392 """ 393 meta_graph_def = self.get_meta_graph_def_from_tags(tags) 394 with sess.graph.as_default(): 395 # Get asset tensors, if any. 396 asset_tensors_dictionary = get_asset_tensors( 397 self._export_dir, meta_graph_def, import_scope=import_scope) 398 399 init_op = get_init_op(meta_graph_def, import_scope) 400 if init_op is not None: 401 sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary) 402 403 def load(self, sess, tags, import_scope=None, **saver_kwargs): 404 """Load the MetaGraphDef graph and restore variable values into the session. 405 406 Args: 407 sess: tf.Session to restore variable values. 408 tags: a set of string tags identifying a MetaGraphDef. 409 import_scope: Optional `string` -- if specified, prepend this string 410 followed by '/' to all loaded tensor names. This scope is applied to 411 tensor instances loaded into the passed session, but it is *not* written 412 through to the static `MetaGraphDef` protocol buffer that is returned. 413 **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph. 414 415 Returns: 416 `MetagraphDef` proto of the graph that was loaded. 417 """ 418 with sess.graph.as_default(): 419 saver, _ = self.load_graph(sess.graph, tags, import_scope, 420 **saver_kwargs) 421 self.restore_variables(sess, saver, import_scope) 422 self.run_init_ops(sess, tags, import_scope) 423 return self.get_meta_graph_def_from_tags(tags) 424