• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15# pylint: disable=g-classes-have-attributes
16"""Contains a shim to allow using TF1 get_variable code in TF2."""
17import functools
18
19from tensorflow.python.eager import context
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.keras.engine import base_layer
24from tensorflow.python.keras.utils import tf_contextlib
25from tensorflow.python.keras.utils import tf_inspect
26from tensorflow.python.module import module
27from tensorflow.python.ops import init_ops
28from tensorflow.python.ops import variable_scope as vs
29from tensorflow.python.ops import variables
30from tensorflow.python.platform import tf_logging as logging
31from tensorflow.python.util import tf_decorator
32
33
34def as_shape(shape):
35  """Converts the given object to a TensorShape."""
36  if isinstance(shape, tensor_shape.TensorShape):
37    return shape
38  else:
39    return tensor_shape.TensorShape(shape)
40
41
42def _is_callable_object(obj):
43  return hasattr(obj, "__call__") and tf_inspect.ismethod(obj.__call__)
44
45
46def _has_kwargs(fn):
47  """Returns whether the passed callable has **kwargs in its signature.
48
49  Args:
50    fn: Function, or function-like object (e.g., result of `functools.partial`).
51
52  Returns:
53    `bool`: if `fn` has **kwargs in its signature.
54
55  Raises:
56     `TypeError`: If fn is not a Function, or function-like object.
57  """
58  if isinstance(fn, functools.partial):
59    fn = fn.func
60  elif _is_callable_object(fn):
61    fn = fn.__call__
62  elif not callable(fn):
63    raise TypeError(
64        "fn should be a function-like object, but is of type {}.".format(
65            type(fn)))
66  return tf_inspect.getfullargspec(fn).varkw is not None
67
68
69def fn_args(fn):
70  """Get argument names for function-like object.
71
72  Args:
73    fn: Function, or function-like object (e.g., result of `functools.partial`).
74
75  Returns:
76    `tuple` of string argument names.
77
78  Raises:
79    ValueError: if partial function has positionally bound arguments
80  """
81  if isinstance(fn, functools.partial):
82    args = fn_args(fn.func)
83    args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])]
84  else:
85    if hasattr(fn, "__call__") and tf_inspect.ismethod(fn.__call__):
86      fn = fn.__call__
87    args = tf_inspect.getfullargspec(fn).args
88    if _is_bound_method(fn) and args:
89      # If it's a bound method, it may or may not have a self/cls first
90      # argument; for example, self could be captured in *args.
91      # If it does have a positional argument, it is self/cls.
92      args.pop(0)
93  return tuple(args)
94
95
96def _is_bound_method(fn):
97  _, fn = tf_decorator.unwrap(fn)
98  return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
99
100
101def validate_synchronization_aggregation_trainable(
102    synchronization, aggregation, trainable, name):
103  """Given user-provided variable properties, sets defaults and validates."""
104  if aggregation is None:
105    aggregation = variables.VariableAggregation.NONE
106  else:
107    if not isinstance(aggregation,
108                      (variables.VariableAggregation,
109                       variables.VariableAggregationV2)):
110      try:
111        aggregation = variables.VariableAggregationV2(aggregation)
112      except ValueError:
113        raise ValueError(
114            "Invalid variable aggregation mode: {} for variable: {}".format(
115                aggregation, name))
116  if synchronization is None:
117    synchronization = variables.VariableSynchronization.AUTO
118  else:
119    try:
120      synchronization = variables.VariableSynchronization(synchronization)
121    except ValueError:
122      raise ValueError(
123          "Invalid variable synchronization mode: {} for variable: {}".format(
124              synchronization, name))
125  if trainable is None:
126    trainable = synchronization != variables.VariableSynchronization.ON_READ
127  return synchronization, aggregation, trainable
128
129
130class _EagerVariableStore(object):
131  """TF2-compatible VariableStore that avoids collections & tracks regularizers.
132
133  New variable names and new variables can be created; all stored
134  variables are initialized with the initializer passed to __init__.
135
136  All variables get created in `tf.init_scope.` to avoid a bad
137  interaction between `tf.function` `FuncGraph` internals, Keras
138  Functional Models, and TPUStrategy variable initialization.
139
140  Attributes:
141    vars: a dictionary with string names (same as passed in GetVar) as keys and
142      the corresponding TensorFlow Variables as values.
143  """
144
145  __slots__ = ["_vars", "_regularizers", "_store_eager_variables"]
146
147  def __init__(self):
148    """Create a variable store."""
149    self._vars = {}  # A dictionary of the stored TensorFlow variables.
150    self._regularizers = {}  # A dict mapping var names to their regularizers.
151    self._store_eager_variables = True
152
153  def get_variable(
154      self,
155      name,
156      shape=None,
157      dtype=dtypes.float32,
158      initializer=None,
159      regularizer=None,
160      reuse=None,
161      trainable=None,
162      collections=None,
163      caching_device=None,
164      partitioner=None,
165      validate_shape=True,
166      use_resource=None,
167      custom_getter=None,
168      constraint=None,
169      synchronization=vs.VariableSynchronization.AUTO,
170      aggregation=vs.VariableAggregation.NONE):
171    """Gets an existing variable with these parameters or create a new one.
172
173    If a variable with the given name is already stored, we return the stored
174    variable. Otherwise, we create a new one.
175
176    Set `reuse` to `True` when you only want to reuse existing Variables.
177    Set `reuse` to `False` when you only want to create new Variables.
178    Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want
179    variables to be created if they don't exist or returned if they do.
180
181    If initializer is `None` (the default), the default initializer passed in
182    the constructor is used. If that one is `None` too, we use a new
183    `glorot_uniform_initializer`. If initializer is a Tensor, we use
184    it as a value and derive the shape from the initializer.
185
186    If a partitioner is provided, a `PartitionedVariable` is returned.
187    Accessing this object as a `Tensor` returns the shards concatenated along
188    the partition axis.
189
190    Some useful partitioners are available.  See, e.g.,
191    `variable_axis_size_partitioner` and `min_max_variable_partitioner`.
192
193    Args:
194      name: The name of the new or existing variable.
195      shape: Shape of the new or existing variable.
196      dtype: Type of the new or existing variable (defaults to `DT_FLOAT`).
197      initializer: Initializer for the variable.
198      regularizer: A (Tensor -> Tensor or None) function; the result of applying
199        it on a newly created variable will be added to the collection
200        GraphKeys.REGULARIZATION_LOSSES and can be used for regularization.
201      reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of
202        variables. When eager execution is enabled  this argument is always
203        forced to be False.
204      trainable: If `True` also add the variable to the graph collection
205        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable`
206        defaults to `True`, unless `synchronization` is set to `ON_READ`, in
207        which case it defaults to `False`.
208      collections: List of graph collections keys to add the `Variable` to.
209        Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
210      caching_device: Optional device string or function describing where the
211        Variable should be cached for reading.  Defaults to the Variable's
212        device.  If not `None`, caches on another device.  Typical use is to
213        cache on the device where the Ops using the `Variable` reside, to
214        deduplicate copying through `Switch` and other conditional statements.
215      partitioner: Optional callable that accepts a fully defined `TensorShape`
216        and dtype of the `Variable` to be created, and returns a list of
217        partitions for each axis (currently only one axis can be partitioned).
218      validate_shape: If False, allows the variable to be initialized with a
219        value of unknown shape. If True, the default, the shape of initial_value
220        must be known.
221      use_resource: If False, creates a regular Variable. If True, creates
222        instead an experimental ResourceVariable which has well-defined
223        semantics. Defaults to False (will later change to True). When eager
224        execution is enabled this argument is always forced to be true.
225      custom_getter: Callable that takes as a first argument the true getter,
226        and allows overwriting the internal get_variable method. The signature
227        of `custom_getter` should match that of this method,
228        but the most future-proof version will allow for changes: `def
229          custom_getter(getter, *args, **kwargs)`.  Direct access to
230        all `get_variable` parameters is also allowed: `def
231          custom_getter(getter, name, *args, **kwargs)`.  A simple identity
232        custom getter that simply creates variables with modified names is:
233          ```python
234        def custom_getter(getter, name, *args, **kwargs): return getter(name +
235          '_suffix', *args, **kwargs) ```
236      constraint: An optional projection function to be applied to the variable
237        after being updated by an `Optimizer` (e.g. used to implement norm
238        constraints or value constraints for layer weights). The function must
239        take as input the unprojected Tensor representing the value of the
240        variable and return the Tensor for the projected value (which must have
241        the same shape). Constraints are not safe to use when doing asynchronous
242        distributed training.
243      synchronization: Indicates when a distributed a variable will be
244        aggregated. Accepted values are constants defined in the class
245        `tf.VariableSynchronization`. By default the synchronization is set to
246        `AUTO` and the current `DistributionStrategy` chooses when to
247        synchronize.
248      aggregation: Indicates how a distributed variable will be aggregated.
249        Accepted values are constants defined in the class
250        `tf.VariableAggregation`.
251
252    Returns:
253      The created or existing `Variable` (or `PartitionedVariable`, if a
254      partitioner was used).
255
256    Raises:
257      ValueError: when creating a new variable and shape is not declared,
258        when reusing a variable and specifying a conflicting shape,
259        or when violating reuse during variable creation.
260      RuntimeError: when eager execution is enabled and not called from an
261        EagerVariableStore.
262    """
263    if custom_getter is not None and not callable(custom_getter):
264      raise ValueError("Passed a custom_getter which is not callable: %s" %
265                       custom_getter)
266
267    with ops.init_scope():
268      if context.executing_eagerly():
269        # Variable creation and initialization takes place in `init_scope`s;
270        # as such, if an `init_scope` lifts us into the eager context, then we
271        # need to use `ResourceVariable`s.
272        use_resource = True
273
274    # Note that it's fine to reuse eager variables whose initialization was
275    # lifted from a function-building graph into the eager context (that's why
276    # the following clause is not wrapped in an `init_scope`); lifted variables
277    # are tracked by the graph's `VariableStore`.
278    if context.executing_eagerly():
279      reuse = vs.AUTO_REUSE
280
281    # If a *_ref type is passed in an error would be triggered further down the
282    # stack. We prevent this using base_dtype to get a non-ref version of the
283    # type, before doing anything else. When _ref types are removed in favor of
284    # resources, this line can be removed.
285    try:
286      dtype = dtype.base_dtype
287    except AttributeError:
288      # .base_dtype not existing means that we will try and use the raw dtype
289      # which was passed in - this might be a NumPy type which is valid.
290      pass
291
292    # This is the main logic of get_variable.  However, custom_getter
293    # may override this logic.  So we save it as a callable and pass
294    # it to custom_getter.
295    # Note: the parameters of _true_getter, and their documentation, match
296    # *exactly* item-for-item with the docstring of this method.
297    def _true_getter(  # pylint: disable=missing-docstring
298        name,
299        shape=None,
300        dtype=dtypes.float32,
301        initializer=None,
302        regularizer=None,
303        reuse=None,
304        trainable=None,
305        collections=None,  # pylint: disable=unused-argument
306        caching_device=None,
307        partitioner=None,
308        validate_shape=True,
309        use_resource=None,  # pylint: disable=unused-argument
310        constraint=None,
311        synchronization=vs.VariableSynchronization.AUTO,
312        aggregation=vs.VariableAggregation.NONE):
313      # Partitioned variable currently unsupported w/ the shim
314      if partitioner is not None:
315        raise ValueError(
316            "`partitioner` arg for `get_variable` is unsupported in TF2."
317            "File a bug if you need help. You passed %s" % partitioner)
318
319      # Single variable case
320      if "%s/part_0" % name in self._vars:
321        raise ValueError(
322            "No partitioner was provided, but a partitioned version of the "
323            "variable was found: %s/part_0. Perhaps a variable of the same "
324            "name was already created with partitioning?" % name)
325
326      return self._get_single_variable(
327          name=name,
328          shape=shape,
329          dtype=dtype,
330          initializer=initializer,
331          regularizer=regularizer,
332          reuse=reuse,
333          trainable=trainable,
334          caching_device=caching_device,
335          validate_shape=validate_shape,
336          constraint=constraint,
337          synchronization=synchronization,
338          aggregation=aggregation)
339
340    synchronization, aggregation, trainable = (
341        validate_synchronization_aggregation_trainable(
342            synchronization, aggregation, trainable, name))
343
344    if custom_getter is not None:
345      # Handle backwards compatibility with getter arguments that were added
346      # to the API after users started writing custom getters.
347      custom_getter_kwargs = {
348          "getter": _true_getter,
349          "name": name,
350          "shape": shape,
351          "dtype": dtype,
352          "initializer": initializer,
353          "regularizer": regularizer,
354          "reuse": reuse,
355          "trainable": trainable,
356          "collections": collections,
357          "caching_device": caching_device,
358          "partitioner": partitioner,
359          "validate_shape": validate_shape,
360          "use_resource": use_resource,
361          "synchronization": synchronization,
362          "aggregation": aggregation,
363      }
364      # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
365      # `lambda`.
366      if ("constraint" in fn_args(custom_getter) or
367          _has_kwargs(custom_getter)):
368        custom_getter_kwargs["constraint"] = constraint
369      return custom_getter(**custom_getter_kwargs)
370    else:
371      return _true_getter(
372          name,
373          shape=shape,
374          dtype=dtype,
375          initializer=initializer,
376          regularizer=regularizer,
377          reuse=reuse,
378          trainable=trainable,
379          collections=collections,
380          caching_device=caching_device,
381          partitioner=partitioner,
382          validate_shape=validate_shape,
383          use_resource=use_resource,
384          constraint=constraint,
385          synchronization=synchronization,
386          aggregation=aggregation)
387
388  def _get_single_variable(
389      self,
390      name,
391      shape=None,
392      dtype=dtypes.float32,
393      initializer=None,
394      regularizer=None,
395      partition_info=None,
396      reuse=None,
397      trainable=None,
398      caching_device=None,
399      validate_shape=True,
400      constraint=None,
401      synchronization=vs.VariableSynchronization.AUTO,
402      aggregation=vs.VariableAggregation.NONE):
403    """Get or create a single Variable (e.g.
404
405    a shard or entire variable).
406
407    See the documentation of get_variable above (ignore partitioning components)
408    for details.
409
410    Args:
411      name: see get_variable.
412      shape: see get_variable.
413      dtype: see get_variable.
414      initializer: see get_variable.
415      regularizer: see get_variable.
416      partition_info: _PartitionInfo object.
417      reuse: see get_variable.
418      trainable: see get_variable.
419      caching_device: see get_variable.
420      validate_shape: see get_variable.
421      constraint: see get_variable.
422      synchronization: see get_variable.
423      aggregation: see get_variable.
424
425    Returns:
426      A Variable.  See documentation of get_variable above.
427
428    Raises:
429      ValueError: See documentation of get_variable above.
430    """
431    # Set to true if initializer is a constant.
432    initializing_from_value = False
433    if initializer is not None and not callable(initializer):
434      initializing_from_value = True
435    if shape is not None and initializing_from_value:
436      raise ValueError("If initializer is a constant, do not specify shape.")
437
438    dtype = dtypes.as_dtype(dtype)
439    shape = as_shape(shape)
440
441    if name in self._vars:
442      # Here we handle the case when returning an existing variable.
443      if reuse is False:  # pylint: disable=g-bool-id-comparison
444        err_msg = ("Variable %s already exists, disallowed."
445                   " Did you mean to set reuse=True or "
446                   "reuse=tf.AUTO_REUSE in VarScope?" % name)
447        # ResourceVariables don't have an op associated with so no traceback
448        raise ValueError(err_msg)
449      found_var = self._vars[name]
450      if not shape.is_compatible_with(found_var.get_shape()):
451        raise ValueError("Trying to share variable %s, but specified shape %s"
452                         " and found shape %s." %
453                         (name, shape, found_var.get_shape()))
454      if not dtype.is_compatible_with(found_var.dtype):
455        dtype_str = dtype.name
456        found_type_str = found_var.dtype.name
457        raise ValueError("Trying to share variable %s, but specified dtype %s"
458                         " and found dtype %s." %
459                         (name, dtype_str, found_type_str))
460      return found_var
461
462    # The code below handles only the case of creating a new variable.
463    if reuse is True:  # pylint: disable=g-bool-id-comparison
464      raise ValueError("Variable %s does not exist, or was not created with "
465                       "tf.get_variable(). Did you mean to set "
466                       "reuse=tf.AUTO_REUSE in VarScope?" % name)
467
468    # Create the tensor to initialize the variable with default value.
469    if initializer is None:
470      initializer, initializing_from_value = self._get_default_initializer(
471          name=name, shape=shape, dtype=dtype)
472    # Enter an init scope when creating the initializer.
473    with ops.init_scope():
474      if initializing_from_value:
475        init_val = initializer
476        variable_dtype = None
477      else:
478        # Instantiate initializer if provided initializer is a type object.
479        if tf_inspect.isclass(initializer):
480          initializer = initializer()
481        if shape.is_fully_defined():
482          if "partition_info" in tf_inspect.getargspec(initializer).args:
483            init_val = functools.partial(initializer,
484                                         shape.as_list(),
485                                         dtype=dtype,
486                                         partition_info=partition_info)
487          else:
488            init_val = functools.partial(initializer,
489                                         shape.as_list(), dtype=dtype)
490          variable_dtype = dtype.base_dtype
491        else:
492          init_val = initializer
493          variable_dtype = None
494
495    # Create the variable (Always eagerly as a workaround for a strange
496    # tpu / funcgraph / keras functional model interaction )
497    with ops.init_scope():
498      v = variables.Variable(
499          initial_value=init_val,
500          name=name,
501          trainable=trainable,
502          caching_device=caching_device,
503          dtype=variable_dtype,
504          validate_shape=validate_shape,
505          constraint=constraint,
506          synchronization=synchronization,
507          aggregation=aggregation)
508
509    self._vars[name] = v
510    logging.vlog(1, "Created variable %s with shape %s and init %s", v.name,
511                 format(shape), initializer)
512
513    # Run the regularizer if requested and save the resulting loss.
514    if regularizer:
515      self.add_regularizer(v, regularizer)
516
517    return v
518
519  def add_regularizer(self, var, regularizer):
520    self._regularizers[var.name] = functools.partial(regularizer, var)
521
522  # Initialize variable when no initializer provided
523  def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32):
524    """Provide a default initializer and a corresponding value.
525
526    Args:
527      name: see get_variable.
528      shape: see get_variable.
529      dtype: see get_variable.
530
531    Returns:
532      initializer and initializing_from_value. See get_variable above.
533
534    Raises:
535      ValueError: When giving unsupported dtype.
536    """
537    del shape
538    # If dtype is DT_FLOAT, provide a uniform unit scaling initializer
539    if dtype.is_floating:
540      initializer = init_ops.glorot_uniform_initializer()
541      initializing_from_value = False
542    # If dtype is DT_INT/DT_UINT, provide a default value `zero`
543    # If dtype is DT_BOOL, provide a default value `FALSE`
544    elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or
545          dtype == dtypes.string):
546      initializer = init_ops.zeros_initializer()
547      initializing_from_value = False
548    # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
549    else:
550      raise ValueError("An initializer for variable %s of %s is required" %
551                       (name, dtype.base_dtype))
552
553    return initializer, initializing_from_value
554
555
556class VariableAndLossTracker(module.Module):
557  """Module that has a scope to capture vars/losses made by `get_variable`."""
558
559  def __init__(self):
560    self._var_store = _EagerVariableStore()  # pylint: disable=protected-access
561    self._variables = {}
562
563  def _variable_creator(self, next_creator, **kwargs):
564    var = next_creator(**kwargs)
565    self._variables[var.name] = var
566
567    return var
568
569  @tf_contextlib.contextmanager
570  def scope(self):
571    with vs.variable_creator_scope(
572        self._variable_creator), vs.with_variable_store(self._var_store):
573      yield
574
575  def get_regularization_losses(self):
576    # TODO(kaftan): Consider adding a regex scope like the collection access.
577    # But, < 40-50 usages of get_regularization_loss(es) with `scope`
578    # & possible to do manually?
579    losses = {}
580    for var_name, regularizer in self._var_store._regularizers.items():  # pylint: disable=protected-access
581      losses[var_name] = regularizer()
582    return losses
583
584
585class VariableScopeWrapperLayer(base_layer.Layer):
586  """Wrapper Layer to capture `compat.v1.get_variable` and `compat.v1.layers`.
587
588  See go/tf2-migration-model-bookkeeping for background.
589
590  This shim layer allows using large sets of TF1 model-forward-pass code as a
591  Keras layer that works in TF2 with TF2 behaviors enabled. To use it,
592  override this class and put your TF1 model's forward pass inside your
593  implementation for `forward_pass`.
594
595  Below are some examples, and then more details on the functionality of this
596  shhim layer to wrap TF1 model forward passes.
597
598  Example of capturing tf.compat.v1.layer-based modeling code as a Keras layer:
599
600  ```python
601  class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeWrapperLayer):
602
603    def __init__(self, units, *args, **kwargs):
604      super().__init__(*args, **kwargs)
605      self.units = units
606
607    def forward_pass(self, inputs, training=None):
608      out = tf.compat.v1.layers.dense(
609          inputs, self.units, name="dense_one",
610          kernel_initializer=init_ops.ones_initializer(),
611          kernel_regularizer="l2")
612      with variable_scope.variable_scope("nested_scope"):
613        out = tf.compat.v1.layers.dense(
614            out, self.units, name="dense_two",
615            kernel_initializer=init_ops.ones_initializer(),
616            kernel_regularizer="l2")
617      return out
618
619  # Create a layer that can be used as a standard keras layer
620  layer = WrappedDoubleDenseLayer(10)
621
622  # call the layer on inputs
623  layer(...)
624
625  # Variables created/used within the scope will be tracked by the layer
626  layer.weights
627  layer.trainable_variables
628
629  # Regularization losses will be captured in layer.losses after a call,
630  # just like any other Keras layer
631  reg_losses = layer.losses
632  ```
633
634  The solution is to wrap the model construction and execution in a keras-style
635  scope:
636
637  ```python
638  class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeWrapperLayer):
639
640    def __init__(self, units, *args, **kwargs):
641      super().__init__(*args, **kwargs)
642      self.units = units
643
644    def forward_pass(self, inputs, training=None):
645      out = inputs
646      with tf.compat.v1.variable_scope("dense_one"):
647        # The weights are created with a `regularizer`,
648        # so the layer should track their regularization losses
649        kernel = tf.compat.v1.get_variable(
650            shape=[out.shape[-1], self.units],
651            regularizer=regularizers.L2(),
652            initializer=init_ops.ones_initializer(),
653            name="kernel")
654        bias = tf.compat.v1.get_variable(
655            shape=[self.units,],
656            initializer=init_ops.zeros_initializer(),
657            name="bias")
658        out = tf.compat.v1.math.matmul(out, kernel)
659        out = tf.compat.v1.nn.bias_add(out, bias)
660      with tf.compat.v1.variable_scope("nested_scope"):
661        with tf.compat.v1.variable_scope("dense_two"):
662          kernel = tf.compat.v1.get_variable(
663              shape=[out.shape[-1], self.units],
664              regularizer=regularizers.L2(),
665              initializer=init_ops.ones_initializer(),
666              name="kernel")
667          bias = tf.compat.v1.get_variable(
668              shape=[self.units,],
669              initializer=init_ops.zeros_initializer(),
670              name="bias")
671          out = tf.compat.v1.math.matmul(out, kernel)
672          out = tf.compat.v1.nn.bias_add(out, bias)
673      return out
674
675  # Create a layer that can be used as a standard keras layer
676  layer = WrappedDoubleDenseLayer(10)
677
678  # call the layer on inputs
679  layer(...)
680
681  # Variables created/used within the scope will be tracked by the layer
682  layer.weights
683  layer.trainable_variables
684
685  # Regularization losses will be captured in layer.losses after a call,
686  # just like any other Keras layer
687  reg_losses = layer.losses
688  ```
689
690  Regularization losses:
691    Any regularizers specified in the `get_variable` calls or `compat.v1.layer`
692    creations will get captured by this wrapper layer. Regularization losses
693    are accessible in `layer.losses` after a call just like in a standard
694    Keras layer, and will be captured by any model that includes this layer.
695
696  Variable scope / variable reuse:
697    variable-scope based reuse in the `forward_pass` will be respected,
698    and work like variable-scope based reuse in TF1.
699
700  Variable Names/Pre-trained checkpoint loading:
701    variable naming from get_variable and `compat.v1.layer` layers will match
702    the TF1 names, so you should be able to re-use your old name-based
703    checkpoints.
704
705  Training Arg in `forward_pass`:
706    Keras will pass a `training` arg to this layer similarly to how it
707    passes `training` to other layers in TF2. See more details in the docs
708    on `tf.keras.layers.Layer` to understand what will be passed and when.
709    Note: tf.compat.v1.layers are usually not called with `training=None`,
710    so the training arg to `forward_pass` might not feed through to them
711    unless you pass it to their calls explicitly.
712
713  Call signature of the forward pass:
714    The semantics of the forward pass signature roughly match the standard
715    Keras layer `call` signature, except that a `training` arg will *always*
716    be passed, so your `forward_pass` must accept either.
717
718  Limitations:
719    * TF2 will not prune unused variable updates (or unused outputs). You may
720      need to adjust your forward pass code to avoid computations or variable
721      updates that you don't intend to use. (E.g. by adding a flag to the
722      `forward_pass` call signature and branching on it).
723    * Avoid Nesting variable creation in tf.function inside of `forward_pass`
724      While the layer may safetely be used from inside a `tf.function`, using
725      a function inside of `forward_pass` will break the variable scoping.
726    * TBD: Nesting keras layers/models or other `VariableScopeWrapperLayer`s
727      directly in `forward_pass` may not work correctly just yet.
728      Support for this/instructions for how to do this is sill being worked on.
729
730  Coming soon: A better guide, testing/verification guide.
731  """
732
733  def __init__(self, **kwargs):
734    super().__init__(**kwargs)
735    # Relies on keras layers tracking Modules
736    self.tracker = VariableAndLossTracker()
737    # May need to inspect func to see if it should pass a `training` arg or not
738
739  def forward_pass(self, *args, **kwargs):
740    raise NotImplementedError
741
742  def call(self, *args, **kwargs):
743    with self.tracker.scope():
744      out = self.forward_pass(*args, **kwargs)
745    if not self._eager_losses:
746      # We have to record regularization losses in the call as if they
747      # are activity losses.
748      # So, don't double-count regularization losses if the layer is used
749      # multiple times in a model
750      for loss in self.tracker.get_regularization_losses().values():
751        self.add_loss(loss)
752    return out
753