1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SavedModel builder implementation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import os 23 24from google.protobuf.any_pb2 import Any 25 26from tensorflow.core.framework import types_pb2 27from tensorflow.core.protobuf import meta_graph_pb2 28from tensorflow.core.protobuf import saved_model_pb2 29from tensorflow.core.protobuf import saver_pb2 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import tf_logging 35from tensorflow.python.saved_model import constants 36from tensorflow.python.saved_model import signature_def_utils 37from tensorflow.python.saved_model import utils_impl as saved_model_utils 38from tensorflow.python.saved_model.pywrap_saved_model import metrics 39from tensorflow.python.training import saver as tf_saver 40from tensorflow.python.util import compat 41from tensorflow.python.util.deprecation import deprecated_args 42from tensorflow.python.util.tf_export import tf_export 43 44# API label for SavedModel metrics. 45_SAVE_BUILDER_LABEL = "save_v1_builder" 46 47 48# Base class for the SavedModelBuilder that is only used by Tensorflow 49# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead. 50@tf_export("__internal__.saved_model.SavedModelBuilder", v1=[]) 51class _SavedModelBuilder(object): 52 """Builds the `SavedModel` protocol buffer and saves variables and assets. 53 54 The `SavedModelBuilder` class provides the functionality to build a 55 `SavedModel` protocol buffer. Specifically, this allows multiple meta 56 graphs to be saved as part of a single language-neutral `SavedModel`, 57 while sharing variables and assets. 58 59 To build a SavedModel, the first meta graph must be saved with variables. 60 Subsequent meta graphs will simply be saved with their graph definitions. If 61 assets need to be saved and written or copied to disk, they can be provided 62 when the meta graph def is added. If multiple meta graph defs are associated 63 an asset of the same name, only the first version is retained. 64 65 Each meta graph added to the SavedModel must be annotated with tags. The tags 66 provide a means to identify the specific meta graph to load and restore, along 67 with the shared set of variables and assets. 68 69 Typical usage for the `SavedModelBuilder`: 70 71 ```python 72 ... 73 builder = tf.compat.v1.saved_model.Builder(export_dir) 74 75 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 76 ... 77 builder.add_meta_graph_and_variables(sess, 78 ["foo-tag"], 79 signature_def_map=foo_signatures, 80 assets_list=foo_assets) 81 ... 82 83 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 84 ... 85 builder.add_meta_graph(["bar-tag", "baz-tag"]) 86 ... 87 88 builder.save() 89 ``` 90 91 Note: This function will only be available through the v1 compatibility 92 library as tf.compat.v1.saved_model.builder.SavedModelBuilder or 93 tf.compat.v1.saved_model.Builder. Tensorflow 2.0 will introduce a new 94 object-based method of creating SavedModels. 95 """ 96 97 def __init__(self, export_dir): 98 self._saved_model = saved_model_pb2.SavedModel() 99 self._saved_model.saved_model_schema_version = ( 100 constants.SAVED_MODEL_SCHEMA_VERSION) 101 102 self._export_dir = export_dir 103 if file_io.file_exists(export_dir): 104 if file_io.list_directory(export_dir): 105 raise AssertionError( 106 "Export directory already exists, and isn't empty. Please choose " 107 "a different export directory, or delete all the contents of the " 108 "specified directory: %s" % export_dir) 109 else: 110 file_io.recursive_create_dir(self._export_dir) 111 112 # Boolean to track whether variables and assets corresponding to the 113 # SavedModel have been saved. Specifically, the first meta graph to be added 114 # MUST use the add_meta_graph_and_variables() API. Subsequent add operations 115 # on the SavedModel MUST use the add_meta_graph() API which does not save 116 # weights. 117 self._has_saved_variables = False 118 119 def _save_and_write_assets(self, meta_graph_def, assets_list=None): 120 """Saves asset to the meta graph and writes asset files to disk. 121 122 Args: 123 meta_graph_def: The meta graph def to which the assets will be added. 124 assets_list: The list where the asset paths are setup. 125 """ 126 # Creates a function that adds assets into the meta graph def. 127 write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def) 128 asset_filename_map = _maybe_save_assets(write_fn, assets_list) 129 130 # Return if there are no assets to write. 131 if not asset_filename_map: 132 tf_logging.info("No assets to write.") 133 return 134 135 # Copy assets from source path to destination path. 136 copy_assets_to_destination_dir(asset_filename_map, self._export_dir) 137 138 def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map): 139 """Tags the meta graph def and adds it to the SavedModel. 140 141 Tags the meta graph def with the supplied tags, adds signature defs to it if 142 provided and appends the meta graph def to the SavedModel proto. 143 144 Args: 145 meta_graph_def: The meta graph def to add to the SavedModel. 146 tags: The set of tags to annotate the meta graph def with. 147 signature_def_map: The map of signature defs to be added to the meta graph 148 def. 149 """ 150 for tag in tags: 151 meta_graph_def.meta_info_def.tags.append(tag) 152 153 if signature_def_map is not None: 154 for key in signature_def_map: 155 meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key]) 156 157 proto_meta_graph_def = self._saved_model.meta_graphs.add() 158 proto_meta_graph_def.CopyFrom(meta_graph_def) 159 160 def _validate_tensor_info(self, tensor_info): 161 """Validates the `TensorInfo` proto. 162 163 Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and 164 `dtype` fields exist and are non-empty. 165 166 Args: 167 tensor_info: `TensorInfo` protocol buffer to validate. 168 169 Raises: 170 AssertionError: If the `encoding` or `dtype` fields of the supplied 171 `TensorInfo` proto are not populated. 172 """ 173 if tensor_info is None: 174 raise AssertionError( 175 "All TensorInfo protos used in the SignatureDefs must have the name " 176 "and dtype fields set.") 177 if tensor_info.WhichOneof("encoding") is None: 178 # TODO(soergel) validate each of the fields of coo_sparse 179 raise AssertionError( 180 "All TensorInfo protos used in the SignatureDefs must have one of " 181 "the 'encoding' fields (e.g., name or coo_sparse) set: %s" 182 % tensor_info) 183 if tensor_info.WhichOneof("encoding") == "composite_tensor": 184 for component in tensor_info.composite_tensor.components: 185 self._validate_tensor_info(component) 186 elif tensor_info.dtype == types_pb2.DT_INVALID: 187 raise AssertionError( 188 "All TensorInfo protos used in the SignatureDefs must have the dtype " 189 "field set: %s" % tensor_info) 190 191 def _validate_signature_def_map(self, signature_def_map): 192 """Validates the `SignatureDef` entries in the signature def map. 193 194 Validation of entries in the signature def map includes ensuring that the 195 `name` and `dtype` fields of the TensorInfo protos of the `inputs` and 196 `outputs` of each `SignatureDef` are populated. Also ensures that reserved 197 SignatureDef keys for the initialization and train ops are not used. 198 199 Args: 200 signature_def_map: The map of signature defs to be validated. 201 202 Raises: 203 AssertionError: If a TensorInfo is not valid. 204 KeyError: If a reserved signature key is used in the map. 205 """ 206 for signature_def_key in signature_def_map: 207 signature_def = signature_def_map[signature_def_key] 208 inputs = signature_def.inputs 209 outputs = signature_def.outputs 210 for inputs_key in inputs: 211 self._validate_tensor_info(inputs[inputs_key]) 212 for outputs_key in outputs: 213 self._validate_tensor_info(outputs[outputs_key]) 214 if constants.INIT_OP_SIGNATURE_KEY in signature_def_map: 215 raise KeyError( 216 "SignatureDef map key \"{}\" is reserved for initialization. Please " 217 "use a different key.".format(constants.INIT_OP_SIGNATURE_KEY)) 218 if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map: 219 raise KeyError( 220 "SignatureDef map key \"{}\" is reserved for the train op. Please " 221 "use a different key.".format(constants.TRAIN_OP_SIGNATURE_KEY)) 222 223 def _maybe_create_saver(self, saver=None): 224 """Creates a sharded saver if one does not already exist.""" 225 if not saver: 226 # Initialize a saver to generate a sharded output for all saveables in the 227 # current scope. 228 saver = tf_saver.Saver( 229 variables._all_saveable_objects(), # pylint: disable=protected-access 230 sharded=True, 231 write_version=saver_pb2.SaverDef.V2, 232 allow_empty=True) 233 return saver 234 235 def add_meta_graph(self, 236 tags, 237 signature_def_map=None, 238 assets_list=None, 239 clear_devices=False, 240 init_op=None, 241 train_op=None, 242 saver=None): 243 """Adds the current meta graph to the SavedModel. 244 245 Creates a Saver in the current scope and uses the Saver to export the meta 246 graph def. Invoking this API requires the `add_meta_graph_and_variables()` 247 API to have been invoked before. 248 249 Args: 250 tags: The set of tags to annotate the meta graph def with. 251 signature_def_map: The map of signature defs to be added to the meta graph 252 def. 253 assets_list: Assets to be saved with SavedModel. Note 254 that this list should be a subset of the assets saved as part of 255 the first meta graph in the SavedModel. 256 clear_devices: Set to true if the device info on the default graph should 257 be cleared. 258 init_op: Op or group of ops to execute when the graph is loaded. Note 259 that when the init_op is specified it is run after the restore op at 260 load-time. 261 train_op: Op or group of opts that trains the model when run. This will 262 not be run automatically when the graph is loaded, instead saved in 263 a SignatureDef accessible through the exported MetaGraph. 264 saver: An instance of tf.compat.v1.train.Saver that will be used to export 265 the metagraph. If None, a sharded Saver that restores all variables will 266 be used. 267 268 Raises: 269 AssertionError: If the variables for the SavedModel have not been saved 270 yet, or if the graph already contains one or more legacy init ops. 271 """ 272 if not self._has_saved_variables: 273 raise AssertionError( 274 "Graph state including variables and assets has not been saved yet. " 275 "Please invoke `add_meta_graph_and_variables()` first.") 276 277 # Validate the signature def map to ensure all included TensorInfos are 278 # properly populated. 279 signature_def_map = signature_def_map or {} 280 self._validate_signature_def_map(signature_def_map) 281 282 # Create a SignatureDef pointing to the graph initialization op, which will 283 # be added to the MetaGraphDef. 284 _add_op_to_signature_def_map(signature_def_map, init_op, 285 constants.INIT_OP_SIGNATURE_KEY) 286 _add_op_to_signature_def_map(signature_def_map, train_op, 287 constants.TRAIN_OP_SIGNATURE_KEY) 288 289 saver = self._maybe_create_saver(saver) 290 291 # The graph almost certainly previously contained at least one Saver, and 292 # possibly several (e.g. one for loading a pretrained embedding, and another 293 # for the model weights). Removing the preexisting ones was the 294 # motivation for the clear_extraneous_savers option, but it turns out that 295 # there are edge cases where that option breaks the graph. Until that is 296 # resolved, we just leave the option set to False for now. 297 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 298 meta_graph_def = saver.export_meta_graph( 299 clear_devices=clear_devices, strip_default_attrs=True) 300 301 # Save asset files and write them to disk, if any. 302 self._save_and_write_assets(meta_graph_def, assets_list) 303 304 # Tag the meta graph def and add it to the SavedModel. 305 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 306 307 def add_meta_graph_and_variables(self, 308 sess, 309 tags, 310 signature_def_map=None, 311 assets_list=None, 312 clear_devices=False, 313 init_op=None, 314 train_op=None, 315 strip_default_attrs=False, 316 saver=None): 317 # pylint: disable=line-too-long 318 """Adds the current meta graph to the SavedModel and saves variables. 319 320 Creates a Saver to save the variables from the provided session. Exports the 321 corresponding meta graph def. This function assumes that the variables to be 322 saved have been initialized. For a given `SavedModelBuilder`, this API must 323 be called exactly once and for the first meta graph to save. For subsequent 324 meta graph defs to be added, the `add_meta_graph()` API must be used. 325 326 Args: 327 sess: The TensorFlow session from which to save the meta graph and 328 variables. 329 tags: The set of tags with which to save the meta graph. 330 signature_def_map: The map of signature def map to add to the meta graph 331 def. 332 assets_list: Assets to be saved with SavedModel. 333 clear_devices: Set to true if the device info on the default graph should 334 be cleared. 335 init_op: Op or group of ops to execute when the graph is loaded. Note 336 that when the init_op is specified it is run after the restore op at 337 load-time. 338 train_op: Op or group of ops that trains the model when run. This will 339 not be run automatically when the graph is loaded, instead saved in 340 a SignatureDef accessible through the exported MetaGraph. 341 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 342 removed from the NodeDefs. For a detailed guide, see 343 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 344 saver: An instance of tf.compat.v1.train.Saver that will be used to export the 345 metagraph and save variables. If None, a sharded Saver that restores 346 all variables will be used. 347 348 """ 349 # pylint: enable=line-too-long 350 if self._has_saved_variables: 351 raise AssertionError("Graph state including variables and assets has " 352 "already been saved. Please invoke " 353 "`add_meta_graph()` instead.") 354 355 # Validate the signature def map to ensure all included TensorInfos are 356 # properly populated. 357 signature_def_map = signature_def_map or {} 358 self._validate_signature_def_map(signature_def_map) 359 360 # Create a SignatureDef pointing to the graph initialization op, which will 361 # be added to the MetaGraphDef. 362 _add_op_to_signature_def_map(signature_def_map, init_op, 363 constants.INIT_OP_SIGNATURE_KEY) 364 _add_op_to_signature_def_map(signature_def_map, train_op, 365 constants.TRAIN_OP_SIGNATURE_KEY) 366 367 saved_model_utils.get_or_create_variables_dir(self._export_dir) 368 variables_path = saved_model_utils.get_variables_path(self._export_dir) 369 370 saver = self._maybe_create_saver(saver) 371 372 # Save the variables. Also, disable writing the checkpoint state proto. The 373 # file is not used during SavedModel loading. In addition, since a 374 # SavedModel can be copied or moved, this avoids the checkpoint state to 375 # become outdated. 376 saver.save(sess, variables_path, write_meta_graph=False, write_state=False) 377 378 # Export the meta graph def. 379 380 # The graph almost certainly previously contained at least one Saver, and 381 # possibly several (e.g. one for loading a pretrained embedding, and another 382 # for the model weights). Removing the preexisting ones was the 383 # motivation for the clear_extraneous_savers option, but it turns out that 384 # there are edge cases where that option breaks the graph. Until that is 385 # resolved, we just leave the option set to False for now. 386 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 387 meta_graph_def = saver.export_meta_graph( 388 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 389 390 # Save asset files and write them to disk, if any. 391 self._save_and_write_assets(meta_graph_def, assets_list) 392 393 # Tag the meta graph def and add it to the SavedModel. 394 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 395 396 # Mark this instance of SavedModel as having saved variables, such that 397 # subsequent attempts to save variables will fail. 398 self._has_saved_variables = True 399 400 def save(self, as_text=False): 401 """Writes a `SavedModel` protocol buffer to disk. 402 403 The function writes the SavedModel protocol buffer to the export directory 404 in a serialized format. 405 406 Args: 407 as_text: Writes the SavedModel protocol buffer in text format to 408 disk. Protocol buffers in text format are useful for debugging, but 409 parsing fails when it encounters an unknown field and so is not forward 410 compatible. This means changes to TensorFlow may prevent deployment of 411 new text format SavedModels to existing serving binaries. Do not deploy 412 `as_text` SavedModels to production. 413 414 Returns: 415 The path to which the SavedModel protocol buffer was written. 416 """ 417 metrics.IncrementWriteApi(_SAVE_BUILDER_LABEL) 418 if not file_io.file_exists(self._export_dir): 419 file_io.recursive_create_dir(self._export_dir) 420 421 if as_text: 422 path = os.path.join( 423 compat.as_bytes(self._export_dir), 424 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 425 file_io.write_string_to_file(path, str(self._saved_model)) 426 else: 427 path = os.path.join( 428 compat.as_bytes(self._export_dir), 429 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 430 file_io.write_string_to_file( 431 path, self._saved_model.SerializeToString(deterministic=True)) 432 tf_logging.info("SavedModel written to: %s", compat.as_text(path)) 433 metrics.IncrementWrite(write_version="1") 434 return path 435 436 437@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"]) # pylint: disable=missing-docstring 438class SavedModelBuilder(_SavedModelBuilder): 439 __doc__ = _SavedModelBuilder.__doc__.replace("assets_list", 440 "assets_collection") 441 442 def __init__(self, export_dir): 443 super(SavedModelBuilder, self).__init__(export_dir=export_dir) 444 445 def _add_collections(self, assets_collection, main_op, train_op): 446 """Add asset and op collections to be saved.""" 447 # Save asset files and write them to disk, if any. 448 self._save_and_write_assets(assets_collection) 449 450 self._maybe_add_main_op(main_op) 451 452 self._add_train_op(train_op) 453 454 def _save_and_write_assets(self, assets_collection_to_add=None): 455 """Saves asset to the meta graph and writes asset files to disk. 456 457 Args: 458 assets_collection_to_add: The collection where the asset paths are setup. 459 """ 460 # Add assets to the collection with key `saved_model.ASSETS_KEY`, in the 461 # graph. 462 asset_filename_map = _maybe_save_assets(_add_asset_to_collection, 463 assets_collection_to_add) 464 465 # Return if there are no assets to write. 466 if not asset_filename_map: 467 tf_logging.info("No assets to write.") 468 return 469 470 # Copy assets from source path to destination path. 471 copy_assets_to_destination_dir(asset_filename_map, self._export_dir) 472 473 def _maybe_add_main_op(self, main_op): 474 """Adds main op to the SavedModel. 475 476 Args: 477 main_op: Main op to run as part of graph initialization. If None, no main 478 op will be added to the graph. 479 480 Raises: 481 TypeError: If the main op is provided but is not of type `Operation`. 482 ValueError: if the Graph already contains an init op. 483 """ 484 if main_op is None: 485 return 486 487 if not isinstance(main_op, ops.Operation): 488 raise TypeError("main_op needs to be an Operation: %r" % main_op) 489 490 # Validate that no other init ops have been added to this graph already. 491 # We check main_op and legacy_init_op for thoroughness and explicitness. 492 for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY): 493 if ops.get_collection(init_op_key): 494 raise ValueError( 495 "Graph already contains one or more main ops under the " 496 "collection {}.".format(init_op_key)) 497 498 ops.add_to_collection(constants.MAIN_OP_KEY, main_op) 499 500 def _add_train_op(self, train_op): 501 """Add train op to the SavedModel. 502 503 Note that this functionality is in development, and liable to be 504 moved elsewhere. 505 506 Args: 507 train_op: Op or group of ops that are used for training. These are stored 508 as a collection with key TRAIN_OP_KEY, but not executed. 509 510 Raises: 511 TypeError if Train op is not of type `Operation`. 512 """ 513 if train_op is not None: 514 if (not isinstance(train_op, ops.Tensor) and 515 not isinstance(train_op, ops.Operation)): 516 raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op) 517 ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) 518 519 @deprecated_args(None, 520 "Pass your op to the equivalent parameter main_op instead.", 521 "legacy_init_op") 522 def add_meta_graph(self, 523 tags, 524 signature_def_map=None, 525 assets_collection=None, 526 legacy_init_op=None, 527 clear_devices=False, 528 main_op=None, 529 strip_default_attrs=False, 530 saver=None): 531 if not self._has_saved_variables: 532 raise AssertionError( 533 "Graph state including variables and assets has not been saved yet. " 534 "Please invoke `add_meta_graph_and_variables()` first.") 535 536 # Validate the signature def map to ensure all included TensorInfos are 537 # properly populated. 538 signature_def_map = signature_def_map or {} 539 self._validate_signature_def_map(signature_def_map) 540 541 # legacy_init_op is deprecated, and going away in TF 2.0. 542 # Re-mapping to main_op, as treatment is identical regardless. 543 main_op = main_op if main_op is not None else legacy_init_op 544 545 # Add assets and ops 546 self._add_collections(assets_collection, main_op, None) 547 548 saver = self._maybe_create_saver(saver) 549 550 # The graph almost certainly previously contained at least one Saver, and 551 # possibly several (e.g. one for loading a pretrained embedding, and another 552 # for the model weights). Removing the preexisting ones was the 553 # motivation for the clear_extraneous_savers option, but it turns out that 554 # there are edge cases where that option breaks the graph. Until that is 555 # resolved, we just leave the option set to False for now. 556 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 557 meta_graph_def = saver.export_meta_graph( 558 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 559 560 # Tag the meta graph def and add it to the SavedModel. 561 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 562 563 @deprecated_args(None, 564 "Pass your op to the equivalent parameter main_op instead.", 565 "legacy_init_op") 566 def add_meta_graph_and_variables(self, 567 sess, 568 tags, 569 signature_def_map=None, 570 assets_collection=None, 571 legacy_init_op=None, 572 clear_devices=False, 573 main_op=None, 574 strip_default_attrs=False, 575 saver=None): 576 if self._has_saved_variables: 577 raise AssertionError("Graph state including variables and assets has " 578 "already been saved. Please invoke " 579 "`add_meta_graph()` instead.") 580 581 # Validate the signature def map to ensure all included TensorInfos are 582 # properly populated. 583 signature_def_map = signature_def_map or {} 584 self._validate_signature_def_map(signature_def_map) 585 586 # legacy_init_op is deprecated, and going away in TF 2.0. 587 # Re-mapping to main_op, as treatment is identical regardless. 588 main_op = main_op or legacy_init_op 589 590 # Add assets and ops 591 self._add_collections(assets_collection, main_op, None) 592 593 saved_model_utils.get_or_create_variables_dir(self._export_dir) 594 variables_path = saved_model_utils.get_variables_path(self._export_dir) 595 596 saver = self._maybe_create_saver(saver) 597 598 # Save the variables. Also, disable writing the checkpoint state proto. The 599 # file is not used during SavedModel loading. In addition, since a 600 # SavedModel can be copied or moved, this avoids the checkpoint state to 601 # become outdated. 602 saver.save(sess, variables_path, write_meta_graph=False, write_state=False) 603 604 # Export the meta graph def. 605 606 # The graph almost certainly previously contained at least one Saver, and 607 # possibly several (e.g. one for loading a pretrained embedding, and another 608 # for the model weights). Removing the preexisting ones was the 609 # motivation for the clear_extraneous_savers option, but it turns out that 610 # there are edge cases where that option breaks the graph. Until that is 611 # resolved, we just leave the option set to False for now. 612 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 613 meta_graph_def = saver.export_meta_graph( 614 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 615 616 # Tag the meta graph def and add it to the SavedModel. 617 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 618 619 # Mark this instance of SavedModel as having saved variables, such that 620 # subsequent attempts to save variables will fail. 621 self._has_saved_variables = True 622 623 add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace( 624 "assets_list", "assets_collection") 625 add_meta_graph_and_variables.__doc__ = \ 626 _SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace( 627 "assets_list", "assets_collection") 628 629 630def _maybe_save_assets(write_fn, assets_to_add=None): 631 """Saves assets to the meta graph. 632 633 Args: 634 write_fn: A function callback that writes assets into meta graph. 635 assets_to_add: The list where the asset paths are setup. 636 637 Returns: 638 A dict of asset basenames for saving to the original full path to the asset. 639 640 Raises: 641 ValueError: Indicating an invalid filepath tensor. 642 """ 643 # Map of target file names to original filenames 644 asset_filename_map = {} 645 646 if assets_to_add is None: 647 tf_logging.info("No assets to save.") 648 return asset_filename_map 649 650 # Iterate over the supplied assets, build the `AssetFile` proto and add them 651 # to the meta graph. 652 for asset_tensor in assets_to_add: 653 asset_source_filepath = _asset_path_from_tensor(asset_tensor) 654 if not asset_source_filepath: 655 raise ValueError("Invalid asset filepath tensor %s" % asset_tensor) 656 657 asset_filename = get_asset_filename_to_add( 658 asset_source_filepath, asset_filename_map) 659 660 # Call the passed-in function that builds AssetFileDef proto and adds it 661 # to either the collection or asset_file_def field of the meta graph. 662 # Note that this should be done even when the file is a duplicate of an 663 # already-added file, as the tensor reference should still exist. 664 write_fn(asset_filename, asset_tensor) 665 666 # In the cases where we are adding a duplicate, this will result in the 667 # last of the filepaths being the one used for copying the file to the 668 # SavedModel. Since the files in question are the same, it doesn't matter 669 # either way. 670 asset_filename_map[asset_filename] = asset_source_filepath 671 672 tf_logging.info("Assets added to graph.") 673 return asset_filename_map 674 675 676def get_asset_filename_to_add(asset_filepath, asset_filename_map): 677 """Get a unique basename to add to the SavedModel if this file is unseen. 678 679 Assets come from users as full paths, and we save them out to the 680 SavedModel as basenames. In some cases, the basenames collide. Here, 681 we dedupe asset basenames by first checking if the file is the same, 682 and, if different, generate and return an index-suffixed basename 683 that can be used to add the asset to the SavedModel. 684 685 Args: 686 asset_filepath: the full path to the asset that is being saved 687 asset_filename_map: a dict of filenames used for saving the asset in 688 the SavedModel to full paths from which the filenames were derived. 689 690 Returns: 691 Uniquified filename string if the file is not a duplicate, or the original 692 filename if the file has already been seen and saved. 693 """ 694 asset_filename = os.path.basename(asset_filepath) 695 696 if asset_filename not in asset_filename_map: 697 # This is an unseen asset. Safe to add. 698 return asset_filename 699 700 other_asset_filepath = asset_filename_map[asset_filename] 701 if other_asset_filepath == asset_filepath: 702 # This is the same file, stored twice in the list. No need 703 # to make unique. 704 return asset_filename 705 706 # Else, asset_filename is in the map, and the filepath is different. Dedupe. 707 if not file_io.filecmp(asset_filepath, other_asset_filepath): 708 # Files are different; dedupe filenames. 709 return _get_unique_asset_filename(asset_filename, asset_filename_map) 710 711 # Files are the same; don't make unique. 712 return asset_filename 713 714 715def _get_unique_asset_filename(asset_filename, asset_filename_map): 716 i = 1 717 unique_filename = asset_filename 718 while unique_filename in asset_filename_map: 719 unique_filename = compat.as_bytes("_").join( 720 [compat.as_bytes(asset_filename), compat.as_bytes(str(i))]) 721 i += 1 722 return unique_filename 723 724 725def _asset_path_from_tensor(path_tensor): 726 """Returns the filepath value stored in constant `path_tensor`. 727 728 Args: 729 path_tensor: Tensor of a file-path. 730 731 Returns: 732 The string value i.e. path of the tensor, if valid. 733 734 Raises: 735 TypeError if tensor does not match expected op type, dtype or value. 736 """ 737 if not isinstance(path_tensor, ops.Tensor): 738 raise TypeError("Asset path tensor must be a Tensor.") 739 if path_tensor.op.type != "Const": 740 raise TypeError("Asset path tensor must be of type constant.") 741 if path_tensor.dtype != dtypes.string: 742 raise TypeError("Asset path tensor must be of dtype string.") 743 str_values = path_tensor.op.get_attr("value").string_val 744 if len(str_values) != 1: 745 raise TypeError("Asset path tensor must be a scalar.") 746 return str_values[0] 747 748 749def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor): 750 """Builds an asset proto and adds it to the meta graph def. 751 752 Args: 753 meta_graph_def: The meta graph def to which the asset will be added. 754 asset_filename: The filename of the asset to be added. 755 asset_tensor: The asset tensor used to populate the tensor info of the asset 756 proto. 757 """ 758 asset_proto = meta_graph_def.asset_file_def.add() 759 asset_proto.filename = asset_filename 760 asset_proto.tensor_info.name = asset_tensor.name 761 762 763def copy_assets_to_destination_dir(asset_filename_map, destination_dir): 764 """Copy all assets from source path to destination path.""" 765 assets_destination_dir = saved_model_utils.get_or_create_assets_dir( 766 destination_dir) 767 768 # Copy each asset from source path to destination path. 769 for asset_basename, asset_source_filepath in asset_filename_map.items(): 770 asset_destination_filepath = os.path.join( 771 compat.as_bytes(assets_destination_dir), 772 compat.as_bytes(asset_basename)) 773 774 # Only copy the asset file to the destination if it does not already 775 # exist. This is to ensure that an asset with the same name defined as 776 # part of multiple graphs is only copied the first time. 777 if not file_io.file_exists(asset_destination_filepath): 778 file_io.copy(asset_source_filepath, asset_destination_filepath) 779 780 tf_logging.info("Assets written to: %s", 781 compat.as_text(assets_destination_dir)) 782 783 784def _add_asset_to_collection(asset_filename, asset_tensor): 785 """Builds an asset proto and adds it to the asset collection of the graph. 786 787 Args: 788 asset_filename: The filename of the asset to be added. 789 asset_tensor: The asset tensor used to populate the tensor info of the 790 asset proto. 791 """ 792 asset_proto = meta_graph_pb2.AssetFileDef() 793 asset_proto.filename = asset_filename 794 asset_proto.tensor_info.name = asset_tensor.name 795 796 asset_any_proto = Any() 797 asset_any_proto.Pack(asset_proto) 798 ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto) 799 800 801def _add_op_to_signature_def_map(signature_def_map, op, key): 802 if op is not None: 803 signature_def_map[key] = signature_def_utils.op_signature_def(op, key) 804