• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A Network is a composition of Layers."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import os
23import weakref
24
25from tensorflow.python.eager import context
26from tensorflow.python.framework import ops
27from tensorflow.python.keras.engine import base_layer_utils
28from tensorflow.python.layers import base
29from tensorflow.python.ops import variable_scope
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.training import checkpoint_utils
32from tensorflow.python.training import saver as saver_lib
33from tensorflow.python.training import training_util
34from tensorflow.python.util import deprecation
35from tensorflow.python.util import function_utils
36
37# pylint: disable=protected-access
38# Explanation for protected-access disable: Network has lots of same-class and
39# parent-class references across different objects, and some to private
40# functions in base.py which should be reused.
41
42
43def _network_name_scope_naming(current_variable_scope):
44  """Name scope naming to match operation names to variable names.
45
46  Used in Networks and also applied to non-Network Layers which are added to
47  Networks before being built.
48
49  Args:
50    current_variable_scope: A VariableScope object.
51  Returns:
52    A name scope name.
53  """
54  return current_variable_scope.name + "/"
55
56
57_NETWORK_DEPRECATION_MESSAGE = (
58    "Please inherit from `tf.keras.Model`, and see its documentation for "
59    "details. `tf.keras.Model` should be a drop-in replacement for "
60    "`tfe.Network` in most cases, but note that `track_layer` is no longer "
61    "necessary or supported. Instead, `Layer` instances are tracked on "
62    "attribute assignment (see the section of `tf.keras.Model`'s documentation "
63    "on subclassing). Since the output of `track_layer` is often assigned to "
64    "an attribute anyway, most code can be ported by simply removing the "
65    "`track_layer` calls.\n\n`tf.keras.Model` works with all TensorFlow "
66    "`Layer` instances, including those from `tf.layers`, but switching to "
67    "the `tf.keras.layers` versions along with the migration to "
68    "`tf.keras.Model` is recommended, since it will preserve variable names. "
69    "Feel free to import it with an alias to avoid excess typing :)."
70)
71
72
73class Network(base.Layer):
74  """Represents the composition of a set of Layers.
75
76  *Deprecated*. Please inherit from `tf.keras.Model`, and see its documentation
77  for details. `tf.keras.Model` should be a drop-in replacement for
78  `tfe.Network` in most cases, but note that `track_layer` is no longer
79  necessary or supported. Instead, `Layer` instances are tracked on attribute
80  assignment (see the section of `tf.keras.Model`'s documentation on
81  subclassing). Since the output of `track_layer` is often assigned to an
82  attribute anyway, most code can be ported by simply removing the `track_layer`
83  calls.
84
85  `tf.keras.Model` works with all TensorFlow `Layer` instances, including those
86  from `tf.layers`, but switching to the `tf.keras.layers` versions along with
87  the migration to `tf.keras.Model` is recommended, since it will preserve
88  variable names.  Feel free to import it with an alias to avoid excess typing
89  :).
90
91  `Network` implements the `Layer` interface and adds convenience methods for
92  managing sub-`Layer`s, such as listing variables.
93
94  `Layer`s (including other `Network`s) should be added via `track_layer`. They
95  can then be used when overriding the `Network.call` method:
96
97  ```python
98  class TwoLayerNetwork(tfe.Network):
99
100    def __init__(self, name):
101      super(TwoLayerNetwork, self).__init__(name=name)
102      self.layer_one = self.track_layer(tf.layers.Dense(16, input_shape=(8,)))
103      self.layer_two = self.track_layer(tf.layers.Dense(1, input_shape=(16,)))
104
105    def call(self, inputs):
106      return self.layer_two(self.layer_one(inputs))
107  ```
108
109  After constructing an object and calling the `Network`, a list of variables
110  created by tracked `Layer`s is available via `Network.variables`:
111
112  ```python
113  net = TwoLayerNetwork(name="net")
114  output = net(tf.ones([1, 8]))
115  print([v.name for v in net.variables])
116  ```
117
118  This example prints variable names, one kernel and one bias per
119  `tf.layers.Dense` layer:
120
121  ```
122  ['net/dense/kernel:0',
123   'net/dense/bias:0',
124   'net/dense_1/kernel:0',
125   'net/dense_1/bias:0']
126  ```
127
128  These variables can be passed to a `Saver` (`tf.train.Saver`, or
129  `tf.contrib.eager.Saver` when executing eagerly) to save or restore the
130  `Network`, typically alongside a global step and `tf.train.Optimizer`
131  variables when checkpointing during training.
132
133  Note that the semantics of calling a `Network` with graph execution (i.e. not
134  executing eagerly) may change slightly in the future. Currently stateful ops
135  are pruned from the graph unless they or something that depends on them is
136  executed in a session, but this behavior is not consistent with eager
137  execution (where stateful ops are executed eagerly). `Layer`s from `tf.layers`
138  do not depend on this pruning and so will not be affected, but `Network`s
139  which rely on stateful ops being added to the graph but not executed (e.g. via
140  custom `Layer`s which manage stateful ops) may break with this change.
141  """
142  # TODO(josh11b,ashankar,allenl):
143  # - Should 'trainable' be changeable on the Network object?
144  # - Do we allow add_variable in Network?
145  # - Detect layers used in __call__ that weren't registered with track_layer.
146  # - Convert inputs to __call__ to tensors.
147
148  @deprecation.deprecated(date=None, instructions=_NETWORK_DEPRECATION_MESSAGE)
149  def __init__(self, name=None):
150    """Configure the `Network`.
151
152    Args:
153      name: The name to use for this `Network`. If specified, it must be unique
154        in the context where this `Network` is first
155         (1) added to another `Network` (in which case it must not share a name
156           with other `Layers` added to that `Network`), or
157         (2) built/called (in which case no other 'top-level' `Network`s may
158          share this name).
159        If unspecified or None, the `Network` will be named using its class
160        name, with a number appended if necessary for uniqueness (e.g. MyNetwork
161        -> 'my_network_1').
162
163    Raises:
164      ValueError: If `name` is not valid. Note that some naming errors will
165        instead be raised when the `Network` is called.
166    """
167    if context.executing_eagerly():
168      logging.warning(
169          ("** tfe.Network is deprecated and will be removed in a future "
170           "version.\n\n%s") % _NETWORK_DEPRECATION_MESSAGE)
171    if isinstance(name, variable_scope.VariableScope):
172      raise ValueError("VariableScopes are not valid Network names.")
173    if name is not None and "/" in name:
174      raise ValueError(
175          "Forward slashes ('/') are not allowed in Network names.")
176    super(Network, self).__init__(name=name)
177    self._layers = []
178    self._sub_layer_name_uids = collections.defaultdict(int)
179    # Initially None, but set to False for networks which are first built as
180    # top-level.
181    self._first_parent = None  # A weak reference to our first parent.
182    self._non_network_sublayers = []
183    self._owned_layers = {}
184    # The scope to use if we end up without a parent.
185    self._default_parent_variable_scope = variable_scope.get_variable_scope()
186    # Hold on to the variable scope counts from init to check whether a scope
187    # with the name we want was ever created in our parent scope. Without this
188    # check we might have name collisions if the parent scope on init gets
189    # closed before build is called.
190    self._variable_scope_counts_on_init = (
191        variable_scope.get_variable_scope_store().variable_scopes_count)
192
193  def _gather_saveables_for_checkpoint(self):
194    raise NotImplementedError(
195        "tfe.Network does not support object-based checkpointing.\n\n%s"
196        % _NETWORK_DEPRECATION_MESSAGE)
197
198  def _name_scope_name(self, current_variable_scope):
199    """Overrides Layer op naming to match variable naming."""
200    return _network_name_scope_naming(
201        current_variable_scope=current_variable_scope)
202
203  def _init_set_name(self, name):
204    # Anonymous Networks (name=None) defer setting a final name until they are
205    # (1) added to another Network, or (2) built/called (where (2) is only used
206    # for a "top level" network).
207    #
208    # However, if we were provided an explicit name (name is not None), that
209    # will always be the final name of the Network; if it turns out not to be
210    # unique or if variable names can't be prefixed by it we will throw an
211    # error.
212    self._name = name
213    self._base_name = None
214
215  def _finalize_name(self, parent_network):
216    if not self._name:
217      # Were were not passed a name explicitly (or it was blank), so this is an
218      # anonymous Network. We make up a unique name.
219      if parent_network:
220        avoid_names = parent_network._owned_layers
221        name_uid_map = parent_network._sub_layer_name_uids
222      else:
223        name_uid_map = base_layer_utils.get_default_graph_uid_map()
224        # Figure out which names we have to avoid based on which variable scope
225        # we're nested in.
226        strip_name = self._default_parent_variable_scope.name
227        if strip_name:
228          strip_name += "/"
229        def _strip_on_init_scope(name):
230          if name.startswith(strip_name):
231            return name[len(strip_name):]
232          else:
233            return None
234        avoid_names = set(
235            _strip_on_init_scope(name)
236            for name in self._variable_scope_counts_on_init.keys() if name)
237      self._name, self._base_name = self._make_unique_name(
238          name_uid_map=name_uid_map, avoid_names=avoid_names,
239          namespace=self._default_parent_variable_scope.name,
240          zero_based=True)
241    if self._first_parent is None or (self._first_parent  # False = no parent
242                                      and self._first_parent() is None):
243      # Save a pointer to the parent Network so that we can later check that the
244      # scope name we get is correct.
245      if not parent_network:
246        self._first_parent = parent_network
247      else:
248        self._first_parent = weakref.ref(parent_network)
249
250  def _set_scope(self, scope=None):
251    if self._scope is None:
252      if not self._first_parent:
253        first_parent = self._first_parent
254      else:
255        first_parent = self._first_parent()
256      if first_parent is None:
257        # If we were never added to another Network, or that Network has beed
258        # garbage collected before being called, then we're a top-level Network.
259        self._finalize_name(
260            # Use False to make sure the value sticks and we don't inherit a
261            # parent if we're added to a network later.
262            parent_network=False)
263      if scope is not None:
264        raise ValueError("Networks may not be created with explicit scopes.")
265      if first_parent:
266        first_parent._set_scope()
267        parent_scope = first_parent._scope
268      else:
269        parent_scope = self._default_parent_variable_scope
270      with variable_scope.variable_scope(parent_scope) as parent_vs:
271        expected_scope_name = parent_vs.name + "/" + self._name
272        if expected_scope_name in self._variable_scope_counts_on_init:
273          raise ValueError(
274              ("A Network named '%s' already exists (or a variable_scope was "
275               "created with this name). Names must be unique.") % (
276                   self._name,))
277        # Make sure variables with this prefix will be unique.
278        with variable_scope.variable_scope(
279            None, use_resource=True, default_name=self._name) as scope:
280          self._scope = scope
281          scope_name = scope.name
282          suffix_start = scope_name.rfind("/") + 1
283          # rfind is -1 if there is no slash in the string, in which case the
284          # suffix starts at the beginning of the string (there is no prefix).
285          scope_suffix = scope_name[suffix_start:]
286          scope_prefix = scope_name[:suffix_start]
287          if scope_suffix != self._name:
288            raise ValueError(
289                ("A Network named '%s' already exists (or a variable_scope was "
290                 "created with this name). Names must be unique.") % (
291                     self._name,))
292          if (first_parent
293              and scope_prefix[:-1] != first_parent.scope_name):
294            raise ValueError(
295                ("Network variable names must match a nesting of sub-Network "
296                 "names. Expected prefix '%s' from parent network, but got "
297                 "'%s' when attempting to create a variable_scope for Network "
298                 "'%s'. Likely an explicit variable_scope was inserted into "
299                 "the nesting.") % (
300                     first_parent.scope_name,
301                     scope_prefix[:-1],
302                     self._name))
303          elif not first_parent and scope_prefix:
304            # For the case when this Network is not nested inside any other
305            # Network, but is in a variable_scope. This Network's name takes on
306            # the full variable scope prefix.
307            self._name = scope_name
308
309      for non_network_sublayer in self._non_network_sublayers:
310        self._set_scope_for_nonnetwork_sublayer(non_network_sublayer)
311
312  def _set_scope_for_nonnetwork_sublayer(self, sublayer):
313    if sublayer._scope is None:
314      if sublayer._first_parent is None:
315        constituent_first_parent = None
316      else:
317        constituent_first_parent = sublayer._first_parent()
318      if constituent_first_parent:
319        constituent_first_parent._set_scope()
320        parent_scope = constituent_first_parent._scope
321      else:
322        self._finalize_name(False)
323        raise ValueError(
324            ("The parent of a Layer added to Network %s was garbage collected "
325             "before the Layer was built. If this limitation bothers you "
326             "please file a feature request.") %
327            (self.name,))
328      with variable_scope.variable_scope(parent_scope):
329        # Horrid hack to make Layer variable names which are direct
330        # sub-layers of Networks conform to the Network variable naming
331        # conventions.
332        with variable_scope.variable_scope(
333            None, use_resource=True,
334            default_name=sublayer.name) as sub_scope:
335          sublayer._scope = sub_scope
336          # Also switch op naming for this Layer to match Network conventions,
337          # i.e. op naming matching variable naming.
338          sublayer._name_scope_name = _network_name_scope_naming
339
340  @base.Layer.name.getter
341  def name(self):
342    if self._name is None:
343      raise ValueError(
344          "The network does not yet have a final name, but a name was "
345          "requested for it. Networks get a name when they are added to "
346          "another Network via track_layer, or when they are first "
347          "called/built.")
348    return self._name
349
350  def track_layer(self, layer):
351    """Track a Layer in this Network.
352
353    `Network` requires that all `Layer`s used in `call()` be tracked so that the
354    `Network` can export a complete list of variables.
355
356    Args:
357      layer: A `tf.layers.Layer` object.
358
359    Returns:
360      The passed in `layer`.
361
362    Raises:
363      RuntimeError: If __init__ has not been called.
364      TypeError: If `layer` is the wrong type.
365      ValueError: If a `Layer` with the same name has already been added.
366    """
367    if not hasattr(self, "_layers"):
368      raise RuntimeError("Need to call Network.__init__ before adding layers")
369    if not isinstance(layer, base.Layer):
370      raise TypeError(
371          "Network.track_layer() passed type %s, not a tf.layers.Layer" %
372          (type(layer),))
373    # Always use `ResourceVariable` with legacy layers.
374    layer._use_resource_variables = True
375    if isinstance(layer, Network):
376      layer._finalize_name(parent_network=self)
377    else:
378      # `layer` is a non-Network, so it hasn't been named to follow Network
379      # conventions for contained Layers (i.e. the same conventions as for
380      # sub-Networks). This renaming is necessary to isolate Network variable
381      # naming from Layers constructed outside the Network and never added to it
382      # (because Layers are named globally).
383      if not layer.built:
384        if not hasattr(layer, "_first_parent"):
385          dereferenced_layer_first_parent = None
386        else:
387          dereferenced_layer_first_parent = layer._first_parent()
388        if dereferenced_layer_first_parent is None:
389          if layer._name != layer._base_name:
390            # If name and base_name do not match, then this Layer used anonymous
391            # naming and we have to rename it. Otherwise there's an explicit
392            # name, and we should respect it (subject to error checking).
393            layer._name, layer._base_name = layer._make_unique_name(
394                name_uid_map=self._sub_layer_name_uids,
395                avoid_names=self._owned_layers,
396                zero_based=True
397                # No namespace required, since we've specified our own UID map.
398            )
399          layer._first_parent = weakref.ref(self)
400        self._non_network_sublayers.append(layer)
401    if (not layer.built
402        and layer._first_parent
403        and self is layer._first_parent()):
404      if layer.name in self._owned_layers:
405        if self._owned_layers[layer.name] is layer:
406          return layer
407        raise ValueError(
408            "Attempt to add two Layers with the name '%s' to the same Network."
409            % (layer.name))
410      self._owned_layers[layer.name] = layer
411    self._layers.append(layer)
412    return layer
413
414  def get_layer(self, name=None, index=None):
415    """Get a contained `tf.layers.Layer` either by name or index.
416
417    Args:
418      name: String matching one of the names of a contained `Layer`. Note that
419        the names of `Layer`s added to `Network`s may not be unique when doing
420        layer sharing (i.e. adding a `Layer` to this `Network` which was already
421        added to another `Network`). The lowest index `Layer` with a matching
422        name will be returned.
423      index: Integer in [0, number of layers). Layers are assigned an index
424        by the order they are added.
425
426    Returns:
427      A `tf.layers.Layer` object.
428
429    Raises:
430      ValueError: If neither or both of 'index' or 'name' is specified, or the
431        lookup failed.
432    """
433    if index is not None:
434      if name is not None:
435        raise ValueError("Exactly one of 'index' or 'name' must be provided")
436      if len(self._layers) <= index:
437        raise ValueError("Was asked to retrieve layer at index " + str(index) +
438                         " but model only has " + str(len(self._layers)) +
439                         " layers.")
440      else:
441        return self._layers[index]
442    else:
443      if not name:
444        raise ValueError("Provide either a layer name or layer index.")
445    for layer in self._layers:
446      if layer.name == name:
447        return layer
448    raise ValueError("No such layer: " + name)
449
450  # The following methods are for implementing the Layer interface.
451
452  @property
453  def weights(self):
454    # TODO(josh11b): Should this return a set or perform de-duplication of
455    # variables in the case of shared layers/variables that appear in
456    # multiple places in the Network?
457    weights = []
458    for layer in self._layers:
459      weights += layer.weights
460    return weights
461
462  @property
463  def trainable_weights(self):
464    weights = []
465    for layer in self._layers:
466      weights += layer.trainable_weights
467    return weights
468
469  @property
470  def non_trainable_weights(self):
471    weights = []
472    for layer in self._layers:
473      weights += layer.non_trainable_weights
474    return weights
475
476  @property
477  def trainable(self):
478    return True
479
480  @trainable.setter
481  def trainable(self, value):
482    if not value:
483      # We believe it better to decide which layers & networks are trainable
484      # at the Trainer level than here. Otherwise you can run into trouble if a
485      # layer/network is shared between two models, but is trainable in one
486      # but not the other (like with adversarial networks).
487      raise AttributeError("cannot mark Network as not trainable")
488
489  @property
490  def layers(self):
491    return self._layers
492
493  def add_variable(self, name, shape, dtype=None, initializer=None,
494                   regularizer=None, trainable=True, constraint=None):
495    raise RuntimeError(
496        "add_variable not supported in Network class yet. Please file an issue "
497        "at https://github.com/tensorflow/tensorflow/issues/new if this is "
498        "important to you")
499
500  def add_loss(self, losses, inputs=None):
501    raise RuntimeError(
502        "add_loss is not supported in Network class yet. Please file an issue "
503        "at https://github.com/tensorflow/tensorflow/issues/new if this is "
504        "important to you")
505
506  @property
507  def losses(self):
508    """Gather losses from `Layer`s in the `Network`.
509
510    Note that when executing eagerly, `Layer.losses` evaluates
511    regularizers. When using graph execution, variable regularization ops have
512    already been created and are simply returned here.
513
514    Returns:
515      A list of tensors.
516    """
517    layer_losses = []
518    for layer in self.layers:
519      layer_losses.extend(layer.losses)
520    return layer_losses
521
522  # TODO(allenl): Support other Layer methods needed for graph mode, such as for
523  # updates
524
525
526class Sequential(Network):
527  """Represents a linear sequence of Layers or functions.
528
529  The output of each layer/function is provided as the input to the next.
530  The inputs passed to `__call__` are passed to the inputs of the first
531  Layer, and it returns the outputs of the last Layer.
532
533  Args:
534    layers_funcs: An optional sequence where each element is either a
535      tf.layers.Layer object or a callable.
536    name: An optional string name to use for this Network.
537  """
538
539  def __init__(self, layers_funcs=None, name=None):
540    super(Sequential, self).__init__(name=name)
541    self._layers_funcs = []
542    if layers_funcs:
543      for l in layers_funcs:
544        self.add(l)
545
546  def add(self, layer_func):
547    if isinstance(layer_func, base.Layer):
548      args = function_utils.fn_args(layer_func.call)
549      self.track_layer(layer_func)
550    elif callable(layer_func):
551      args = function_utils.fn_args(layer_func)
552    else:
553      raise TypeError(
554          "Sequential.add() takes only tf.layers.Layer objects or callables; "
555          "not '%s' of type '%s'." % (layer_func, type(layer_func)))
556    self._layers_funcs.append((("training" in args), layer_func))
557
558  def call(self, inputs, training=None):
559    """Call each Layer in the order they were added."""
560    # TODO(josh11b): Support "mode" and maybe other arguments
561    if training is None:
562      for _, l in self._layers_funcs:
563        inputs = l(inputs)
564    else:
565      for has_training_arg, l in self._layers_funcs:
566        if has_training_arg:
567          inputs = l(inputs, training)
568        else:
569          inputs = l(inputs)
570    return inputs
571
572
573_DeferredRestoration = collections.namedtuple(
574
575    "_DeferredRestoration",
576    [
577        # The map_func to use (either user-specified or the default).
578        "map_func",
579        # Boolean, True if the user specified an explicit map_func, for error
580        # messages.
581        "map_func_is_user",
582        # A mapping from checkpoint names to initial values of not-yet-created
583        # variables which should be restored. These values come from parsing a
584        # checkpoint.
585        "checkpointed_variables_to_restore",
586        # A mapping from checkpoint name to variable objects of variables which
587        # have already been restored, for error checking.
588        "restored_variables",
589        # The session to restore with (if in graph mode).
590        "session",
591        # Names of the Network where the restore was requested, for error
592        # messages.
593        "network_name",
594        "network_scope_name"
595    ])
596
597
598def _default_naming_conflict_error_message(
599    mapped_name, first_variable, second_variable,
600    network_name, network_scope_name):
601  return (
602      ("The default checkpoint variable name mapping strategy for Network "
603       "'%s' resulted in a naming conflict. We attempted to strip off the "
604       "variable prefix for the Network ('%s'), but this resulted in two "
605       "variables named '%s' (originally '%s' and '%s'). This should only "
606       "happen when using variable sharing (i.e. the Network contains Networks "
607       "or Layers which were first added to another Network, and therefore "
608       "have that Network's variable prefix). One solution is to pass "
609       "`map_func=lambda n: n` to save and restore to use fully qualified "
610       "variable names in the checkpoint, although this will require that the "
611       "variable prefix of the Network being restored into is also '%s'. You "
612       "may alternatively write an arbitrary mapping.")
613      % (
614          network_name, network_scope_name, mapped_name,
615          first_variable._shared_name,
616          second_variable._shared_name, network_scope_name
617      ))
618
619
620def _restore_custom_map_func_error_message(
621    mapped_name, first_variable, second_variable,
622    network_name, network_scope_name):
623  return (
624      ("The map_func passed to restore_network_checkpoint for the Network '%s' "
625       "resulted in two variables named '%s' (originally '%s' and '%s'). Since "
626       "this is also an error when saving, this Network was "
627       "probably not saved with this map_func. Note that map_func "
628       "always maps from full variable names to checkpoint names; "
629       "there is no need to specify an inverse mapping.\n\n"
630       "Try stripping less from the variable names, or renaming parts "
631       "of the Network. For reference, variables created by sub-Layers "
632       "of this Network are prefixed with '%s', but if they are "
633       "re-used after being added to another Network they will have "
634       "that Network's full variable prefix instead.") % (
635           network_name, mapped_name,
636           first_variable._shared_name,
637           second_variable._shared_name,
638           network_scope_name))
639
640
641def _make_custom_getter_for_deferred_restorations():
642  """Returns a custom getter which searches `deferred_restorations`.
643
644  Returns: A tuple of (_custom_getter, deferred_restorations)
645    _custom_getter: The getter which should be added to variable_scopes where
646      variables will be created.
647    deferred_restorations: A list for _DeferredRestoration objects. Typically
648      empty when the getter is set, and expanded as deferred restorations are
649      requested. All new deferred restorations should be appended to the end of
650      the list, where they will have priority over older deferred restorations.
651  """
652  deferred_restorations = []
653
654  def _custom_getter(getter, name, shape=None, dtype=None,
655                     initializer=None,
656                     *args, **kwargs):
657    """A custom getter which processes deferred restorations."""
658    # Iterate over restorations, newest first (newer restorations will take
659    # precedence over older restorations, just like with immediate restorations
660    # into existing variables).
661    delayed_restoration = None
662    found_value = False
663    value_to_restore = None
664    for delayed_restoration in reversed(
665        deferred_restorations):
666      checkpoint_name = delayed_restoration.map_func(name)
667      if (checkpoint_name
668          in delayed_restoration.checkpointed_variables_to_restore):
669        found_value = True
670        value_to_restore = (
671            delayed_restoration.checkpointed_variables_to_restore[
672                checkpoint_name])
673      if found_value:
674        break
675    # value_to_restore may be False because this variable is not in any
676    # checkpoint we are restoring, or None because we have explicitly set it to
677    # None when it was previously fetched. In either case, we don't need to
678    # set an initializer.
679    if found_value and value_to_restore is not None:
680      initializer = value_to_restore
681      shape = None
682    variable = getter(name, shape=shape, dtype=dtype, initializer=initializer,
683                      *args, **kwargs)
684    if found_value and value_to_restore is not None:
685      # Mark as already restored from this checkpoint.
686      delayed_restoration.checkpointed_variables_to_restore[
687          checkpoint_name] = None
688      if not context.executing_eagerly():
689        delayed_restoration.session.run(variable.initializer)
690    if found_value:
691      # Error checking should run even if we've already restored a value.
692      if delayed_restoration.restored_variables.setdefault(
693          checkpoint_name, variable) is not variable:
694        # Naming conflict. We've tried to initialize two variables with the
695        # same value from the checkpoint.
696        if delayed_restoration.map_func_is_user:
697          raise ValueError(
698              _restore_custom_map_func_error_message(
699                  mapped_name=checkpoint_name,
700                  first_variable=delayed_restoration.restored_variables[
701                      checkpoint_name],
702                  second_variable=variable,
703                  network_name=delayed_restoration.network_name,
704                  network_scope_name=delayed_restoration.network_scope_name))
705        else:
706          raise ValueError(
707              _default_naming_conflict_error_message(
708                  mapped_name=checkpoint_name,
709                  first_variable=delayed_restoration.restored_variables[
710                      checkpoint_name],
711                  second_variable=variable,
712                  network_name=delayed_restoration.network_name,
713                  network_scope_name=delayed_restoration.network_scope_name))
714    return variable
715  return _custom_getter, deferred_restorations
716
717
718def _make_prefix_stripping_map_fn(scope_name):
719  """Closure for stripping the scope name of a Network.
720
721  Implemented as a closure rather than a member function to avoid reference
722  cycles in deferred restorations (this function should not have a reference to
723  the Network which created it).
724
725  Args:
726    scope_name: The Network.scope_name to strip from variables.
727  Returns:
728    A scope_name-stripping default `map_fn` for the Network.
729  """
730
731  def _strip_variable_prefix(original_variable_name):
732    """The default map_func for saving or restoring variables.
733
734    Strips the variable prefix for the Network on which save/restore was called,
735    and leaves other variable names fully qualified in the checkpoint.
736
737    Args:
738      original_variable_name: The _shared_name of the variable (no :0
739        suffix) to map.
740    Returns:
741      The checkpoint name of the variable.
742    """
743    scope_name_with_slash = scope_name + "/"
744    if original_variable_name.startswith(scope_name_with_slash):
745      return original_variable_name[len(scope_name_with_slash):]
746    else:
747      return original_variable_name
748
749  return _strip_variable_prefix
750
751
752@deprecation.deprecated(date=None, instructions=(
753    "Please inherit from tf.keras.Model instead of tfe.Network, and use "
754    "tf.keras.Model.save_weights."))
755def save_network_checkpoint(
756    network, save_path, global_step=None, map_func=None):
757  """Save variables from the Network to a checkpoint.
758
759  Args:
760    network: A Network object to save.
761    save_path: Either a checkpoint prefix or the name of a directory to save
762      the checkpoint in (in which case the checkpoint will be named based on
763      the Network name).
764    global_step: The global step to use when naming the checkpoint. If None
765      (default), we will first try to get the default global step. If that
766      fails because no default global step exists, then the checkpoint is
767      created without a global step suffix.
768    map_func: A function mapping fully qualified variable names
769      (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By
770      default (if `map_func=None`), the variable prefix for the network being
771      restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped
772      and all other variable names (shared with other Networks) are left
773      unchanged.
774  Returns:
775    The checkpoint prefix for the saved checkpoint, which may be passed to
776    `Network.restore`.
777  Raises:
778    ValueError: If the Network has not yet been called, or if map_func results
779      in a name collision.
780  """
781  if not network.built:
782    raise ValueError(
783        "Attempt to save the Network before it was first called. This means "
784        "variables have not yet been created, so there is nothing to save.")
785  network._set_scope()  # scope_name should be available to map_funcs
786  if global_step is None:
787    global_step = training_util.get_global_step()
788  if os.path.isdir(save_path):
789    # If we were passed a directory, default to naming based on the Network
790    # name.
791    save_path = os.path.join(save_path, network.name.replace("/", "_"))
792  user_map_func = map_func
793  if map_func is None:
794    map_func = _make_prefix_stripping_map_fn(network.scope_name)
795  variable_map = {}
796  for variable in network.variables:
797    mapped_name = map_func(variable._shared_name)
798    if variable_map.setdefault(mapped_name, variable) is not variable:
799      if user_map_func is None:
800        # Instead of erroring out, we could just re-try and silently use the
801        # full variable names in the checkpoint. This could be odd for deeply
802        # nested sub-Networks (since the full prefix from the nesting would
803        # get added), so for now we'll let the user deal with this case.
804        raise ValueError(_default_naming_conflict_error_message(
805            mapped_name=mapped_name,
806            first_variable=variable_map[mapped_name],
807            second_variable=variable,
808            network_name=network.name,
809            network_scope_name=network.scope_name))
810      else:
811        # The user passed their own problematic map_func.
812        raise ValueError(
813            ("The map_func passed to save_network_checkpoint for the Network "
814             "'%s' resulted in two variables named '%s' ('%s' and '%s'). Try "
815             "stripping less from the variable names, or renaming parts of "
816             "the Network. For reference, variables created by sub-Layers of "
817             "this Network are prefixed with '%s', but if they are re-used "
818             "after being added to another Network, they will have that "
819             "Network's full variable prefix instead.") % (
820                 network.name, mapped_name,
821                 variable_map[mapped_name]._shared_name,
822                 variable._shared_name,
823                 network.scope_name))
824  if context.executing_eagerly():
825    sess = None
826  else:
827    sess = ops.get_default_session()
828  return saver_lib.Saver(variable_map).save(
829      sess=sess, save_path=save_path, write_meta_graph=False,
830      global_step=global_step)
831
832
833def _add_deferred_restoration(layer, deferred_restoration):
834  """Add a deferred restoration to this Layer and all children.
835
836  Restorations which are requested later have higher priority, and the highest
837  priority matching restoration is applied to a variable when it is created.
838
839  Args:
840    layer: The Layer (may not be a Network) to operate on.
841    deferred_restoration: A _DeferredRestoration object.
842  """
843  # Networks don't create variables at the moment, so this append isn't strictly
844  # necessary. We could get by with only adding deferred restorations to
845  # non-Network Layers.
846  if isinstance(layer, Network):
847    layer._set_scope()
848  # Make sure this Layer has a deferred restoration queue and a custom getter,
849  # then add our request to it.
850  if not hasattr(layer, "_custom_getter"):
851    assert not hasattr(layer, "_deferred_restorations")
852    layer._custom_getter, layer._deferred_restorations = (
853        _make_custom_getter_for_deferred_restorations())
854  # We use set_custom_getter because it avoids recursively calling up the
855  # variable_scope tree. We've done the tree traversal ourselves and have added
856  # the request to each Layer which needs it.
857  layer._scope.set_custom_getter(layer._custom_getter)
858  layer._deferred_restorations.append(deferred_restoration)
859  if isinstance(layer, Network):
860    for sublayer in layer.layers:
861      if not isinstance(sublayer, Network):
862        layer._set_scope_for_nonnetwork_sublayer(sublayer)
863      _add_deferred_restoration(sublayer, deferred_restoration)
864
865
866def _restore_existing_variables(network, save_path, map_func, user_map_func):
867  """Use a standard Saver to restore existing variables from a checkpoint.
868
869  Args:
870    network: A Network object to restore.
871    save_path: The checkpoint prefix or directory to read from.
872    map_func: The function to use when mapping from variable names to
873      checkpoint names.
874    user_map_func: The original map_func passed by the user, for error
875      checking.
876  Returns:
877    A dictionary mapping from checkpoint names to variable objects which have
878    been restored (for bookkeeping to avoid deferred restorations on these
879    variables).
880  Raises:
881    ValueError: If there is a name collision.
882  """
883  existing_variables_by_checkpoint_name = {}
884  for variable in network.variables:
885    checkpoint_name = map_func(variable._shared_name)
886    if existing_variables_by_checkpoint_name.setdefault(
887        checkpoint_name, variable) is not variable:
888      if user_map_func is None:
889        raise ValueError(_default_naming_conflict_error_message(
890            mapped_name=checkpoint_name,
891            first_variable=existing_variables_by_checkpoint_name[
892                checkpoint_name],
893            second_variable=variable,
894            network_name=network.name,
895            network_scope_name=network.scope_name))
896      else:
897        raise ValueError(_restore_custom_map_func_error_message(
898            mapped_name=checkpoint_name,
899            first_variable=existing_variables_by_checkpoint_name[
900                checkpoint_name],
901            second_variable=variable,
902            network_name=network.name,
903            network_scope_name=network.scope_name))
904  if existing_variables_by_checkpoint_name:
905    if context.executing_eagerly():
906      sess = None
907    else:
908      sess = ops.get_default_session()
909    saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore(
910        sess=sess, save_path=save_path)
911  return existing_variables_by_checkpoint_name
912
913
914def _set_restore_on_create(network, save_path, map_func, user_map_func,
915                           existing_variables_by_checkpoint_name):
916  """If necessary, request deferred restorations of variables."""
917  checkpoint_reader = checkpoint_utils.load_checkpoint(save_path)
918  checkpointed_variables_to_restore = {}
919  for checkpoint_name, _ in checkpoint_utils.list_variables(save_path):
920    if checkpoint_name in existing_variables_by_checkpoint_name:
921      # This variable was already created and restored.
922      continue
923    # Save the variable for later restoration in a custom getter.
924    checkpointed_variables_to_restore[checkpoint_name] = (
925        checkpoint_reader.get_tensor(checkpoint_name))
926  # Only set a deferred restoration if there are checkpoint variables which
927  # have not been assigned to existing variables. Note that this loses out on
928  # some opportunity for error checking, but avoids creating
929  # _DeferredRestoration objects once a Network has been built (so that
930  # restoring in a loop does not take increasing amounts of memory).
931  if checkpointed_variables_to_restore:
932    if context.executing_eagerly():
933      sess = None
934    else:
935      sess = ops.get_default_session()
936    # We need a name for error messages. If we haven't been added to another
937    # Network yet, we're top-level.
938    network._finalize_name(False)
939    network._set_scope()
940    # Save a record of this restoration for use in the custom getter.
941    deferred_restoration = _DeferredRestoration(
942        map_func=map_func,
943        map_func_is_user=(user_map_func is not None),
944        checkpointed_variables_to_restore=checkpointed_variables_to_restore,
945        restored_variables={},
946        session=sess,
947        network_name=network.name,
948        network_scope_name=network.scope_name)
949    # Add the deferred registration to non-Network children, and request that
950    # Networks propagate the request to their children.
951    _add_deferred_restoration(network, deferred_restoration)
952
953
954@deprecation.deprecated(date=None, instructions=(
955    "Please inherit from tf.keras.Model instead of tfe.Network, and use "
956    "tf.keras.Model.load_weights."))
957def restore_network_checkpoint(network, save_path, map_func=None):
958  """Restore the Network from a checkpoint.
959
960  If variables have already been created (typically when some or all of the
961  `Network` is built), they are assigned values from the checkpoint immediately,
962  overwriting any existing values (in graph mode the default session is used for
963  the assignments).
964
965  If there are checkpoint entries which do not correspond to any existing
966  variables in the `Network`, these values are saved for deferred restoration;
967  their initial values will be the checkpointed values once they are
968  created. Requests for multiple deferred restorations behave the same way as
969  immediate restorations, in that later requests will take priority over earlier
970  requests relevant to the same variable.
971
972  If this `Network` shares `Layer`s with another network, those `Layer`s will
973  also have their variables restored from the checkpoint.
974
975  Args:
976    network: A Network object to restore.
977    save_path: The return value of `tfe.save_network_checkpoint`, or a directory
978      to search for a checkpoint.
979    map_func: A function mapping fully qualified variable names
980      (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By
981      default (if `map_func=None`), the variable prefix for the network being
982      restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped
983      and all other variable names (shared with other Networks) are left
984      unchanged. Note that this is the _same_ map_func as
985      `tfe.save_network_checkpoint`, not an inverse mapping.
986  """
987  network._finalize_name(parent_network=False)
988  network._set_scope()  # scope_name should be available to map_funcs
989  if os.path.isdir(save_path):
990    # If we don't have a name yet, set no parent.
991    save_path = os.path.join(save_path, network.name.replace("/", "_"))
992  user_map_func = map_func
993  if map_func is None:
994    map_func = _make_prefix_stripping_map_fn(network.scope_name)
995  # Step one is to restore any existing variables from the checkpoint.
996  existing_variables_by_checkpoint_name = _restore_existing_variables(
997      network=network,
998      save_path=save_path,
999      map_func=map_func,
1000      user_map_func=user_map_func)
1001  # Step two is to set a custom getter which restores variables on creation,
1002  # for those variables which have not been added to sub-Layers yet.
1003  _set_restore_on_create(
1004      network=network,
1005      save_path=save_path,
1006      map_func=map_func,
1007      user_map_func=user_map_func,
1008      existing_variables_by_checkpoint_name=(
1009          existing_variables_by_checkpoint_name))
1010