1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15# pylint: disable=g-classes-have-attributes 16"""Contains a shim to allow using TF1 get_variable code in TF2.""" 17import functools 18 19from tensorflow.python.eager import context 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.keras.engine import base_layer 24from tensorflow.python.keras.utils import tf_contextlib 25from tensorflow.python.keras.utils import tf_inspect 26from tensorflow.python.module import module 27from tensorflow.python.ops import init_ops 28from tensorflow.python.ops import variable_scope as vs 29from tensorflow.python.ops import variables 30from tensorflow.python.platform import tf_logging as logging 31from tensorflow.python.util import tf_decorator 32 33 34def as_shape(shape): 35 """Converts the given object to a TensorShape.""" 36 if isinstance(shape, tensor_shape.TensorShape): 37 return shape 38 else: 39 return tensor_shape.TensorShape(shape) 40 41 42def _is_callable_object(obj): 43 return hasattr(obj, "__call__") and tf_inspect.ismethod(obj.__call__) 44 45 46def _has_kwargs(fn): 47 """Returns whether the passed callable has **kwargs in its signature. 48 49 Args: 50 fn: Function, or function-like object (e.g., result of `functools.partial`). 51 52 Returns: 53 `bool`: if `fn` has **kwargs in its signature. 54 55 Raises: 56 `TypeError`: If fn is not a Function, or function-like object. 57 """ 58 if isinstance(fn, functools.partial): 59 fn = fn.func 60 elif _is_callable_object(fn): 61 fn = fn.__call__ 62 elif not callable(fn): 63 raise TypeError( 64 "fn should be a function-like object, but is of type {}.".format( 65 type(fn))) 66 return tf_inspect.getfullargspec(fn).varkw is not None 67 68 69def fn_args(fn): 70 """Get argument names for function-like object. 71 72 Args: 73 fn: Function, or function-like object (e.g., result of `functools.partial`). 74 75 Returns: 76 `tuple` of string argument names. 77 78 Raises: 79 ValueError: if partial function has positionally bound arguments 80 """ 81 if isinstance(fn, functools.partial): 82 args = fn_args(fn.func) 83 args = [a for a in args[len(fn.args):] if a not in (fn.keywords or [])] 84 else: 85 if hasattr(fn, "__call__") and tf_inspect.ismethod(fn.__call__): 86 fn = fn.__call__ 87 args = tf_inspect.getfullargspec(fn).args 88 if _is_bound_method(fn) and args: 89 # If it's a bound method, it may or may not have a self/cls first 90 # argument; for example, self could be captured in *args. 91 # If it does have a positional argument, it is self/cls. 92 args.pop(0) 93 return tuple(args) 94 95 96def _is_bound_method(fn): 97 _, fn = tf_decorator.unwrap(fn) 98 return tf_inspect.ismethod(fn) and (fn.__self__ is not None) 99 100 101def validate_synchronization_aggregation_trainable( 102 synchronization, aggregation, trainable, name): 103 """Given user-provided variable properties, sets defaults and validates.""" 104 if aggregation is None: 105 aggregation = variables.VariableAggregation.NONE 106 else: 107 if not isinstance(aggregation, 108 (variables.VariableAggregation, 109 variables.VariableAggregationV2)): 110 try: 111 aggregation = variables.VariableAggregationV2(aggregation) 112 except ValueError: 113 raise ValueError( 114 "Invalid variable aggregation mode: {} for variable: {}".format( 115 aggregation, name)) 116 if synchronization is None: 117 synchronization = variables.VariableSynchronization.AUTO 118 else: 119 try: 120 synchronization = variables.VariableSynchronization(synchronization) 121 except ValueError: 122 raise ValueError( 123 "Invalid variable synchronization mode: {} for variable: {}".format( 124 synchronization, name)) 125 if trainable is None: 126 trainable = synchronization != variables.VariableSynchronization.ON_READ 127 return synchronization, aggregation, trainable 128 129 130class _EagerVariableStore(object): 131 """TF2-compatible VariableStore that avoids collections & tracks regularizers. 132 133 New variable names and new variables can be created; all stored 134 variables are initialized with the initializer passed to __init__. 135 136 All variables get created in `tf.init_scope.` to avoid a bad 137 interaction between `tf.function` `FuncGraph` internals, Keras 138 Functional Models, and TPUStrategy variable initialization. 139 140 Attributes: 141 vars: a dictionary with string names (same as passed in GetVar) as keys and 142 the corresponding TensorFlow Variables as values. 143 """ 144 145 __slots__ = ["_vars", "_regularizers", "_store_eager_variables"] 146 147 def __init__(self): 148 """Create a variable store.""" 149 self._vars = {} # A dictionary of the stored TensorFlow variables. 150 self._regularizers = {} # A dict mapping var names to their regularizers. 151 self._store_eager_variables = True 152 153 def get_variable( 154 self, 155 name, 156 shape=None, 157 dtype=dtypes.float32, 158 initializer=None, 159 regularizer=None, 160 reuse=None, 161 trainable=None, 162 collections=None, 163 caching_device=None, 164 partitioner=None, 165 validate_shape=True, 166 use_resource=None, 167 custom_getter=None, 168 constraint=None, 169 synchronization=vs.VariableSynchronization.AUTO, 170 aggregation=vs.VariableAggregation.NONE): 171 """Gets an existing variable with these parameters or create a new one. 172 173 If a variable with the given name is already stored, we return the stored 174 variable. Otherwise, we create a new one. 175 176 Set `reuse` to `True` when you only want to reuse existing Variables. 177 Set `reuse` to `False` when you only want to create new Variables. 178 Set `reuse` to None (the default) or tf.compat.v1.AUTO_REUSE when you want 179 variables to be created if they don't exist or returned if they do. 180 181 If initializer is `None` (the default), the default initializer passed in 182 the constructor is used. If that one is `None` too, we use a new 183 `glorot_uniform_initializer`. If initializer is a Tensor, we use 184 it as a value and derive the shape from the initializer. 185 186 If a partitioner is provided, a `PartitionedVariable` is returned. 187 Accessing this object as a `Tensor` returns the shards concatenated along 188 the partition axis. 189 190 Some useful partitioners are available. See, e.g., 191 `variable_axis_size_partitioner` and `min_max_variable_partitioner`. 192 193 Args: 194 name: The name of the new or existing variable. 195 shape: Shape of the new or existing variable. 196 dtype: Type of the new or existing variable (defaults to `DT_FLOAT`). 197 initializer: Initializer for the variable. 198 regularizer: A (Tensor -> Tensor or None) function; the result of applying 199 it on a newly created variable will be added to the collection 200 GraphKeys.REGULARIZATION_LOSSES and can be used for regularization. 201 reuse: a Boolean, None, or tf.AUTO_REUSE. Controls reuse or creation of 202 variables. When eager execution is enabled this argument is always 203 forced to be False. 204 trainable: If `True` also add the variable to the graph collection 205 `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable` 206 defaults to `True`, unless `synchronization` is set to `ON_READ`, in 207 which case it defaults to `False`. 208 collections: List of graph collections keys to add the `Variable` to. 209 Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). 210 caching_device: Optional device string or function describing where the 211 Variable should be cached for reading. Defaults to the Variable's 212 device. If not `None`, caches on another device. Typical use is to 213 cache on the device where the Ops using the `Variable` reside, to 214 deduplicate copying through `Switch` and other conditional statements. 215 partitioner: Optional callable that accepts a fully defined `TensorShape` 216 and dtype of the `Variable` to be created, and returns a list of 217 partitions for each axis (currently only one axis can be partitioned). 218 validate_shape: If False, allows the variable to be initialized with a 219 value of unknown shape. If True, the default, the shape of initial_value 220 must be known. 221 use_resource: If False, creates a regular Variable. If True, creates 222 instead an experimental ResourceVariable which has well-defined 223 semantics. Defaults to False (will later change to True). When eager 224 execution is enabled this argument is always forced to be true. 225 custom_getter: Callable that takes as a first argument the true getter, 226 and allows overwriting the internal get_variable method. The signature 227 of `custom_getter` should match that of this method, 228 but the most future-proof version will allow for changes: `def 229 custom_getter(getter, *args, **kwargs)`. Direct access to 230 all `get_variable` parameters is also allowed: `def 231 custom_getter(getter, name, *args, **kwargs)`. A simple identity 232 custom getter that simply creates variables with modified names is: 233 ```python 234 def custom_getter(getter, name, *args, **kwargs): return getter(name + 235 '_suffix', *args, **kwargs) ``` 236 constraint: An optional projection function to be applied to the variable 237 after being updated by an `Optimizer` (e.g. used to implement norm 238 constraints or value constraints for layer weights). The function must 239 take as input the unprojected Tensor representing the value of the 240 variable and return the Tensor for the projected value (which must have 241 the same shape). Constraints are not safe to use when doing asynchronous 242 distributed training. 243 synchronization: Indicates when a distributed a variable will be 244 aggregated. Accepted values are constants defined in the class 245 `tf.VariableSynchronization`. By default the synchronization is set to 246 `AUTO` and the current `DistributionStrategy` chooses when to 247 synchronize. 248 aggregation: Indicates how a distributed variable will be aggregated. 249 Accepted values are constants defined in the class 250 `tf.VariableAggregation`. 251 252 Returns: 253 The created or existing `Variable` (or `PartitionedVariable`, if a 254 partitioner was used). 255 256 Raises: 257 ValueError: when creating a new variable and shape is not declared, 258 when reusing a variable and specifying a conflicting shape, 259 or when violating reuse during variable creation. 260 RuntimeError: when eager execution is enabled and not called from an 261 EagerVariableStore. 262 """ 263 if custom_getter is not None and not callable(custom_getter): 264 raise ValueError("Passed a custom_getter which is not callable: %s" % 265 custom_getter) 266 267 with ops.init_scope(): 268 if context.executing_eagerly(): 269 # Variable creation and initialization takes place in `init_scope`s; 270 # as such, if an `init_scope` lifts us into the eager context, then we 271 # need to use `ResourceVariable`s. 272 use_resource = True 273 274 # Note that it's fine to reuse eager variables whose initialization was 275 # lifted from a function-building graph into the eager context (that's why 276 # the following clause is not wrapped in an `init_scope`); lifted variables 277 # are tracked by the graph's `VariableStore`. 278 if context.executing_eagerly(): 279 reuse = vs.AUTO_REUSE 280 281 # If a *_ref type is passed in an error would be triggered further down the 282 # stack. We prevent this using base_dtype to get a non-ref version of the 283 # type, before doing anything else. When _ref types are removed in favor of 284 # resources, this line can be removed. 285 try: 286 dtype = dtype.base_dtype 287 except AttributeError: 288 # .base_dtype not existing means that we will try and use the raw dtype 289 # which was passed in - this might be a NumPy type which is valid. 290 pass 291 292 # This is the main logic of get_variable. However, custom_getter 293 # may override this logic. So we save it as a callable and pass 294 # it to custom_getter. 295 # Note: the parameters of _true_getter, and their documentation, match 296 # *exactly* item-for-item with the docstring of this method. 297 def _true_getter( # pylint: disable=missing-docstring 298 name, 299 shape=None, 300 dtype=dtypes.float32, 301 initializer=None, 302 regularizer=None, 303 reuse=None, 304 trainable=None, 305 collections=None, # pylint: disable=unused-argument 306 caching_device=None, 307 partitioner=None, 308 validate_shape=True, 309 use_resource=None, # pylint: disable=unused-argument 310 constraint=None, 311 synchronization=vs.VariableSynchronization.AUTO, 312 aggregation=vs.VariableAggregation.NONE): 313 # Partitioned variable currently unsupported w/ the shim 314 if partitioner is not None: 315 raise ValueError( 316 "`partitioner` arg for `get_variable` is unsupported in TF2." 317 "File a bug if you need help. You passed %s" % partitioner) 318 319 # Single variable case 320 if "%s/part_0" % name in self._vars: 321 raise ValueError( 322 "No partitioner was provided, but a partitioned version of the " 323 "variable was found: %s/part_0. Perhaps a variable of the same " 324 "name was already created with partitioning?" % name) 325 326 return self._get_single_variable( 327 name=name, 328 shape=shape, 329 dtype=dtype, 330 initializer=initializer, 331 regularizer=regularizer, 332 reuse=reuse, 333 trainable=trainable, 334 caching_device=caching_device, 335 validate_shape=validate_shape, 336 constraint=constraint, 337 synchronization=synchronization, 338 aggregation=aggregation) 339 340 synchronization, aggregation, trainable = ( 341 validate_synchronization_aggregation_trainable( 342 synchronization, aggregation, trainable, name)) 343 344 if custom_getter is not None: 345 # Handle backwards compatibility with getter arguments that were added 346 # to the API after users started writing custom getters. 347 custom_getter_kwargs = { 348 "getter": _true_getter, 349 "name": name, 350 "shape": shape, 351 "dtype": dtype, 352 "initializer": initializer, 353 "regularizer": regularizer, 354 "reuse": reuse, 355 "trainable": trainable, 356 "collections": collections, 357 "caching_device": caching_device, 358 "partitioner": partitioner, 359 "validate_shape": validate_shape, 360 "use_resource": use_resource, 361 "synchronization": synchronization, 362 "aggregation": aggregation, 363 } 364 # `fn_args` and `has_kwargs` can handle functions, `functools.partial`, 365 # `lambda`. 366 if ("constraint" in fn_args(custom_getter) or 367 _has_kwargs(custom_getter)): 368 custom_getter_kwargs["constraint"] = constraint 369 return custom_getter(**custom_getter_kwargs) 370 else: 371 return _true_getter( 372 name, 373 shape=shape, 374 dtype=dtype, 375 initializer=initializer, 376 regularizer=regularizer, 377 reuse=reuse, 378 trainable=trainable, 379 collections=collections, 380 caching_device=caching_device, 381 partitioner=partitioner, 382 validate_shape=validate_shape, 383 use_resource=use_resource, 384 constraint=constraint, 385 synchronization=synchronization, 386 aggregation=aggregation) 387 388 def _get_single_variable( 389 self, 390 name, 391 shape=None, 392 dtype=dtypes.float32, 393 initializer=None, 394 regularizer=None, 395 partition_info=None, 396 reuse=None, 397 trainable=None, 398 caching_device=None, 399 validate_shape=True, 400 constraint=None, 401 synchronization=vs.VariableSynchronization.AUTO, 402 aggregation=vs.VariableAggregation.NONE): 403 """Get or create a single Variable (e.g. 404 405 a shard or entire variable). 406 407 See the documentation of get_variable above (ignore partitioning components) 408 for details. 409 410 Args: 411 name: see get_variable. 412 shape: see get_variable. 413 dtype: see get_variable. 414 initializer: see get_variable. 415 regularizer: see get_variable. 416 partition_info: _PartitionInfo object. 417 reuse: see get_variable. 418 trainable: see get_variable. 419 caching_device: see get_variable. 420 validate_shape: see get_variable. 421 constraint: see get_variable. 422 synchronization: see get_variable. 423 aggregation: see get_variable. 424 425 Returns: 426 A Variable. See documentation of get_variable above. 427 428 Raises: 429 ValueError: See documentation of get_variable above. 430 """ 431 # Set to true if initializer is a constant. 432 initializing_from_value = False 433 if initializer is not None and not callable(initializer): 434 initializing_from_value = True 435 if shape is not None and initializing_from_value: 436 raise ValueError("If initializer is a constant, do not specify shape.") 437 438 dtype = dtypes.as_dtype(dtype) 439 shape = as_shape(shape) 440 441 if name in self._vars: 442 # Here we handle the case when returning an existing variable. 443 if reuse is False: # pylint: disable=g-bool-id-comparison 444 err_msg = ("Variable %s already exists, disallowed." 445 " Did you mean to set reuse=True or " 446 "reuse=tf.AUTO_REUSE in VarScope?" % name) 447 # ResourceVariables don't have an op associated with so no traceback 448 raise ValueError(err_msg) 449 found_var = self._vars[name] 450 if not shape.is_compatible_with(found_var.get_shape()): 451 raise ValueError("Trying to share variable %s, but specified shape %s" 452 " and found shape %s." % 453 (name, shape, found_var.get_shape())) 454 if not dtype.is_compatible_with(found_var.dtype): 455 dtype_str = dtype.name 456 found_type_str = found_var.dtype.name 457 raise ValueError("Trying to share variable %s, but specified dtype %s" 458 " and found dtype %s." % 459 (name, dtype_str, found_type_str)) 460 return found_var 461 462 # The code below handles only the case of creating a new variable. 463 if reuse is True: # pylint: disable=g-bool-id-comparison 464 raise ValueError("Variable %s does not exist, or was not created with " 465 "tf.get_variable(). Did you mean to set " 466 "reuse=tf.AUTO_REUSE in VarScope?" % name) 467 468 # Create the tensor to initialize the variable with default value. 469 if initializer is None: 470 initializer, initializing_from_value = self._get_default_initializer( 471 name=name, shape=shape, dtype=dtype) 472 # Enter an init scope when creating the initializer. 473 with ops.init_scope(): 474 if initializing_from_value: 475 init_val = initializer 476 variable_dtype = None 477 else: 478 # Instantiate initializer if provided initializer is a type object. 479 if tf_inspect.isclass(initializer): 480 initializer = initializer() 481 if shape.is_fully_defined(): 482 if "partition_info" in tf_inspect.getargspec(initializer).args: 483 init_val = functools.partial(initializer, 484 shape.as_list(), 485 dtype=dtype, 486 partition_info=partition_info) 487 else: 488 init_val = functools.partial(initializer, 489 shape.as_list(), dtype=dtype) 490 variable_dtype = dtype.base_dtype 491 else: 492 init_val = initializer 493 variable_dtype = None 494 495 # Create the variable (Always eagerly as a workaround for a strange 496 # tpu / funcgraph / keras functional model interaction ) 497 with ops.init_scope(): 498 v = variables.Variable( 499 initial_value=init_val, 500 name=name, 501 trainable=trainable, 502 caching_device=caching_device, 503 dtype=variable_dtype, 504 validate_shape=validate_shape, 505 constraint=constraint, 506 synchronization=synchronization, 507 aggregation=aggregation) 508 509 self._vars[name] = v 510 logging.vlog(1, "Created variable %s with shape %s and init %s", v.name, 511 format(shape), initializer) 512 513 # Run the regularizer if requested and save the resulting loss. 514 if regularizer: 515 self.add_regularizer(v, regularizer) 516 517 return v 518 519 def add_regularizer(self, var, regularizer): 520 self._regularizers[var.name] = functools.partial(regularizer, var) 521 522 # Initialize variable when no initializer provided 523 def _get_default_initializer(self, name, shape=None, dtype=dtypes.float32): 524 """Provide a default initializer and a corresponding value. 525 526 Args: 527 name: see get_variable. 528 shape: see get_variable. 529 dtype: see get_variable. 530 531 Returns: 532 initializer and initializing_from_value. See get_variable above. 533 534 Raises: 535 ValueError: When giving unsupported dtype. 536 """ 537 del shape 538 # If dtype is DT_FLOAT, provide a uniform unit scaling initializer 539 if dtype.is_floating: 540 initializer = init_ops.glorot_uniform_initializer() 541 initializing_from_value = False 542 # If dtype is DT_INT/DT_UINT, provide a default value `zero` 543 # If dtype is DT_BOOL, provide a default value `FALSE` 544 elif (dtype.is_integer or dtype.is_unsigned or dtype.is_bool or 545 dtype == dtypes.string): 546 initializer = init_ops.zeros_initializer() 547 initializing_from_value = False 548 # NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here? 549 else: 550 raise ValueError("An initializer for variable %s of %s is required" % 551 (name, dtype.base_dtype)) 552 553 return initializer, initializing_from_value 554 555 556class VariableAndLossTracker(module.Module): 557 """Module that has a scope to capture vars/losses made by `get_variable`.""" 558 559 def __init__(self): 560 self._var_store = _EagerVariableStore() # pylint: disable=protected-access 561 self._variables = {} 562 563 def _variable_creator(self, next_creator, **kwargs): 564 var = next_creator(**kwargs) 565 self._variables[var.name] = var 566 567 return var 568 569 @tf_contextlib.contextmanager 570 def scope(self): 571 with vs.variable_creator_scope( 572 self._variable_creator), vs.with_variable_store(self._var_store): 573 yield 574 575 def get_regularization_losses(self): 576 # TODO(kaftan): Consider adding a regex scope like the collection access. 577 # But, < 40-50 usages of get_regularization_loss(es) with `scope` 578 # & possible to do manually? 579 losses = {} 580 for var_name, regularizer in self._var_store._regularizers.items(): # pylint: disable=protected-access 581 losses[var_name] = regularizer() 582 return losses 583 584 585class VariableScopeWrapperLayer(base_layer.Layer): 586 """Wrapper Layer to capture `compat.v1.get_variable` and `compat.v1.layers`. 587 588 See go/tf2-migration-model-bookkeeping for background. 589 590 This shim layer allows using large sets of TF1 model-forward-pass code as a 591 Keras layer that works in TF2 with TF2 behaviors enabled. To use it, 592 override this class and put your TF1 model's forward pass inside your 593 implementation for `forward_pass`. 594 595 Below are some examples, and then more details on the functionality of this 596 shhim layer to wrap TF1 model forward passes. 597 598 Example of capturing tf.compat.v1.layer-based modeling code as a Keras layer: 599 600 ```python 601 class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeWrapperLayer): 602 603 def __init__(self, units, *args, **kwargs): 604 super().__init__(*args, **kwargs) 605 self.units = units 606 607 def forward_pass(self, inputs, training=None): 608 out = tf.compat.v1.layers.dense( 609 inputs, self.units, name="dense_one", 610 kernel_initializer=init_ops.ones_initializer(), 611 kernel_regularizer="l2") 612 with variable_scope.variable_scope("nested_scope"): 613 out = tf.compat.v1.layers.dense( 614 out, self.units, name="dense_two", 615 kernel_initializer=init_ops.ones_initializer(), 616 kernel_regularizer="l2") 617 return out 618 619 # Create a layer that can be used as a standard keras layer 620 layer = WrappedDoubleDenseLayer(10) 621 622 # call the layer on inputs 623 layer(...) 624 625 # Variables created/used within the scope will be tracked by the layer 626 layer.weights 627 layer.trainable_variables 628 629 # Regularization losses will be captured in layer.losses after a call, 630 # just like any other Keras layer 631 reg_losses = layer.losses 632 ``` 633 634 The solution is to wrap the model construction and execution in a keras-style 635 scope: 636 637 ```python 638 class WrappedDoubleDenseLayer(variable_scope_shim.VariableScopeWrapperLayer): 639 640 def __init__(self, units, *args, **kwargs): 641 super().__init__(*args, **kwargs) 642 self.units = units 643 644 def forward_pass(self, inputs, training=None): 645 out = inputs 646 with tf.compat.v1.variable_scope("dense_one"): 647 # The weights are created with a `regularizer`, 648 # so the layer should track their regularization losses 649 kernel = tf.compat.v1.get_variable( 650 shape=[out.shape[-1], self.units], 651 regularizer=regularizers.L2(), 652 initializer=init_ops.ones_initializer(), 653 name="kernel") 654 bias = tf.compat.v1.get_variable( 655 shape=[self.units,], 656 initializer=init_ops.zeros_initializer(), 657 name="bias") 658 out = tf.compat.v1.math.matmul(out, kernel) 659 out = tf.compat.v1.nn.bias_add(out, bias) 660 with tf.compat.v1.variable_scope("nested_scope"): 661 with tf.compat.v1.variable_scope("dense_two"): 662 kernel = tf.compat.v1.get_variable( 663 shape=[out.shape[-1], self.units], 664 regularizer=regularizers.L2(), 665 initializer=init_ops.ones_initializer(), 666 name="kernel") 667 bias = tf.compat.v1.get_variable( 668 shape=[self.units,], 669 initializer=init_ops.zeros_initializer(), 670 name="bias") 671 out = tf.compat.v1.math.matmul(out, kernel) 672 out = tf.compat.v1.nn.bias_add(out, bias) 673 return out 674 675 # Create a layer that can be used as a standard keras layer 676 layer = WrappedDoubleDenseLayer(10) 677 678 # call the layer on inputs 679 layer(...) 680 681 # Variables created/used within the scope will be tracked by the layer 682 layer.weights 683 layer.trainable_variables 684 685 # Regularization losses will be captured in layer.losses after a call, 686 # just like any other Keras layer 687 reg_losses = layer.losses 688 ``` 689 690 Regularization losses: 691 Any regularizers specified in the `get_variable` calls or `compat.v1.layer` 692 creations will get captured by this wrapper layer. Regularization losses 693 are accessible in `layer.losses` after a call just like in a standard 694 Keras layer, and will be captured by any model that includes this layer. 695 696 Variable scope / variable reuse: 697 variable-scope based reuse in the `forward_pass` will be respected, 698 and work like variable-scope based reuse in TF1. 699 700 Variable Names/Pre-trained checkpoint loading: 701 variable naming from get_variable and `compat.v1.layer` layers will match 702 the TF1 names, so you should be able to re-use your old name-based 703 checkpoints. 704 705 Training Arg in `forward_pass`: 706 Keras will pass a `training` arg to this layer similarly to how it 707 passes `training` to other layers in TF2. See more details in the docs 708 on `tf.keras.layers.Layer` to understand what will be passed and when. 709 Note: tf.compat.v1.layers are usually not called with `training=None`, 710 so the training arg to `forward_pass` might not feed through to them 711 unless you pass it to their calls explicitly. 712 713 Call signature of the forward pass: 714 The semantics of the forward pass signature roughly match the standard 715 Keras layer `call` signature, except that a `training` arg will *always* 716 be passed, so your `forward_pass` must accept either. 717 718 Limitations: 719 * TF2 will not prune unused variable updates (or unused outputs). You may 720 need to adjust your forward pass code to avoid computations or variable 721 updates that you don't intend to use. (E.g. by adding a flag to the 722 `forward_pass` call signature and branching on it). 723 * Avoid Nesting variable creation in tf.function inside of `forward_pass` 724 While the layer may safetely be used from inside a `tf.function`, using 725 a function inside of `forward_pass` will break the variable scoping. 726 * TBD: Nesting keras layers/models or other `VariableScopeWrapperLayer`s 727 directly in `forward_pass` may not work correctly just yet. 728 Support for this/instructions for how to do this is sill being worked on. 729 730 Coming soon: A better guide, testing/verification guide. 731 """ 732 733 def __init__(self, **kwargs): 734 super().__init__(**kwargs) 735 # Relies on keras layers tracking Modules 736 self.tracker = VariableAndLossTracker() 737 # May need to inspect func to see if it should pass a `training` arg or not 738 739 def forward_pass(self, *args, **kwargs): 740 raise NotImplementedError 741 742 def call(self, *args, **kwargs): 743 with self.tracker.scope(): 744 out = self.forward_pass(*args, **kwargs) 745 if not self._eager_losses: 746 # We have to record regularization losses in the call as if they 747 # are activity losses. 748 # So, don't double-count regularization losses if the layer is used 749 # multiple times in a model 750 for loss in self.tracker.get_regularization_losses().values(): 751 self.add_loss(loss) 752 return out 753