1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A Network is a composition of Layers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import os 23import weakref 24 25from tensorflow.python.eager import context 26from tensorflow.python.framework import ops 27from tensorflow.python.keras.engine import base_layer_utils 28from tensorflow.python.layers import base 29from tensorflow.python.ops import variable_scope 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.training import checkpoint_utils 32from tensorflow.python.training import saver as saver_lib 33from tensorflow.python.training import training_util 34from tensorflow.python.util import deprecation 35from tensorflow.python.util import function_utils 36 37# pylint: disable=protected-access 38# Explanation for protected-access disable: Network has lots of same-class and 39# parent-class references across different objects, and some to private 40# functions in base.py which should be reused. 41 42 43def _network_name_scope_naming(current_variable_scope): 44 """Name scope naming to match operation names to variable names. 45 46 Used in Networks and also applied to non-Network Layers which are added to 47 Networks before being built. 48 49 Args: 50 current_variable_scope: A VariableScope object. 51 Returns: 52 A name scope name. 53 """ 54 return current_variable_scope.name + "/" 55 56 57_NETWORK_DEPRECATION_MESSAGE = ( 58 "Please inherit from `tf.keras.Model`, and see its documentation for " 59 "details. `tf.keras.Model` should be a drop-in replacement for " 60 "`tfe.Network` in most cases, but note that `track_layer` is no longer " 61 "necessary or supported. Instead, `Layer` instances are tracked on " 62 "attribute assignment (see the section of `tf.keras.Model`'s documentation " 63 "on subclassing). Since the output of `track_layer` is often assigned to " 64 "an attribute anyway, most code can be ported by simply removing the " 65 "`track_layer` calls.\n\n`tf.keras.Model` works with all TensorFlow " 66 "`Layer` instances, including those from `tf.layers`, but switching to " 67 "the `tf.keras.layers` versions along with the migration to " 68 "`tf.keras.Model` is recommended, since it will preserve variable names. " 69 "Feel free to import it with an alias to avoid excess typing :)." 70) 71 72 73class Network(base.Layer): 74 """Represents the composition of a set of Layers. 75 76 *Deprecated*. Please inherit from `tf.keras.Model`, and see its documentation 77 for details. `tf.keras.Model` should be a drop-in replacement for 78 `tfe.Network` in most cases, but note that `track_layer` is no longer 79 necessary or supported. Instead, `Layer` instances are tracked on attribute 80 assignment (see the section of `tf.keras.Model`'s documentation on 81 subclassing). Since the output of `track_layer` is often assigned to an 82 attribute anyway, most code can be ported by simply removing the `track_layer` 83 calls. 84 85 `tf.keras.Model` works with all TensorFlow `Layer` instances, including those 86 from `tf.layers`, but switching to the `tf.keras.layers` versions along with 87 the migration to `tf.keras.Model` is recommended, since it will preserve 88 variable names. Feel free to import it with an alias to avoid excess typing 89 :). 90 91 `Network` implements the `Layer` interface and adds convenience methods for 92 managing sub-`Layer`s, such as listing variables. 93 94 `Layer`s (including other `Network`s) should be added via `track_layer`. They 95 can then be used when overriding the `Network.call` method: 96 97 ```python 98 class TwoLayerNetwork(tfe.Network): 99 100 def __init__(self, name): 101 super(TwoLayerNetwork, self).__init__(name=name) 102 self.layer_one = self.track_layer(tf.layers.Dense(16, input_shape=(8,))) 103 self.layer_two = self.track_layer(tf.layers.Dense(1, input_shape=(16,))) 104 105 def call(self, inputs): 106 return self.layer_two(self.layer_one(inputs)) 107 ``` 108 109 After constructing an object and calling the `Network`, a list of variables 110 created by tracked `Layer`s is available via `Network.variables`: 111 112 ```python 113 net = TwoLayerNetwork(name="net") 114 output = net(tf.ones([1, 8])) 115 print([v.name for v in net.variables]) 116 ``` 117 118 This example prints variable names, one kernel and one bias per 119 `tf.layers.Dense` layer: 120 121 ``` 122 ['net/dense/kernel:0', 123 'net/dense/bias:0', 124 'net/dense_1/kernel:0', 125 'net/dense_1/bias:0'] 126 ``` 127 128 These variables can be passed to a `Saver` (`tf.train.Saver`, or 129 `tf.contrib.eager.Saver` when executing eagerly) to save or restore the 130 `Network`, typically alongside a global step and `tf.train.Optimizer` 131 variables when checkpointing during training. 132 133 Note that the semantics of calling a `Network` with graph execution (i.e. not 134 executing eagerly) may change slightly in the future. Currently stateful ops 135 are pruned from the graph unless they or something that depends on them is 136 executed in a session, but this behavior is not consistent with eager 137 execution (where stateful ops are executed eagerly). `Layer`s from `tf.layers` 138 do not depend on this pruning and so will not be affected, but `Network`s 139 which rely on stateful ops being added to the graph but not executed (e.g. via 140 custom `Layer`s which manage stateful ops) may break with this change. 141 """ 142 # TODO(josh11b,ashankar,allenl): 143 # - Should 'trainable' be changeable on the Network object? 144 # - Do we allow add_variable in Network? 145 # - Detect layers used in __call__ that weren't registered with track_layer. 146 # - Convert inputs to __call__ to tensors. 147 148 @deprecation.deprecated(date=None, instructions=_NETWORK_DEPRECATION_MESSAGE) 149 def __init__(self, name=None): 150 """Configure the `Network`. 151 152 Args: 153 name: The name to use for this `Network`. If specified, it must be unique 154 in the context where this `Network` is first 155 (1) added to another `Network` (in which case it must not share a name 156 with other `Layers` added to that `Network`), or 157 (2) built/called (in which case no other 'top-level' `Network`s may 158 share this name). 159 If unspecified or None, the `Network` will be named using its class 160 name, with a number appended if necessary for uniqueness (e.g. MyNetwork 161 -> 'my_network_1'). 162 163 Raises: 164 ValueError: If `name` is not valid. Note that some naming errors will 165 instead be raised when the `Network` is called. 166 """ 167 if context.executing_eagerly(): 168 logging.warning( 169 ("** tfe.Network is deprecated and will be removed in a future " 170 "version.\n\n%s") % _NETWORK_DEPRECATION_MESSAGE) 171 if isinstance(name, variable_scope.VariableScope): 172 raise ValueError("VariableScopes are not valid Network names.") 173 if name is not None and "/" in name: 174 raise ValueError( 175 "Forward slashes ('/') are not allowed in Network names.") 176 super(Network, self).__init__(name=name) 177 self._layers = [] 178 self._sub_layer_name_uids = collections.defaultdict(int) 179 # Initially None, but set to False for networks which are first built as 180 # top-level. 181 self._first_parent = None # A weak reference to our first parent. 182 self._non_network_sublayers = [] 183 self._owned_layers = {} 184 # The scope to use if we end up without a parent. 185 self._default_parent_variable_scope = variable_scope.get_variable_scope() 186 # Hold on to the variable scope counts from init to check whether a scope 187 # with the name we want was ever created in our parent scope. Without this 188 # check we might have name collisions if the parent scope on init gets 189 # closed before build is called. 190 self._variable_scope_counts_on_init = ( 191 variable_scope.get_variable_scope_store().variable_scopes_count) 192 193 def _gather_saveables_for_checkpoint(self): 194 raise NotImplementedError( 195 "tfe.Network does not support object-based checkpointing.\n\n%s" 196 % _NETWORK_DEPRECATION_MESSAGE) 197 198 def _name_scope_name(self, current_variable_scope): 199 """Overrides Layer op naming to match variable naming.""" 200 return _network_name_scope_naming( 201 current_variable_scope=current_variable_scope) 202 203 def _init_set_name(self, name): 204 # Anonymous Networks (name=None) defer setting a final name until they are 205 # (1) added to another Network, or (2) built/called (where (2) is only used 206 # for a "top level" network). 207 # 208 # However, if we were provided an explicit name (name is not None), that 209 # will always be the final name of the Network; if it turns out not to be 210 # unique or if variable names can't be prefixed by it we will throw an 211 # error. 212 self._name = name 213 self._base_name = None 214 215 def _finalize_name(self, parent_network): 216 if not self._name: 217 # Were were not passed a name explicitly (or it was blank), so this is an 218 # anonymous Network. We make up a unique name. 219 if parent_network: 220 avoid_names = parent_network._owned_layers 221 name_uid_map = parent_network._sub_layer_name_uids 222 else: 223 name_uid_map = base_layer_utils.get_default_graph_uid_map() 224 # Figure out which names we have to avoid based on which variable scope 225 # we're nested in. 226 strip_name = self._default_parent_variable_scope.name 227 if strip_name: 228 strip_name += "/" 229 def _strip_on_init_scope(name): 230 if name.startswith(strip_name): 231 return name[len(strip_name):] 232 else: 233 return None 234 avoid_names = set( 235 _strip_on_init_scope(name) 236 for name in self._variable_scope_counts_on_init.keys() if name) 237 self._name, self._base_name = self._make_unique_name( 238 name_uid_map=name_uid_map, avoid_names=avoid_names, 239 namespace=self._default_parent_variable_scope.name, 240 zero_based=True) 241 if self._first_parent is None or (self._first_parent # False = no parent 242 and self._first_parent() is None): 243 # Save a pointer to the parent Network so that we can later check that the 244 # scope name we get is correct. 245 if not parent_network: 246 self._first_parent = parent_network 247 else: 248 self._first_parent = weakref.ref(parent_network) 249 250 def _set_scope(self, scope=None): 251 if self._scope is None: 252 if not self._first_parent: 253 first_parent = self._first_parent 254 else: 255 first_parent = self._first_parent() 256 if first_parent is None: 257 # If we were never added to another Network, or that Network has beed 258 # garbage collected before being called, then we're a top-level Network. 259 self._finalize_name( 260 # Use False to make sure the value sticks and we don't inherit a 261 # parent if we're added to a network later. 262 parent_network=False) 263 if scope is not None: 264 raise ValueError("Networks may not be created with explicit scopes.") 265 if first_parent: 266 first_parent._set_scope() 267 parent_scope = first_parent._scope 268 else: 269 parent_scope = self._default_parent_variable_scope 270 with variable_scope.variable_scope(parent_scope) as parent_vs: 271 expected_scope_name = parent_vs.name + "/" + self._name 272 if expected_scope_name in self._variable_scope_counts_on_init: 273 raise ValueError( 274 ("A Network named '%s' already exists (or a variable_scope was " 275 "created with this name). Names must be unique.") % ( 276 self._name,)) 277 # Make sure variables with this prefix will be unique. 278 with variable_scope.variable_scope( 279 None, use_resource=True, default_name=self._name) as scope: 280 self._scope = scope 281 scope_name = scope.name 282 suffix_start = scope_name.rfind("/") + 1 283 # rfind is -1 if there is no slash in the string, in which case the 284 # suffix starts at the beginning of the string (there is no prefix). 285 scope_suffix = scope_name[suffix_start:] 286 scope_prefix = scope_name[:suffix_start] 287 if scope_suffix != self._name: 288 raise ValueError( 289 ("A Network named '%s' already exists (or a variable_scope was " 290 "created with this name). Names must be unique.") % ( 291 self._name,)) 292 if (first_parent 293 and scope_prefix[:-1] != first_parent.scope_name): 294 raise ValueError( 295 ("Network variable names must match a nesting of sub-Network " 296 "names. Expected prefix '%s' from parent network, but got " 297 "'%s' when attempting to create a variable_scope for Network " 298 "'%s'. Likely an explicit variable_scope was inserted into " 299 "the nesting.") % ( 300 first_parent.scope_name, 301 scope_prefix[:-1], 302 self._name)) 303 elif not first_parent and scope_prefix: 304 # For the case when this Network is not nested inside any other 305 # Network, but is in a variable_scope. This Network's name takes on 306 # the full variable scope prefix. 307 self._name = scope_name 308 309 for non_network_sublayer in self._non_network_sublayers: 310 self._set_scope_for_nonnetwork_sublayer(non_network_sublayer) 311 312 def _set_scope_for_nonnetwork_sublayer(self, sublayer): 313 if sublayer._scope is None: 314 if sublayer._first_parent is None: 315 constituent_first_parent = None 316 else: 317 constituent_first_parent = sublayer._first_parent() 318 if constituent_first_parent: 319 constituent_first_parent._set_scope() 320 parent_scope = constituent_first_parent._scope 321 else: 322 self._finalize_name(False) 323 raise ValueError( 324 ("The parent of a Layer added to Network %s was garbage collected " 325 "before the Layer was built. If this limitation bothers you " 326 "please file a feature request.") % 327 (self.name,)) 328 with variable_scope.variable_scope(parent_scope): 329 # Horrid hack to make Layer variable names which are direct 330 # sub-layers of Networks conform to the Network variable naming 331 # conventions. 332 with variable_scope.variable_scope( 333 None, use_resource=True, 334 default_name=sublayer.name) as sub_scope: 335 sublayer._scope = sub_scope 336 # Also switch op naming for this Layer to match Network conventions, 337 # i.e. op naming matching variable naming. 338 sublayer._name_scope_name = _network_name_scope_naming 339 340 @base.Layer.name.getter 341 def name(self): 342 if self._name is None: 343 raise ValueError( 344 "The network does not yet have a final name, but a name was " 345 "requested for it. Networks get a name when they are added to " 346 "another Network via track_layer, or when they are first " 347 "called/built.") 348 return self._name 349 350 def track_layer(self, layer): 351 """Track a Layer in this Network. 352 353 `Network` requires that all `Layer`s used in `call()` be tracked so that the 354 `Network` can export a complete list of variables. 355 356 Args: 357 layer: A `tf.layers.Layer` object. 358 359 Returns: 360 The passed in `layer`. 361 362 Raises: 363 RuntimeError: If __init__ has not been called. 364 TypeError: If `layer` is the wrong type. 365 ValueError: If a `Layer` with the same name has already been added. 366 """ 367 if not hasattr(self, "_layers"): 368 raise RuntimeError("Need to call Network.__init__ before adding layers") 369 if not isinstance(layer, base.Layer): 370 raise TypeError( 371 "Network.track_layer() passed type %s, not a tf.layers.Layer" % 372 (type(layer),)) 373 # Always use `ResourceVariable` with legacy layers. 374 layer._use_resource_variables = True 375 if isinstance(layer, Network): 376 layer._finalize_name(parent_network=self) 377 else: 378 # `layer` is a non-Network, so it hasn't been named to follow Network 379 # conventions for contained Layers (i.e. the same conventions as for 380 # sub-Networks). This renaming is necessary to isolate Network variable 381 # naming from Layers constructed outside the Network and never added to it 382 # (because Layers are named globally). 383 if not layer.built: 384 if not hasattr(layer, "_first_parent"): 385 dereferenced_layer_first_parent = None 386 else: 387 dereferenced_layer_first_parent = layer._first_parent() 388 if dereferenced_layer_first_parent is None: 389 if layer._name != layer._base_name: 390 # If name and base_name do not match, then this Layer used anonymous 391 # naming and we have to rename it. Otherwise there's an explicit 392 # name, and we should respect it (subject to error checking). 393 layer._name, layer._base_name = layer._make_unique_name( 394 name_uid_map=self._sub_layer_name_uids, 395 avoid_names=self._owned_layers, 396 zero_based=True 397 # No namespace required, since we've specified our own UID map. 398 ) 399 layer._first_parent = weakref.ref(self) 400 self._non_network_sublayers.append(layer) 401 if (not layer.built 402 and layer._first_parent 403 and self is layer._first_parent()): 404 if layer.name in self._owned_layers: 405 if self._owned_layers[layer.name] is layer: 406 return layer 407 raise ValueError( 408 "Attempt to add two Layers with the name '%s' to the same Network." 409 % (layer.name)) 410 self._owned_layers[layer.name] = layer 411 self._layers.append(layer) 412 return layer 413 414 def get_layer(self, name=None, index=None): 415 """Get a contained `tf.layers.Layer` either by name or index. 416 417 Args: 418 name: String matching one of the names of a contained `Layer`. Note that 419 the names of `Layer`s added to `Network`s may not be unique when doing 420 layer sharing (i.e. adding a `Layer` to this `Network` which was already 421 added to another `Network`). The lowest index `Layer` with a matching 422 name will be returned. 423 index: Integer in [0, number of layers). Layers are assigned an index 424 by the order they are added. 425 426 Returns: 427 A `tf.layers.Layer` object. 428 429 Raises: 430 ValueError: If neither or both of 'index' or 'name' is specified, or the 431 lookup failed. 432 """ 433 if index is not None: 434 if name is not None: 435 raise ValueError("Exactly one of 'index' or 'name' must be provided") 436 if len(self._layers) <= index: 437 raise ValueError("Was asked to retrieve layer at index " + str(index) + 438 " but model only has " + str(len(self._layers)) + 439 " layers.") 440 else: 441 return self._layers[index] 442 else: 443 if not name: 444 raise ValueError("Provide either a layer name or layer index.") 445 for layer in self._layers: 446 if layer.name == name: 447 return layer 448 raise ValueError("No such layer: " + name) 449 450 # The following methods are for implementing the Layer interface. 451 452 @property 453 def weights(self): 454 # TODO(josh11b): Should this return a set or perform de-duplication of 455 # variables in the case of shared layers/variables that appear in 456 # multiple places in the Network? 457 weights = [] 458 for layer in self._layers: 459 weights += layer.weights 460 return weights 461 462 @property 463 def trainable_weights(self): 464 weights = [] 465 for layer in self._layers: 466 weights += layer.trainable_weights 467 return weights 468 469 @property 470 def non_trainable_weights(self): 471 weights = [] 472 for layer in self._layers: 473 weights += layer.non_trainable_weights 474 return weights 475 476 @property 477 def trainable(self): 478 return True 479 480 @trainable.setter 481 def trainable(self, value): 482 if not value: 483 # We believe it better to decide which layers & networks are trainable 484 # at the Trainer level than here. Otherwise you can run into trouble if a 485 # layer/network is shared between two models, but is trainable in one 486 # but not the other (like with adversarial networks). 487 raise AttributeError("cannot mark Network as not trainable") 488 489 @property 490 def layers(self): 491 return self._layers 492 493 def add_variable(self, name, shape, dtype=None, initializer=None, 494 regularizer=None, trainable=True, constraint=None): 495 raise RuntimeError( 496 "add_variable not supported in Network class yet. Please file an issue " 497 "at https://github.com/tensorflow/tensorflow/issues/new if this is " 498 "important to you") 499 500 def add_loss(self, losses, inputs=None): 501 raise RuntimeError( 502 "add_loss is not supported in Network class yet. Please file an issue " 503 "at https://github.com/tensorflow/tensorflow/issues/new if this is " 504 "important to you") 505 506 @property 507 def losses(self): 508 """Gather losses from `Layer`s in the `Network`. 509 510 Note that when executing eagerly, `Layer.losses` evaluates 511 regularizers. When using graph execution, variable regularization ops have 512 already been created and are simply returned here. 513 514 Returns: 515 A list of tensors. 516 """ 517 layer_losses = [] 518 for layer in self.layers: 519 layer_losses.extend(layer.losses) 520 return layer_losses 521 522 # TODO(allenl): Support other Layer methods needed for graph mode, such as for 523 # updates 524 525 526class Sequential(Network): 527 """Represents a linear sequence of Layers or functions. 528 529 The output of each layer/function is provided as the input to the next. 530 The inputs passed to `__call__` are passed to the inputs of the first 531 Layer, and it returns the outputs of the last Layer. 532 533 Args: 534 layers_funcs: An optional sequence where each element is either a 535 tf.layers.Layer object or a callable. 536 name: An optional string name to use for this Network. 537 """ 538 539 def __init__(self, layers_funcs=None, name=None): 540 super(Sequential, self).__init__(name=name) 541 self._layers_funcs = [] 542 if layers_funcs: 543 for l in layers_funcs: 544 self.add(l) 545 546 def add(self, layer_func): 547 if isinstance(layer_func, base.Layer): 548 args = function_utils.fn_args(layer_func.call) 549 self.track_layer(layer_func) 550 elif callable(layer_func): 551 args = function_utils.fn_args(layer_func) 552 else: 553 raise TypeError( 554 "Sequential.add() takes only tf.layers.Layer objects or callables; " 555 "not '%s' of type '%s'." % (layer_func, type(layer_func))) 556 self._layers_funcs.append((("training" in args), layer_func)) 557 558 def call(self, inputs, training=None): 559 """Call each Layer in the order they were added.""" 560 # TODO(josh11b): Support "mode" and maybe other arguments 561 if training is None: 562 for _, l in self._layers_funcs: 563 inputs = l(inputs) 564 else: 565 for has_training_arg, l in self._layers_funcs: 566 if has_training_arg: 567 inputs = l(inputs, training) 568 else: 569 inputs = l(inputs) 570 return inputs 571 572 573_DeferredRestoration = collections.namedtuple( 574 575 "_DeferredRestoration", 576 [ 577 # The map_func to use (either user-specified or the default). 578 "map_func", 579 # Boolean, True if the user specified an explicit map_func, for error 580 # messages. 581 "map_func_is_user", 582 # A mapping from checkpoint names to initial values of not-yet-created 583 # variables which should be restored. These values come from parsing a 584 # checkpoint. 585 "checkpointed_variables_to_restore", 586 # A mapping from checkpoint name to variable objects of variables which 587 # have already been restored, for error checking. 588 "restored_variables", 589 # The session to restore with (if in graph mode). 590 "session", 591 # Names of the Network where the restore was requested, for error 592 # messages. 593 "network_name", 594 "network_scope_name" 595 ]) 596 597 598def _default_naming_conflict_error_message( 599 mapped_name, first_variable, second_variable, 600 network_name, network_scope_name): 601 return ( 602 ("The default checkpoint variable name mapping strategy for Network " 603 "'%s' resulted in a naming conflict. We attempted to strip off the " 604 "variable prefix for the Network ('%s'), but this resulted in two " 605 "variables named '%s' (originally '%s' and '%s'). This should only " 606 "happen when using variable sharing (i.e. the Network contains Networks " 607 "or Layers which were first added to another Network, and therefore " 608 "have that Network's variable prefix). One solution is to pass " 609 "`map_func=lambda n: n` to save and restore to use fully qualified " 610 "variable names in the checkpoint, although this will require that the " 611 "variable prefix of the Network being restored into is also '%s'. You " 612 "may alternatively write an arbitrary mapping.") 613 % ( 614 network_name, network_scope_name, mapped_name, 615 first_variable._shared_name, 616 second_variable._shared_name, network_scope_name 617 )) 618 619 620def _restore_custom_map_func_error_message( 621 mapped_name, first_variable, second_variable, 622 network_name, network_scope_name): 623 return ( 624 ("The map_func passed to restore_network_checkpoint for the Network '%s' " 625 "resulted in two variables named '%s' (originally '%s' and '%s'). Since " 626 "this is also an error when saving, this Network was " 627 "probably not saved with this map_func. Note that map_func " 628 "always maps from full variable names to checkpoint names; " 629 "there is no need to specify an inverse mapping.\n\n" 630 "Try stripping less from the variable names, or renaming parts " 631 "of the Network. For reference, variables created by sub-Layers " 632 "of this Network are prefixed with '%s', but if they are " 633 "re-used after being added to another Network they will have " 634 "that Network's full variable prefix instead.") % ( 635 network_name, mapped_name, 636 first_variable._shared_name, 637 second_variable._shared_name, 638 network_scope_name)) 639 640 641def _make_custom_getter_for_deferred_restorations(): 642 """Returns a custom getter which searches `deferred_restorations`. 643 644 Returns: A tuple of (_custom_getter, deferred_restorations) 645 _custom_getter: The getter which should be added to variable_scopes where 646 variables will be created. 647 deferred_restorations: A list for _DeferredRestoration objects. Typically 648 empty when the getter is set, and expanded as deferred restorations are 649 requested. All new deferred restorations should be appended to the end of 650 the list, where they will have priority over older deferred restorations. 651 """ 652 deferred_restorations = [] 653 654 def _custom_getter(getter, name, shape=None, dtype=None, 655 initializer=None, 656 *args, **kwargs): 657 """A custom getter which processes deferred restorations.""" 658 # Iterate over restorations, newest first (newer restorations will take 659 # precedence over older restorations, just like with immediate restorations 660 # into existing variables). 661 delayed_restoration = None 662 found_value = False 663 value_to_restore = None 664 for delayed_restoration in reversed( 665 deferred_restorations): 666 checkpoint_name = delayed_restoration.map_func(name) 667 if (checkpoint_name 668 in delayed_restoration.checkpointed_variables_to_restore): 669 found_value = True 670 value_to_restore = ( 671 delayed_restoration.checkpointed_variables_to_restore[ 672 checkpoint_name]) 673 if found_value: 674 break 675 # value_to_restore may be False because this variable is not in any 676 # checkpoint we are restoring, or None because we have explicitly set it to 677 # None when it was previously fetched. In either case, we don't need to 678 # set an initializer. 679 if found_value and value_to_restore is not None: 680 initializer = value_to_restore 681 shape = None 682 variable = getter(name, shape=shape, dtype=dtype, initializer=initializer, 683 *args, **kwargs) 684 if found_value and value_to_restore is not None: 685 # Mark as already restored from this checkpoint. 686 delayed_restoration.checkpointed_variables_to_restore[ 687 checkpoint_name] = None 688 if not context.executing_eagerly(): 689 delayed_restoration.session.run(variable.initializer) 690 if found_value: 691 # Error checking should run even if we've already restored a value. 692 if delayed_restoration.restored_variables.setdefault( 693 checkpoint_name, variable) is not variable: 694 # Naming conflict. We've tried to initialize two variables with the 695 # same value from the checkpoint. 696 if delayed_restoration.map_func_is_user: 697 raise ValueError( 698 _restore_custom_map_func_error_message( 699 mapped_name=checkpoint_name, 700 first_variable=delayed_restoration.restored_variables[ 701 checkpoint_name], 702 second_variable=variable, 703 network_name=delayed_restoration.network_name, 704 network_scope_name=delayed_restoration.network_scope_name)) 705 else: 706 raise ValueError( 707 _default_naming_conflict_error_message( 708 mapped_name=checkpoint_name, 709 first_variable=delayed_restoration.restored_variables[ 710 checkpoint_name], 711 second_variable=variable, 712 network_name=delayed_restoration.network_name, 713 network_scope_name=delayed_restoration.network_scope_name)) 714 return variable 715 return _custom_getter, deferred_restorations 716 717 718def _make_prefix_stripping_map_fn(scope_name): 719 """Closure for stripping the scope name of a Network. 720 721 Implemented as a closure rather than a member function to avoid reference 722 cycles in deferred restorations (this function should not have a reference to 723 the Network which created it). 724 725 Args: 726 scope_name: The Network.scope_name to strip from variables. 727 Returns: 728 A scope_name-stripping default `map_fn` for the Network. 729 """ 730 731 def _strip_variable_prefix(original_variable_name): 732 """The default map_func for saving or restoring variables. 733 734 Strips the variable prefix for the Network on which save/restore was called, 735 and leaves other variable names fully qualified in the checkpoint. 736 737 Args: 738 original_variable_name: The _shared_name of the variable (no :0 739 suffix) to map. 740 Returns: 741 The checkpoint name of the variable. 742 """ 743 scope_name_with_slash = scope_name + "/" 744 if original_variable_name.startswith(scope_name_with_slash): 745 return original_variable_name[len(scope_name_with_slash):] 746 else: 747 return original_variable_name 748 749 return _strip_variable_prefix 750 751 752@deprecation.deprecated(date=None, instructions=( 753 "Please inherit from tf.keras.Model instead of tfe.Network, and use " 754 "tf.keras.Model.save_weights.")) 755def save_network_checkpoint( 756 network, save_path, global_step=None, map_func=None): 757 """Save variables from the Network to a checkpoint. 758 759 Args: 760 network: A Network object to save. 761 save_path: Either a checkpoint prefix or the name of a directory to save 762 the checkpoint in (in which case the checkpoint will be named based on 763 the Network name). 764 global_step: The global step to use when naming the checkpoint. If None 765 (default), we will first try to get the default global step. If that 766 fails because no default global step exists, then the checkpoint is 767 created without a global step suffix. 768 map_func: A function mapping fully qualified variable names 769 (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By 770 default (if `map_func=None`), the variable prefix for the network being 771 restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped 772 and all other variable names (shared with other Networks) are left 773 unchanged. 774 Returns: 775 The checkpoint prefix for the saved checkpoint, which may be passed to 776 `Network.restore`. 777 Raises: 778 ValueError: If the Network has not yet been called, or if map_func results 779 in a name collision. 780 """ 781 if not network.built: 782 raise ValueError( 783 "Attempt to save the Network before it was first called. This means " 784 "variables have not yet been created, so there is nothing to save.") 785 network._set_scope() # scope_name should be available to map_funcs 786 if global_step is None: 787 global_step = training_util.get_global_step() 788 if os.path.isdir(save_path): 789 # If we were passed a directory, default to naming based on the Network 790 # name. 791 save_path = os.path.join(save_path, network.name.replace("/", "_")) 792 user_map_func = map_func 793 if map_func is None: 794 map_func = _make_prefix_stripping_map_fn(network.scope_name) 795 variable_map = {} 796 for variable in network.variables: 797 mapped_name = map_func(variable._shared_name) 798 if variable_map.setdefault(mapped_name, variable) is not variable: 799 if user_map_func is None: 800 # Instead of erroring out, we could just re-try and silently use the 801 # full variable names in the checkpoint. This could be odd for deeply 802 # nested sub-Networks (since the full prefix from the nesting would 803 # get added), so for now we'll let the user deal with this case. 804 raise ValueError(_default_naming_conflict_error_message( 805 mapped_name=mapped_name, 806 first_variable=variable_map[mapped_name], 807 second_variable=variable, 808 network_name=network.name, 809 network_scope_name=network.scope_name)) 810 else: 811 # The user passed their own problematic map_func. 812 raise ValueError( 813 ("The map_func passed to save_network_checkpoint for the Network " 814 "'%s' resulted in two variables named '%s' ('%s' and '%s'). Try " 815 "stripping less from the variable names, or renaming parts of " 816 "the Network. For reference, variables created by sub-Layers of " 817 "this Network are prefixed with '%s', but if they are re-used " 818 "after being added to another Network, they will have that " 819 "Network's full variable prefix instead.") % ( 820 network.name, mapped_name, 821 variable_map[mapped_name]._shared_name, 822 variable._shared_name, 823 network.scope_name)) 824 if context.executing_eagerly(): 825 sess = None 826 else: 827 sess = ops.get_default_session() 828 return saver_lib.Saver(variable_map).save( 829 sess=sess, save_path=save_path, write_meta_graph=False, 830 global_step=global_step) 831 832 833def _add_deferred_restoration(layer, deferred_restoration): 834 """Add a deferred restoration to this Layer and all children. 835 836 Restorations which are requested later have higher priority, and the highest 837 priority matching restoration is applied to a variable when it is created. 838 839 Args: 840 layer: The Layer (may not be a Network) to operate on. 841 deferred_restoration: A _DeferredRestoration object. 842 """ 843 # Networks don't create variables at the moment, so this append isn't strictly 844 # necessary. We could get by with only adding deferred restorations to 845 # non-Network Layers. 846 if isinstance(layer, Network): 847 layer._set_scope() 848 # Make sure this Layer has a deferred restoration queue and a custom getter, 849 # then add our request to it. 850 if not hasattr(layer, "_custom_getter"): 851 assert not hasattr(layer, "_deferred_restorations") 852 layer._custom_getter, layer._deferred_restorations = ( 853 _make_custom_getter_for_deferred_restorations()) 854 # We use set_custom_getter because it avoids recursively calling up the 855 # variable_scope tree. We've done the tree traversal ourselves and have added 856 # the request to each Layer which needs it. 857 layer._scope.set_custom_getter(layer._custom_getter) 858 layer._deferred_restorations.append(deferred_restoration) 859 if isinstance(layer, Network): 860 for sublayer in layer.layers: 861 if not isinstance(sublayer, Network): 862 layer._set_scope_for_nonnetwork_sublayer(sublayer) 863 _add_deferred_restoration(sublayer, deferred_restoration) 864 865 866def _restore_existing_variables(network, save_path, map_func, user_map_func): 867 """Use a standard Saver to restore existing variables from a checkpoint. 868 869 Args: 870 network: A Network object to restore. 871 save_path: The checkpoint prefix or directory to read from. 872 map_func: The function to use when mapping from variable names to 873 checkpoint names. 874 user_map_func: The original map_func passed by the user, for error 875 checking. 876 Returns: 877 A dictionary mapping from checkpoint names to variable objects which have 878 been restored (for bookkeeping to avoid deferred restorations on these 879 variables). 880 Raises: 881 ValueError: If there is a name collision. 882 """ 883 existing_variables_by_checkpoint_name = {} 884 for variable in network.variables: 885 checkpoint_name = map_func(variable._shared_name) 886 if existing_variables_by_checkpoint_name.setdefault( 887 checkpoint_name, variable) is not variable: 888 if user_map_func is None: 889 raise ValueError(_default_naming_conflict_error_message( 890 mapped_name=checkpoint_name, 891 first_variable=existing_variables_by_checkpoint_name[ 892 checkpoint_name], 893 second_variable=variable, 894 network_name=network.name, 895 network_scope_name=network.scope_name)) 896 else: 897 raise ValueError(_restore_custom_map_func_error_message( 898 mapped_name=checkpoint_name, 899 first_variable=existing_variables_by_checkpoint_name[ 900 checkpoint_name], 901 second_variable=variable, 902 network_name=network.name, 903 network_scope_name=network.scope_name)) 904 if existing_variables_by_checkpoint_name: 905 if context.executing_eagerly(): 906 sess = None 907 else: 908 sess = ops.get_default_session() 909 saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( 910 sess=sess, save_path=save_path) 911 return existing_variables_by_checkpoint_name 912 913 914def _set_restore_on_create(network, save_path, map_func, user_map_func, 915 existing_variables_by_checkpoint_name): 916 """If necessary, request deferred restorations of variables.""" 917 checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) 918 checkpointed_variables_to_restore = {} 919 for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): 920 if checkpoint_name in existing_variables_by_checkpoint_name: 921 # This variable was already created and restored. 922 continue 923 # Save the variable for later restoration in a custom getter. 924 checkpointed_variables_to_restore[checkpoint_name] = ( 925 checkpoint_reader.get_tensor(checkpoint_name)) 926 # Only set a deferred restoration if there are checkpoint variables which 927 # have not been assigned to existing variables. Note that this loses out on 928 # some opportunity for error checking, but avoids creating 929 # _DeferredRestoration objects once a Network has been built (so that 930 # restoring in a loop does not take increasing amounts of memory). 931 if checkpointed_variables_to_restore: 932 if context.executing_eagerly(): 933 sess = None 934 else: 935 sess = ops.get_default_session() 936 # We need a name for error messages. If we haven't been added to another 937 # Network yet, we're top-level. 938 network._finalize_name(False) 939 network._set_scope() 940 # Save a record of this restoration for use in the custom getter. 941 deferred_restoration = _DeferredRestoration( 942 map_func=map_func, 943 map_func_is_user=(user_map_func is not None), 944 checkpointed_variables_to_restore=checkpointed_variables_to_restore, 945 restored_variables={}, 946 session=sess, 947 network_name=network.name, 948 network_scope_name=network.scope_name) 949 # Add the deferred registration to non-Network children, and request that 950 # Networks propagate the request to their children. 951 _add_deferred_restoration(network, deferred_restoration) 952 953 954@deprecation.deprecated(date=None, instructions=( 955 "Please inherit from tf.keras.Model instead of tfe.Network, and use " 956 "tf.keras.Model.load_weights.")) 957def restore_network_checkpoint(network, save_path, map_func=None): 958 """Restore the Network from a checkpoint. 959 960 If variables have already been created (typically when some or all of the 961 `Network` is built), they are assigned values from the checkpoint immediately, 962 overwriting any existing values (in graph mode the default session is used for 963 the assignments). 964 965 If there are checkpoint entries which do not correspond to any existing 966 variables in the `Network`, these values are saved for deferred restoration; 967 their initial values will be the checkpointed values once they are 968 created. Requests for multiple deferred restorations behave the same way as 969 immediate restorations, in that later requests will take priority over earlier 970 requests relevant to the same variable. 971 972 If this `Network` shares `Layer`s with another network, those `Layer`s will 973 also have their variables restored from the checkpoint. 974 975 Args: 976 network: A Network object to restore. 977 save_path: The return value of `tfe.save_network_checkpoint`, or a directory 978 to search for a checkpoint. 979 map_func: A function mapping fully qualified variable names 980 (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By 981 default (if `map_func=None`), the variable prefix for the network being 982 restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped 983 and all other variable names (shared with other Networks) are left 984 unchanged. Note that this is the _same_ map_func as 985 `tfe.save_network_checkpoint`, not an inverse mapping. 986 """ 987 network._finalize_name(parent_network=False) 988 network._set_scope() # scope_name should be available to map_funcs 989 if os.path.isdir(save_path): 990 # If we don't have a name yet, set no parent. 991 save_path = os.path.join(save_path, network.name.replace("/", "_")) 992 user_map_func = map_func 993 if map_func is None: 994 map_func = _make_prefix_stripping_map_fn(network.scope_name) 995 # Step one is to restore any existing variables from the checkpoint. 996 existing_variables_by_checkpoint_name = _restore_existing_variables( 997 network=network, 998 save_path=save_path, 999 map_func=map_func, 1000 user_map_func=user_map_func) 1001 # Step two is to set a custom getter which restores variables on creation, 1002 # for those variables which have not been added to sub-Layers yet. 1003 _set_restore_on_create( 1004 network=network, 1005 save_path=save_path, 1006 map_func=map_func, 1007 user_map_func=user_map_func, 1008 existing_variables_by_checkpoint_name=( 1009 existing_variables_by_checkpoint_name)) 1010