1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SavedModel builder implementation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import os 23 24from google.protobuf.any_pb2 import Any 25 26from tensorflow.core.framework import types_pb2 27from tensorflow.core.protobuf import meta_graph_pb2 28from tensorflow.core.protobuf import saved_model_pb2 29from tensorflow.core.protobuf import saver_pb2 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import tf_logging 35from tensorflow.python.saved_model import constants 36from tensorflow.python.saved_model import signature_def_utils 37from tensorflow.python.saved_model import utils_impl as saved_model_utils 38from tensorflow.python.training import saver as tf_saver 39from tensorflow.python.util import compat 40from tensorflow.python.util.deprecation import deprecated_args 41from tensorflow.python.util.tf_export import tf_export 42 43 44# Base class for the SavedModelBuilder that is only used by Tensorflow 45# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead. 46class _SavedModelBuilder(object): 47 """Builds the `SavedModel` protocol buffer and saves variables and assets. 48 49 The `SavedModelBuilder` class provides the functionality to build a 50 `SavedModel` protocol buffer. Specifically, this allows multiple meta 51 graphs to be saved as part of a single language-neutral `SavedModel`, 52 while sharing variables and assets. 53 54 To build a SavedModel, the first meta graph must be saved with variables. 55 Subsequent meta graphs will simply be saved with their graph definitions. If 56 assets need to be saved and written or copied to disk, they can be provided 57 when the meta graph def is added. If multiple meta graph defs are associated 58 an asset of the same name, only the first version is retained. 59 60 Each meta graph added to the SavedModel must be annotated with tags. The tags 61 provide a means to identify the specific meta graph to load and restore, along 62 with the shared set of variables and assets. 63 64 Typical usage for the `SavedModelBuilder`: 65 66 ```python 67 ... 68 builder = tf.compat.v1.saved_model.Builder(export_dir) 69 70 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 71 ... 72 builder.add_meta_graph_and_variables(sess, 73 ["foo-tag"], 74 signature_def_map=foo_signatures, 75 assets_list=foo_assets) 76 ... 77 78 with tf.compat.v1.Session(graph=tf.Graph()) as sess: 79 ... 80 builder.add_meta_graph(["bar-tag", "baz-tag"]) 81 ... 82 83 builder.save() 84 ``` 85 86 Note: This function will only be available through the v1 compatibility 87 library as tf.compat.v1.saved_model.builder.SavedModelBuilder or 88 tf.compat.v1.saved_model.Builder. Tensorflow 2.0 will introduce a new 89 object-based method of creating SavedModels. 90 """ 91 92 def __init__(self, export_dir): 93 self._saved_model = saved_model_pb2.SavedModel() 94 self._saved_model.saved_model_schema_version = ( 95 constants.SAVED_MODEL_SCHEMA_VERSION) 96 97 self._export_dir = export_dir 98 if file_io.file_exists(export_dir): 99 if file_io.list_directory(export_dir): 100 raise AssertionError( 101 "Export directory already exists, and isn't empty. Please choose " 102 "a different export directory, or delete all the contents of the " 103 "specified directory: %s" % export_dir) 104 else: 105 file_io.recursive_create_dir(self._export_dir) 106 107 # Boolean to track whether variables and assets corresponding to the 108 # SavedModel have been saved. Specifically, the first meta graph to be added 109 # MUST use the add_meta_graph_and_variables() API. Subsequent add operations 110 # on the SavedModel MUST use the add_meta_graph() API which does not save 111 # weights. 112 self._has_saved_variables = False 113 114 def _save_and_write_assets(self, meta_graph_def, assets_list=None): 115 """Saves asset to the meta graph and writes asset files to disk. 116 117 Args: 118 meta_graph_def: The meta graph def to which the assets will be added. 119 assets_list: The list where the asset paths are setup. 120 """ 121 # Creates a function that adds assets into the meta graph def. 122 write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def) 123 asset_filename_map = _maybe_save_assets(write_fn, assets_list) 124 125 # Return if there are no assets to write. 126 if not asset_filename_map: 127 tf_logging.info("No assets to write.") 128 return 129 130 # Copy assets from source path to destination path. 131 copy_assets_to_destination_dir(asset_filename_map, self._export_dir) 132 133 def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map): 134 """Tags the meta graph def and adds it to the SavedModel. 135 136 Tags the meta graph def with the supplied tags, adds signature defs to it if 137 provided and appends the meta graph def to the SavedModel proto. 138 139 Args: 140 meta_graph_def: The meta graph def to add to the SavedModel. 141 tags: The set of tags to annotate the meta graph def with. 142 signature_def_map: The map of signature defs to be added to the meta graph 143 def. 144 """ 145 for tag in tags: 146 meta_graph_def.meta_info_def.tags.append(tag) 147 148 if signature_def_map is not None: 149 for key in signature_def_map: 150 meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key]) 151 152 proto_meta_graph_def = self._saved_model.meta_graphs.add() 153 proto_meta_graph_def.CopyFrom(meta_graph_def) 154 155 def _validate_tensor_info(self, tensor_info): 156 """Validates the `TensorInfo` proto. 157 158 Checks if the `encoding` (`name` or `coo_sparse` or `type_spec`) and 159 `dtype` fields exist and are non-empty. 160 161 Args: 162 tensor_info: `TensorInfo` protocol buffer to validate. 163 164 Raises: 165 AssertionError: If the `encoding` or `dtype` fields of the supplied 166 `TensorInfo` proto are not populated. 167 """ 168 if tensor_info is None: 169 raise AssertionError( 170 "All TensorInfo protos used in the SignatureDefs must have the name " 171 "and dtype fields set.") 172 if tensor_info.WhichOneof("encoding") is None: 173 # TODO(soergel) validate each of the fields of coo_sparse 174 raise AssertionError( 175 "All TensorInfo protos used in the SignatureDefs must have one of " 176 "the 'encoding' fields (e.g., name or coo_sparse) set: %s" 177 % tensor_info) 178 if tensor_info.WhichOneof("encoding") == "composite_tensor": 179 for component in tensor_info.composite_tensor.components: 180 self._validate_tensor_info(component) 181 elif tensor_info.dtype == types_pb2.DT_INVALID: 182 raise AssertionError( 183 "All TensorInfo protos used in the SignatureDefs must have the dtype " 184 "field set: %s" % tensor_info) 185 186 def _validate_signature_def_map(self, signature_def_map): 187 """Validates the `SignatureDef` entries in the signature def map. 188 189 Validation of entries in the signature def map includes ensuring that the 190 `name` and `dtype` fields of the TensorInfo protos of the `inputs` and 191 `outputs` of each `SignatureDef` are populated. Also ensures that reserved 192 SignatureDef keys for the initialization and train ops are not used. 193 194 Args: 195 signature_def_map: The map of signature defs to be validated. 196 197 Raises: 198 AssertionError: If a TensorInfo is not valid. 199 KeyError: If a reserved signature key is used in the map. 200 """ 201 for signature_def_key in signature_def_map: 202 signature_def = signature_def_map[signature_def_key] 203 inputs = signature_def.inputs 204 outputs = signature_def.outputs 205 for inputs_key in inputs: 206 self._validate_tensor_info(inputs[inputs_key]) 207 for outputs_key in outputs: 208 self._validate_tensor_info(outputs[outputs_key]) 209 if constants.INIT_OP_SIGNATURE_KEY in signature_def_map: 210 raise KeyError( 211 "SignatureDef map key \"{}\" is reserved for initialization. Please " 212 "use a different key.".format(constants.INIT_OP_SIGNATURE_KEY)) 213 if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map: 214 raise KeyError( 215 "SignatureDef map key \"{}\" is reserved for the train op. Please " 216 "use a different key.".format(constants.TRAIN_OP_SIGNATURE_KEY)) 217 218 def _maybe_create_saver(self, saver=None): 219 """Creates a sharded saver if one does not already exist.""" 220 if not saver: 221 # Initialize a saver to generate a sharded output for all saveables in the 222 # current scope. 223 saver = tf_saver.Saver( 224 variables._all_saveable_objects(), # pylint: disable=protected-access 225 sharded=True, 226 write_version=saver_pb2.SaverDef.V2, 227 allow_empty=True) 228 return saver 229 230 def add_meta_graph(self, 231 tags, 232 signature_def_map=None, 233 assets_list=None, 234 clear_devices=False, 235 init_op=None, 236 train_op=None, 237 saver=None): 238 """Adds the current meta graph to the SavedModel. 239 240 Creates a Saver in the current scope and uses the Saver to export the meta 241 graph def. Invoking this API requires the `add_meta_graph_and_variables()` 242 API to have been invoked before. 243 244 Args: 245 tags: The set of tags to annotate the meta graph def with. 246 signature_def_map: The map of signature defs to be added to the meta graph 247 def. 248 assets_list: Assets to be saved with SavedModel. Note 249 that this list should be a subset of the assets saved as part of 250 the first meta graph in the SavedModel. 251 clear_devices: Set to true if the device info on the default graph should 252 be cleared. 253 init_op: Op or group of ops to execute when the graph is loaded. Note 254 that when the init_op is specified it is run after the restore op at 255 load-time. 256 train_op: Op or group of opts that trains the model when run. This will 257 not be run automatically when the graph is loaded, instead saved in 258 a SignatureDef accessible through the exported MetaGraph. 259 saver: An instance of tf.compat.v1.train.Saver that will be used to export 260 the metagraph. If None, a sharded Saver that restores all variables will 261 be used. 262 263 Raises: 264 AssertionError: If the variables for the SavedModel have not been saved 265 yet, or if the graph already contains one or more legacy init ops. 266 """ 267 if not self._has_saved_variables: 268 raise AssertionError( 269 "Graph state including variables and assets has not been saved yet. " 270 "Please invoke `add_meta_graph_and_variables()` first.") 271 272 # Validate the signature def map to ensure all included TensorInfos are 273 # properly populated. 274 signature_def_map = signature_def_map or {} 275 self._validate_signature_def_map(signature_def_map) 276 277 # Create a SignatureDef pointing to the graph initialization op, which will 278 # be added to the MetaGraphDef. 279 _add_op_to_signature_def_map(signature_def_map, init_op, 280 constants.INIT_OP_SIGNATURE_KEY) 281 _add_op_to_signature_def_map(signature_def_map, train_op, 282 constants.TRAIN_OP_SIGNATURE_KEY) 283 284 saver = self._maybe_create_saver(saver) 285 286 # The graph almost certainly previously contained at least one Saver, and 287 # possibly several (e.g. one for loading a pretrained embedding, and another 288 # for the model weights). Removing the preexisting ones was the 289 # motivation for the clear_extraneous_savers option, but it turns out that 290 # there are edge cases where that option breaks the graph. Until that is 291 # resolved, we just leave the option set to False for now. 292 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 293 meta_graph_def = saver.export_meta_graph( 294 clear_devices=clear_devices, strip_default_attrs=True) 295 296 # Save asset files and write them to disk, if any. 297 self._save_and_write_assets(meta_graph_def, assets_list) 298 299 # Tag the meta graph def and add it to the SavedModel. 300 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 301 302 def add_meta_graph_and_variables(self, 303 sess, 304 tags, 305 signature_def_map=None, 306 assets_list=None, 307 clear_devices=False, 308 init_op=None, 309 train_op=None, 310 strip_default_attrs=False, 311 saver=None): 312 # pylint: disable=line-too-long 313 """Adds the current meta graph to the SavedModel and saves variables. 314 315 Creates a Saver to save the variables from the provided session. Exports the 316 corresponding meta graph def. This function assumes that the variables to be 317 saved have been initialized. For a given `SavedModelBuilder`, this API must 318 be called exactly once and for the first meta graph to save. For subsequent 319 meta graph defs to be added, the `add_meta_graph()` API must be used. 320 321 Args: 322 sess: The TensorFlow session from which to save the meta graph and 323 variables. 324 tags: The set of tags with which to save the meta graph. 325 signature_def_map: The map of signature def map to add to the meta graph 326 def. 327 assets_list: Assets to be saved with SavedModel. 328 clear_devices: Set to true if the device info on the default graph should 329 be cleared. 330 init_op: Op or group of ops to execute when the graph is loaded. Note 331 that when the init_op is specified it is run after the restore op at 332 load-time. 333 train_op: Op or group of ops that trains the model when run. This will 334 not be run automatically when the graph is loaded, instead saved in 335 a SignatureDef accessible through the exported MetaGraph. 336 strip_default_attrs: Boolean. If `True`, default-valued attributes will be 337 removed from the NodeDefs. For a detailed guide, see 338 [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). 339 saver: An instance of tf.compat.v1.train.Saver that will be used to export the 340 metagraph and save variables. If None, a sharded Saver that restores 341 all variables will be used. 342 343 """ 344 # pylint: enable=line-too-long 345 if self._has_saved_variables: 346 raise AssertionError("Graph state including variables and assets has " 347 "already been saved. Please invoke " 348 "`add_meta_graph()` instead.") 349 350 # Validate the signature def map to ensure all included TensorInfos are 351 # properly populated. 352 signature_def_map = signature_def_map or {} 353 self._validate_signature_def_map(signature_def_map) 354 355 # Create a SignatureDef pointing to the graph initialization op, which will 356 # be added to the MetaGraphDef. 357 _add_op_to_signature_def_map(signature_def_map, init_op, 358 constants.INIT_OP_SIGNATURE_KEY) 359 _add_op_to_signature_def_map(signature_def_map, train_op, 360 constants.TRAIN_OP_SIGNATURE_KEY) 361 362 saved_model_utils.get_or_create_variables_dir(self._export_dir) 363 variables_path = saved_model_utils.get_variables_path(self._export_dir) 364 365 saver = self._maybe_create_saver(saver) 366 367 # Save the variables. Also, disable writing the checkpoint state proto. The 368 # file is not used during SavedModel loading. In addition, since a 369 # SavedModel can be copied or moved, this avoids the checkpoint state to 370 # become outdated. 371 saver.save(sess, variables_path, write_meta_graph=False, write_state=False) 372 373 # Export the meta graph def. 374 375 # The graph almost certainly previously contained at least one Saver, and 376 # possibly several (e.g. one for loading a pretrained embedding, and another 377 # for the model weights). Removing the preexisting ones was the 378 # motivation for the clear_extraneous_savers option, but it turns out that 379 # there are edge cases where that option breaks the graph. Until that is 380 # resolved, we just leave the option set to False for now. 381 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 382 meta_graph_def = saver.export_meta_graph( 383 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 384 385 # Save asset files and write them to disk, if any. 386 self._save_and_write_assets(meta_graph_def, assets_list) 387 388 # Tag the meta graph def and add it to the SavedModel. 389 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 390 391 # Mark this instance of SavedModel as having saved variables, such that 392 # subsequent attempts to save variables will fail. 393 self._has_saved_variables = True 394 395 def save(self, as_text=False): 396 """Writes a `SavedModel` protocol buffer to disk. 397 398 The function writes the SavedModel protocol buffer to the export directory 399 in a serialized format. 400 401 Args: 402 as_text: Writes the SavedModel protocol buffer in text format to 403 disk. Protocol buffers in text format are useful for debugging, but 404 parsing fails when it encounters an unknown field and so is not forward 405 compatible. This means changes to TensorFlow may prevent deployment of 406 new text format SavedModels to existing serving binaries. Do not deploy 407 `as_text` SavedModels to production. 408 409 Returns: 410 The path to which the SavedModel protocol buffer was written. 411 """ 412 if not file_io.file_exists(self._export_dir): 413 file_io.recursive_create_dir(self._export_dir) 414 415 if as_text: 416 path = os.path.join( 417 compat.as_bytes(self._export_dir), 418 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 419 file_io.write_string_to_file(path, str(self._saved_model)) 420 else: 421 path = os.path.join( 422 compat.as_bytes(self._export_dir), 423 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 424 file_io.write_string_to_file( 425 path, self._saved_model.SerializeToString(deterministic=True)) 426 tf_logging.info("SavedModel written to: %s", compat.as_text(path)) 427 428 return path 429 430 431@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"]) # pylint: disable=missing-docstring 432class SavedModelBuilder(_SavedModelBuilder): 433 __doc__ = _SavedModelBuilder.__doc__.replace("assets_list", 434 "assets_collection") 435 436 def __init__(self, export_dir): 437 super(SavedModelBuilder, self).__init__(export_dir=export_dir) 438 439 def _add_collections(self, assets_collection, main_op, train_op): 440 """Add asset and op collections to be saved.""" 441 # Save asset files and write them to disk, if any. 442 self._save_and_write_assets(assets_collection) 443 444 self._maybe_add_main_op(main_op) 445 446 self._add_train_op(train_op) 447 448 def _save_and_write_assets(self, assets_collection_to_add=None): 449 """Saves asset to the meta graph and writes asset files to disk. 450 451 Args: 452 assets_collection_to_add: The collection where the asset paths are setup. 453 """ 454 # Add assets to the collection with key `saved_model.ASSETS_KEY`, in the 455 # graph. 456 asset_filename_map = _maybe_save_assets(_add_asset_to_collection, 457 assets_collection_to_add) 458 459 # Return if there are no assets to write. 460 if not asset_filename_map: 461 tf_logging.info("No assets to write.") 462 return 463 464 # Copy assets from source path to destination path. 465 copy_assets_to_destination_dir(asset_filename_map, self._export_dir) 466 467 def _maybe_add_main_op(self, main_op): 468 """Adds main op to the SavedModel. 469 470 Args: 471 main_op: Main op to run as part of graph initialization. If None, no main 472 op will be added to the graph. 473 474 Raises: 475 TypeError: If the main op is provided but is not of type `Operation`. 476 ValueError: if the Graph already contains an init op. 477 """ 478 if main_op is None: 479 return 480 481 if not isinstance(main_op, ops.Operation): 482 raise TypeError("main_op needs to be an Operation: %r" % main_op) 483 484 # Validate that no other init ops have been added to this graph already. 485 # We check main_op and legacy_init_op for thoroughness and explicitness. 486 for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY): 487 if ops.get_collection(init_op_key): 488 raise ValueError( 489 "Graph already contains one or more main ops under the " 490 "collection {}.".format(init_op_key)) 491 492 ops.add_to_collection(constants.MAIN_OP_KEY, main_op) 493 494 def _add_train_op(self, train_op): 495 """Add train op to the SavedModel. 496 497 Note that this functionality is in development, and liable to be 498 moved elsewhere. 499 500 Args: 501 train_op: Op or group of ops that are used for training. These are stored 502 as a collection with key TRAIN_OP_KEY, but not executed. 503 504 Raises: 505 TypeError if Train op is not of type `Operation`. 506 """ 507 if train_op is not None: 508 if (not isinstance(train_op, ops.Tensor) and 509 not isinstance(train_op, ops.Operation)): 510 raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op) 511 ops.add_to_collection(constants.TRAIN_OP_KEY, train_op) 512 513 @deprecated_args(None, 514 "Pass your op to the equivalent parameter main_op instead.", 515 "legacy_init_op") 516 def add_meta_graph(self, 517 tags, 518 signature_def_map=None, 519 assets_collection=None, 520 legacy_init_op=None, 521 clear_devices=False, 522 main_op=None, 523 strip_default_attrs=False, 524 saver=None): 525 if not self._has_saved_variables: 526 raise AssertionError( 527 "Graph state including variables and assets has not been saved yet. " 528 "Please invoke `add_meta_graph_and_variables()` first.") 529 530 # Validate the signature def map to ensure all included TensorInfos are 531 # properly populated. 532 signature_def_map = signature_def_map or {} 533 self._validate_signature_def_map(signature_def_map) 534 535 # legacy_init_op is deprecated, and going away in TF 2.0. 536 # Re-mapping to main_op, as treatment is identical regardless. 537 main_op = main_op if main_op is not None else legacy_init_op 538 539 # Add assets and ops 540 self._add_collections(assets_collection, main_op, None) 541 542 saver = self._maybe_create_saver(saver) 543 544 # The graph almost certainly previously contained at least one Saver, and 545 # possibly several (e.g. one for loading a pretrained embedding, and another 546 # for the model weights). Removing the preexisting ones was the 547 # motivation for the clear_extraneous_savers option, but it turns out that 548 # there are edge cases where that option breaks the graph. Until that is 549 # resolved, we just leave the option set to False for now. 550 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 551 meta_graph_def = saver.export_meta_graph( 552 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 553 554 # Tag the meta graph def and add it to the SavedModel. 555 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 556 557 @deprecated_args(None, 558 "Pass your op to the equivalent parameter main_op instead.", 559 "legacy_init_op") 560 def add_meta_graph_and_variables(self, 561 sess, 562 tags, 563 signature_def_map=None, 564 assets_collection=None, 565 legacy_init_op=None, 566 clear_devices=False, 567 main_op=None, 568 strip_default_attrs=False, 569 saver=None): 570 if self._has_saved_variables: 571 raise AssertionError("Graph state including variables and assets has " 572 "already been saved. Please invoke " 573 "`add_meta_graph()` instead.") 574 575 # Validate the signature def map to ensure all included TensorInfos are 576 # properly populated. 577 signature_def_map = signature_def_map or {} 578 self._validate_signature_def_map(signature_def_map) 579 580 # legacy_init_op is deprecated, and going away in TF 2.0. 581 # Re-mapping to main_op, as treatment is identical regardless. 582 main_op = main_op or legacy_init_op 583 584 # Add assets and ops 585 self._add_collections(assets_collection, main_op, None) 586 587 saved_model_utils.get_or_create_variables_dir(self._export_dir) 588 variables_path = saved_model_utils.get_variables_path(self._export_dir) 589 590 saver = self._maybe_create_saver(saver) 591 592 # Save the variables. Also, disable writing the checkpoint state proto. The 593 # file is not used during SavedModel loading. In addition, since a 594 # SavedModel can be copied or moved, this avoids the checkpoint state to 595 # become outdated. 596 saver.save(sess, variables_path, write_meta_graph=False, write_state=False) 597 598 # Export the meta graph def. 599 600 # The graph almost certainly previously contained at least one Saver, and 601 # possibly several (e.g. one for loading a pretrained embedding, and another 602 # for the model weights). Removing the preexisting ones was the 603 # motivation for the clear_extraneous_savers option, but it turns out that 604 # there are edge cases where that option breaks the graph. Until that is 605 # resolved, we just leave the option set to False for now. 606 # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. 607 meta_graph_def = saver.export_meta_graph( 608 clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) 609 610 # Tag the meta graph def and add it to the SavedModel. 611 self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) 612 613 # Mark this instance of SavedModel as having saved variables, such that 614 # subsequent attempts to save variables will fail. 615 self._has_saved_variables = True 616 617 add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace( 618 "assets_list", "assets_collection") 619 add_meta_graph_and_variables.__doc__ = \ 620 _SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace( 621 "assets_list", "assets_collection") 622 623 624def _maybe_save_assets(write_fn, assets_to_add=None): 625 """Saves assets to the meta graph. 626 627 Args: 628 write_fn: A function callback that writes assets into meta graph. 629 assets_to_add: The list where the asset paths are setup. 630 631 Returns: 632 A dict of asset basenames for saving to the original full path to the asset. 633 634 Raises: 635 ValueError: Indicating an invalid filepath tensor. 636 """ 637 # Map of target file names to original filenames 638 asset_filename_map = {} 639 640 if assets_to_add is None: 641 tf_logging.info("No assets to save.") 642 return asset_filename_map 643 644 # Iterate over the supplied assets, build the `AssetFile` proto and add them 645 # to the meta graph. 646 for asset_tensor in assets_to_add: 647 asset_source_filepath = _asset_path_from_tensor(asset_tensor) 648 if not asset_source_filepath: 649 raise ValueError("Invalid asset filepath tensor %s" % asset_tensor) 650 651 asset_filename = get_asset_filename_to_add( 652 asset_source_filepath, asset_filename_map) 653 654 # Call the passed-in function that builds AssetFileDef proto and adds it 655 # to either the collection or asset_file_def field of the meta graph. 656 # Note that this should be done even when the file is a duplicate of an 657 # already-added file, as the tensor reference should still exist. 658 write_fn(asset_filename, asset_tensor) 659 660 # In the cases where we are adding a duplicate, this will result in the 661 # last of the filepaths being the one used for copying the file to the 662 # SavedModel. Since the files in question are the same, it doesn't matter 663 # either way. 664 asset_filename_map[asset_filename] = asset_source_filepath 665 666 tf_logging.info("Assets added to graph.") 667 return asset_filename_map 668 669 670def get_asset_filename_to_add(asset_filepath, asset_filename_map): 671 """Get a unique basename to add to the SavedModel if this file is unseen. 672 673 Assets come from users as full paths, and we save them out to the 674 SavedModel as basenames. In some cases, the basenames collide. Here, 675 we dedupe asset basenames by first checking if the file is the same, 676 and, if different, generate and return an index-suffixed basename 677 that can be used to add the asset to the SavedModel. 678 679 Args: 680 asset_filepath: the full path to the asset that is being saved 681 asset_filename_map: a dict of filenames used for saving the asset in 682 the SavedModel to full paths from which the filenames were derived. 683 684 Returns: 685 Uniquified filename string if the file is not a duplicate, or the original 686 filename if the file has already been seen and saved. 687 """ 688 asset_filename = os.path.basename(asset_filepath) 689 690 if asset_filename not in asset_filename_map: 691 # This is an unseen asset. Safe to add. 692 return asset_filename 693 694 other_asset_filepath = asset_filename_map[asset_filename] 695 if other_asset_filepath == asset_filepath: 696 # This is the same file, stored twice in the list. No need 697 # to make unique. 698 return asset_filename 699 700 # Else, asset_filename is in the map, and the filepath is different. Dedupe. 701 if not file_io.filecmp(asset_filepath, other_asset_filepath): 702 # Files are different; dedupe filenames. 703 return _get_unique_asset_filename(asset_filename, asset_filename_map) 704 705 # Files are the same; don't make unique. 706 return asset_filename 707 708 709def _get_unique_asset_filename(asset_filename, asset_filename_map): 710 i = 1 711 unique_filename = asset_filename 712 while unique_filename in asset_filename_map: 713 unique_filename = compat.as_bytes("_").join( 714 [compat.as_bytes(asset_filename), compat.as_bytes(str(i))]) 715 i += 1 716 return unique_filename 717 718 719def _asset_path_from_tensor(path_tensor): 720 """Returns the filepath value stored in constant `path_tensor`. 721 722 Args: 723 path_tensor: Tensor of a file-path. 724 725 Returns: 726 The string value i.e. path of the tensor, if valid. 727 728 Raises: 729 TypeError if tensor does not match expected op type, dtype or value. 730 """ 731 if not isinstance(path_tensor, ops.Tensor): 732 raise TypeError("Asset path tensor must be a Tensor.") 733 if path_tensor.op.type != "Const": 734 raise TypeError("Asset path tensor must be of type constant.") 735 if path_tensor.dtype != dtypes.string: 736 raise TypeError("Asset path tensor must be of dtype string.") 737 str_values = path_tensor.op.get_attr("value").string_val 738 if len(str_values) != 1: 739 raise TypeError("Asset path tensor must be a scalar.") 740 return str_values[0] 741 742 743def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor): 744 """Builds an asset proto and adds it to the meta graph def. 745 746 Args: 747 meta_graph_def: The meta graph def to which the asset will be added. 748 asset_filename: The filename of the asset to be added. 749 asset_tensor: The asset tensor used to populate the tensor info of the asset 750 proto. 751 """ 752 asset_proto = meta_graph_def.asset_file_def.add() 753 asset_proto.filename = asset_filename 754 asset_proto.tensor_info.name = asset_tensor.name 755 756 757def copy_assets_to_destination_dir(asset_filename_map, destination_dir): 758 """Copy all assets from source path to destination path.""" 759 assets_destination_dir = saved_model_utils.get_or_create_assets_dir( 760 destination_dir) 761 762 # Copy each asset from source path to destination path. 763 for asset_basename, asset_source_filepath in asset_filename_map.items(): 764 asset_destination_filepath = os.path.join( 765 compat.as_bytes(assets_destination_dir), 766 compat.as_bytes(asset_basename)) 767 768 # Only copy the asset file to the destination if it does not already 769 # exist. This is to ensure that an asset with the same name defined as 770 # part of multiple graphs is only copied the first time. 771 if not file_io.file_exists(asset_destination_filepath): 772 file_io.copy(asset_source_filepath, asset_destination_filepath) 773 774 tf_logging.info("Assets written to: %s", 775 compat.as_text(assets_destination_dir)) 776 777 778def _add_asset_to_collection(asset_filename, asset_tensor): 779 """Builds an asset proto and adds it to the asset collection of the graph. 780 781 Args: 782 asset_filename: The filename of the asset to be added. 783 asset_tensor: The asset tensor used to populate the tensor info of the 784 asset proto. 785 """ 786 asset_proto = meta_graph_pb2.AssetFileDef() 787 asset_proto.filename = asset_filename 788 asset_proto.tensor_info.name = asset_tensor.name 789 790 asset_any_proto = Any() 791 asset_any_proto.Pack(asset_proto) 792 ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto) 793 794 795def _add_op_to_signature_def_map(signature_def_map, op, key): 796 if op is not None: 797 signature_def_map[key] = signature_def_utils.op_signature_def(op, key) 798