1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains private utilities used mainly by the base Layer class.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import functools 21import threading 22 23from tensorflow.python import tf2 24from tensorflow.python.distribute import distribution_strategy_context 25from tensorflow.python.eager import context 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.keras import backend 32from tensorflow.python.keras.utils import control_flow_util 33from tensorflow.python.keras.utils import tf_inspect 34from tensorflow.python.keras.utils import tf_utils 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import variables as tf_variables 37from tensorflow.python.ops.ragged import ragged_tensor 38from tensorflow.python.training.tracking import base as tracking 39from tensorflow.python.util import keras_deps 40from tensorflow.python.util import nest 41from tensorflow.python.util.tf_export import keras_export 42 43_call_context = threading.local() 44 45 46def create_mean_metric(value, name=None): 47 # import keras will import base_layer and then this module, and metric relies 48 # on base_layer, which result into a cyclic dependency. 49 from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top 50 metric_obj = metrics_module.Mean(name=name, dtype=value.dtype) 51 return metric_obj, metric_obj(value) 52 53 54def make_variable(name, 55 shape=None, 56 dtype=dtypes.float32, 57 initializer=None, 58 trainable=None, 59 caching_device=None, 60 validate_shape=True, 61 constraint=None, 62 use_resource=None, 63 collections=None, 64 synchronization=tf_variables.VariableSynchronization.AUTO, 65 aggregation=tf_variables.VariableAggregation.NONE, 66 partitioner=None): # pylint: disable=unused-argument 67 """Temporary util to create a variable (relies on `variable_scope.variable`). 68 69 Some reuse-related technicalities prevent us from using 70 `variable_scope.get_variable()` directly, so we use a subcomponent 71 that has fewer constraints (`variable_scope.variable()`). 72 73 In the longer term, it seems like a similar "default variable creator" method 74 should exist in `Trackable` instead. When this happens, we can get 75 rid of this temporary solution. 76 77 TODO(fchollet): remove this method when no longer needed. 78 79 Args: 80 name: Variable name. 81 shape: Variable shape. 82 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 83 initializer: Initializer instance (callable). 84 trainable: Whether the variable should be part of the layer's 85 "trainable_variables" (e.g. variables, biases) 86 or "non_trainable_variables" (e.g. BatchNorm mean, stddev). 87 Note, if the current variable scope is marked as non-trainable 88 then this parameter is ignored and any added variables are also 89 marked as non-trainable. `trainable` defaults to `True` unless 90 `synchronization` is set to `ON_READ`. 91 caching_device: Passed to `tf.Variable`. 92 validate_shape: Passed to `tf.Variable`. 93 constraint: Constraint instance (callable). 94 use_resource: Whether to use a `ResourceVariable`. 95 collections: List of graph collections keys. The new variable is added to 96 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 97 synchronization: Indicates when a distributed a variable will be 98 aggregated. Accepted values are constants defined in the class 99 `tf.VariableSynchronization`. By default the synchronization is set to 100 `AUTO` and the current `DistributionStrategy` chooses 101 when to synchronize. If `synchronization` is set to `ON_READ`, 102 `trainable` must not be set to `True`. 103 aggregation: Indicates how a distributed variable will be aggregated. 104 Accepted values are constants defined in the class 105 `tf.VariableAggregation`. 106 partitioner: Not handled at this time. 107 108 Returns: 109 Variable instance. 110 """ 111 initializing_from_value = False 112 if initializer is not None and not callable(initializer): 113 initializing_from_value = True 114 115 if initializing_from_value: 116 init_val = initializer 117 variable_dtype = None 118 else: 119 # Instantiate initializer if provided initializer is a type object. 120 if tf_inspect.isclass(initializer): 121 initializer = initializer() 122 init_val = functools.partial(initializer, shape, dtype=dtype) 123 variable_dtype = dtype.base_dtype 124 if use_resource is None: 125 use_resource = True 126 127 # TODO(apassos,rohanj) figure out how to remove collections from here so we 128 # can remove the V1. 129 variable_shape = tensor_shape.TensorShape(shape) 130 return tf_variables.VariableV1( 131 initial_value=init_val, 132 name=name, 133 trainable=trainable, 134 caching_device=caching_device, 135 dtype=variable_dtype, 136 validate_shape=validate_shape, 137 constraint=constraint, 138 use_resource=use_resource, 139 collections=collections, 140 synchronization=synchronization, 141 aggregation=aggregation, 142 shape=variable_shape if variable_shape else None) 143 144 145def collect_previous_mask(input_tensors): 146 """Retrieves the output mask(s) of the previous node. 147 148 Args: 149 input_tensors: An arbitrary structure of Tensors. 150 151 Returns: 152 A mask tensor or list of mask tensors. 153 """ 154 155 def _collect_previous_mask(x): 156 return getattr(x, '_keras_mask', None) 157 158 return nest.map_structure(_collect_previous_mask, input_tensors) 159 160 161def have_all_keras_metadata(tensors): 162 return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors)) 163 164 165def generate_placeholders_from_shape(shape): 166 return array_ops.placeholder(shape=shape, dtype=backend.floatx()) 167 168 169def create_keras_history(tensors): 170 """Wraps TensorFlow Operations for compatibility with the Functional API. 171 172 This method checks to see if a Tensor in `tensors` is missing Keras metadata 173 and has its origin in a Keras `Input` Layer. If so, this method will replace 174 the raw TensorFlow Operations that created this tensor with 175 `TensorFlowOpLayer` instances that create identical operations. 176 177 Any Tensors not originating from a Keras `Input` Layer will be treated as 178 constants when constructing `TensorFlowOpLayer` instances. 179 180 Args: 181 tensors: A structure of Tensors, some of which come from raw TensorFlow 182 operations and need to have Keras metadata assigned to them. 183 184 Returns: 185 created_layers: List. The `TensorFlowOpLayer` instances created to wrap 186 the raw Tensorflow operations. 187 """ 188 _, created_layers = _create_keras_history_helper(tensors, set(), []) 189 return created_layers 190 191 192# Unsafe Internal attribute. 193# If True, Keras will not evaluate the constant-foldable inputs to tf op 194# layers in TF1 graphs. This *might* speed up model construction time in 195# certain settings, but it means 196# the models will not be serializable/deserializable via get_config 197# (Only via Savedmodels). It may also change the semantics of whether 198# generated random numbers are generated once and re-used, or recomputed 199# each time. 200# Note: This path triggers for TPUEstimators / xla compiled graphs regardless 201# of this setting. 202_UNSAFE_GRAPH_OP_LAYER_CREATION = False 203 204 205def _create_keras_history_helper(tensors, processed_ops, created_layers): 206 """Helper method for `create_keras_history`. 207 208 Args: 209 tensors: A structure of Tensors for which to create Keras metadata. 210 processed_ops: Set. TensorFlow operations that have already been wrapped in 211 `TensorFlowOpLayer` instances. 212 created_layers: List. The `TensorFlowOpLayer` instances created. 213 214 Returns: 215 Tuple. First element is the updated set of TensorFlow Operations that 216 have been wrapped in `TensorFlowOpLayer` instances. Second element is 217 a list of the `TensorFlowOpLayer` instances created. 218 """ 219 # Import of `base_layer` needed in order to create `TensorFlowOpLayer`. 220 # Cannot be imported at top because of circular dependencies. 221 # TODO(omalleyt): Resolve circular dependency. 222 from tensorflow.python.keras.engine import base_layer # pylint: disable=g-import-not-at-top 223 tensor_list = nest.flatten(tensors) 224 sparse_ops = [] 225 ragged_tensors = [] 226 for tensor in tensor_list: 227 if getattr(tensor, '_keras_history', None) is not None: 228 continue 229 if isinstance( 230 tensor, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): 231 sparse_ops.append(tensor.op) 232 continue 233 if tf_utils.is_ragged(tensor): 234 # Ragged tensors don't have an op property 235 ragged_tensors.append(tensor) 236 continue 237 op = tensor.op # The Op that created this Tensor. 238 if op not in processed_ops: 239 # Recursively set `_keras_history`. 240 op_inputs = list(op.inputs) 241 constants = {} 242 layer_inputs = [] 243 for i, op_input in enumerate(op_inputs): 244 if uses_keras_history(op_input): 245 layer_inputs.append(op_input) 246 else: 247 # Treat any value not originating from a `keras.Input` as 248 # a constant. Variables cannot be supported. 249 ds_with_session = ( 250 distribution_strategy_context.in_cross_replica_context() and 251 not ops.executing_eagerly_outside_functions()) 252 using_xla = control_flow_util.GraphOrParentsInXlaContext( 253 ops.get_default_graph()) 254 if ds_with_session or using_xla or _UNSAFE_GRAPH_OP_LAYER_CREATION: 255 # In Legacy Graph mode, evaluating here makes Session be 256 # configured improperly. The downside of this is that saving 257 # via `get_config` breaks, but SavedModel still works. 258 constants[i] = op_input 259 else: 260 with ops.init_scope(): 261 if ops.executing_eagerly_outside_functions(): 262 constants[i] = backend.eval_in_eager_or_function(op_input) 263 else: 264 constants[i] = backend.function([], op_input)([]) 265 layer_inputs = unnest_if_single_tensor(layer_inputs) 266 processed_ops, created_layers = _create_keras_history_helper( 267 layer_inputs, processed_ops, created_layers) 268 name = op.name 269 node_def = op.node_def.SerializeToString() 270 op_layer = base_layer.TensorFlowOpLayer( 271 node_def, constants=constants, name=name) 272 created_layers.append(op_layer) 273 op_layer._set_connectivity_metadata( # pylint: disable=protected-access 274 args=(layer_inputs,), 275 kwargs={}, 276 outputs=op.outputs) 277 processed_ops.update([op]) 278 if sparse_ops or ragged_tensors: 279 lambda_example = """ 280 weights_mult = lambda x: tf.sparse.sparse_dense_matmul(x, weights) 281 output = tf.keras.layers.Lambda(weights_mult)(input) 282 """ 283 raise ValueError( 284 'Tensorflow ops that generate ragged or sparse tensor ' 285 'outputs are currently not supported by Keras automatic ' 286 'op wrapping. Please wrap these ops in a Lambda layer: ' 287 '\n\n```\n{example}\n```\n' 288 'Sparse ops encountered: {sparse_ops}\n' 289 'Ragged tensors encountered: {ragged_tensors}\n'.format( 290 example=lambda_example, 291 sparse_ops=str(sparse_ops), 292 ragged_tensors=str(ragged_tensors))) 293 return processed_ops, created_layers 294 295 296def unnest_if_single_tensor(input_tensors): 297 # Preserve compatibility with older configs 298 flat_input_tensors = nest.flatten(input_tensors) 299 # If this is a single element but not a dict, unwrap. If this is a dict, 300 # assume the first layer expects a dict (as is the case with a 301 # DenseFeatures layer); pass through. 302 if not isinstance(input_tensors, dict) and len(flat_input_tensors) == 1: 303 input_tensors = flat_input_tensors[0] 304 return input_tensors 305 306 307def needs_keras_history(tensors, ignore_call_context=False): 308 """Check if any Tensors need to be wrapped in TensorFlowOpLayers. 309 310 This will never return True inside a sublayer, because sublayers 311 do not need to create Keras History. Otherwise, this returns True 312 if one or more of `tensors` originates from a `keras.Input` and 313 does not have `_keras_history` set. 314 315 Args: 316 tensors: An arbitrary nested structure of Tensors. 317 ignore_call_context: Whether to ignore the check of if currently 318 outside of a `call` context. This is `True` when creating 319 KerasHistory inside `Node`, where we always know that Tensors 320 are being used with the Functional API. 321 322 Returns: 323 Bool, whether at least one Tensor needs to be wrapped. 324 """ 325 input_tensors = nest.flatten(tensors) 326 if call_context().in_call and not ignore_call_context: 327 return False 328 if all( 329 getattr(tensor, '_keras_history', None) is not None 330 for tensor in input_tensors): 331 # KerasHistory already set. 332 return False 333 return uses_keras_history(tensors) 334 335 336def is_in_keras_graph(): 337 """Returns if currently executing inside of a Keras graph.""" 338 return call_context().in_keras_graph 339 340 341def is_in_eager_or_tf_function(): 342 """Returns if in eager mode or inside of a tf.function.""" 343 return context.executing_eagerly() or is_in_tf_function() 344 345 346def is_in_tf_function(): 347 """Returns if inside of a tf.function.""" 348 # Check if running in V1 graph mode. 349 if not ops.executing_eagerly_outside_functions(): 350 return False 351 if not ops.inside_function(): 352 return False 353 # Check if inside Keras FuncGraph. 354 if is_in_keras_graph(): 355 return False 356 # Check for a v1 `wrap_function` FuncGraph. 357 graph = ops.get_default_graph() 358 if (getattr(graph, 'name', False) and 359 graph.name.startswith('wrapped_function')): 360 return False 361 return True 362 363 364def uses_keras_history(tensors): 365 """Check if at least one Tensor originates from a `keras.Input`. 366 367 This is `True` if at least one Tensor has its origin in a `keras.Input`. 368 Any Tensor that originates from a `keras.Input` will have a dependency 369 Tensor with a `_keras_history` attribute attached. Tensors that have 370 already been checked to not originate from a `keras.Input` 371 are marked as `_keras_history_checked`. 372 373 Args: 374 tensors: An arbitrary nested structure of Tensors. 375 376 Returns: 377 Bool, whether at least one Tensor originates from a `keras.Input`. 378 """ 379 checked_tensors = set() 380 tensors_to_check = nest.flatten(tensors) 381 382 while tensors_to_check: 383 new_tensors_to_check = [] 384 for tensor in tensors_to_check: 385 if id(tensor) in checked_tensors: 386 continue 387 388 checked_tensors.add(id(tensor)) 389 390 if getattr(tensor, '_keras_history_checked', None) is not None: 391 continue 392 if getattr(tensor, '_keras_history', None) is not None: 393 return True 394 395 try: 396 new_tensors_to_check.extend(tensor.op.inputs) 397 except AttributeError: 398 # In case `tensor` is a Variable created in an Eager context. 399 pass 400 401 tensors_to_check = new_tensors_to_check 402 403 # Mark that these Tensors have been checked once for `_keras_history`, 404 # and should not be checked again for performance reasons. 405 mark_checked(tensors) 406 return False 407 408 409def mark_checked(tensors): 410 """Marks that these Tensors should not be tracked. 411 412 This prevents Layers from attempting to create TensorFlowOpLayers 413 for these Tensors. 414 415 Args: 416 tensors: An arbitrary structure of Tensors. 417 """ 418 419 def _mark_checked(tensor): 420 tensor._keras_history_checked = True # pylint: disable=protected-access 421 422 nest.map_structure(_mark_checked, tensors) 423 424 425def call_context(): 426 """Returns currently active `CallContext`.""" 427 call_ctx = getattr(_call_context, 'call_context', None) 428 if call_ctx is None: 429 call_ctx = CallContext() 430 _call_context.call_context = call_ctx 431 return call_ctx 432 433 434# Inject the call_context function to keras_deps to remove the dependency 435# from TFLite to Keras. 436keras_deps.register_call_context_function(call_context) 437 438 439class CallContext(object): 440 """Keeps track of properties currently inside a Layer/Model's `call`. 441 442 Attributes: 443 in_call: Whether currently inside the `call` of a Layer. 444 layer: The `Layer` whose `call` is currently active. 445 inputs: The inputs to the currently active `Layer`. 446 build_graph: Whether currently inside a Graph or FuncGraph. 447 training: Whether currently executing in training or inference mode. 448 saving: Whether currently saving to SavedModel. 449 frozen: Whether currently executing inside a `Layer` with `trainable` set to 450 `False`. 451 in_keras_graph: Whether executing inside the Keras Graph. 452 """ 453 454 def __init__(self): 455 # Handle `in_call` separately as it is the most-read attr and reading it is 456 # on the hot path. 457 self.in_call = False 458 self._state = { 459 'layer': None, 460 'inputs': None, 461 'build_graph': False, 462 'training': None, 463 'saving': None 464 } 465 # TODO(b/150169018): This logic can be replaced after the Functional API 466 # refactor. 467 self._in_keras_graph = False 468 469 def enter(self, layer, inputs, build_graph, training, saving=None): 470 """Push a Layer and its inputs and state onto the current call context. 471 472 Args: 473 layer: The `Layer` whose `call` is currently active. 474 inputs: The inputs to the currently active `Layer`. 475 build_graph: Whether currently inside a Graph or FuncGraph. 476 training: Whether currently executing in training or inference mode. 477 saving: Whether currently saving to SavedModel. 478 479 Returns: 480 Context manager. 481 """ 482 state = { 483 'layer': layer, 484 'inputs': inputs, 485 'build_graph': build_graph, 486 'training': training, 487 'saving': saving 488 } 489 return CallContextManager(self, state) 490 491 @property 492 def layer(self): 493 return self._state['layer'] 494 495 @property 496 def inputs(self): 497 return self._state['inputs'] 498 499 @property 500 def build_graph(self): 501 return self._state['build_graph'] 502 503 @property 504 def training(self): 505 return self._state['training'] 506 507 @property 508 def saving(self): 509 return self._state['saving'] 510 511 @property 512 def frozen(self): 513 layer = self._state['layer'] 514 if not layer: 515 return False 516 return not layer.trainable 517 518 @property 519 def in_keras_graph(self): 520 # Returns True even if in a subgraph of the Keras graph, such as those 521 # created by control flow ops. 522 if context.executing_eagerly(): 523 return False 524 return (self._in_keras_graph or 525 getattr(backend.get_graph(), 'name', None) == 'keras_graph') 526 527 528class CallContextManager(object): 529 """Context manager for `CallContext`.""" 530 531 def __init__(self, call_ctx, state): 532 self._call_ctx = call_ctx 533 self._state = state 534 self._build_graph = state['build_graph'] 535 536 def __enter__(self): 537 call_ctx = self._call_ctx 538 self._prev_in_call = call_ctx.in_call 539 self._prev_state = call_ctx._state 540 541 call_ctx.in_call = True 542 call_ctx._state = self._state 543 544 # TODO(b/150169018): This logic can be removed after the Functional API 545 # refactor. 546 if self._build_graph: 547 self._prev_in_keras_graph = call_ctx._in_keras_graph 548 call_ctx._in_keras_graph = ( 549 call_ctx._in_keras_graph or 550 getattr(backend.get_graph(), 'name', None) == 'keras_graph') 551 552 def __exit__(self, *exc_info): 553 call_ctx = self._call_ctx 554 call_ctx.in_call = self._prev_in_call 555 call_ctx._state = self._prev_state 556 557 if self._build_graph: 558 call_ctx._in_keras_graph = self._prev_in_keras_graph 559 560 561def training_arg_passed_to_call(argspec, args, kwargs): 562 """Returns whether a user passed the `training` argument in `__call__`.""" 563 # `argspec.args` starts with ['self', 'inputs'] 564 full_args = dict(zip(argspec.args[2:], args)) 565 full_args.update(kwargs) 566 return 'training' in full_args and full_args['training'] is not None 567 568 569def is_subclassed(layer): 570 """Returns True if the object is a subclassed layer or subclassed model.""" 571 return (layer.__module__.find('keras.engine') == -1 and 572 layer.__module__.find('keras.layers') == -1) 573 574 575def from_saved_model(layer): 576 """Returns whether the layer is loaded from a SavedModel.""" 577 return layer.__module__.find('keras.saving.saved_model') != -1 578 579 580def check_graph_consistency(tensor=None, method='add_loss', force_raise=False): 581 """Checks that tensors passed to `add_*` method match the Keras graph. 582 583 When one of the `add_*` method is called inside a V2 conditional branch, 584 the underlying tensor gets created in a FuncGraph managed by control_flow_v2. 585 We need to raise clear error messages in such cases. 586 587 Args: 588 tensor: Tensor to check, or `False` if it is known that an error 589 should be raised. 590 method: Caller method, one of {'add_metric', 'add_loss', 'add_update'}. 591 force_raise: If an error should be raised regardless of `tensor`. 592 593 Raises: 594 RuntimeError: In case of an out-of-graph tensor. 595 """ 596 if (force_raise or 597 (ops.executing_eagerly_outside_functions() and 598 hasattr(tensor, 'graph') and tensor.graph.is_control_flow_graph)): 599 if method == 'activity_regularizer': 600 bad_example = """ 601 class TestModel(tf.keras.Model): 602 603 def __init__(self): 604 super(TestModel, self).__init__(name='test_model') 605 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2') 606 607 def call(self, x, training=None): 608 if training: 609 return self.dense(x) 610 else: 611 return self.dense(x) 612 """ 613 correct_example = """ 614 class TestModel(tf.keras.Model): 615 616 def __init__(self): 617 super(TestModel, self).__init__(name='test_model') 618 self.dense = tf.keras.layers.Dense(2, activity_regularizer='l2') 619 620 def call(self, x, training=None): 621 return self.dense(x) 622 """ 623 raise RuntimeError( 624 'You are using a layer with `activity_regularizer` in a control flow ' 625 'branch, e.g.:\n{bad_example}\nThis is currently not supported. ' 626 'Please move your call to the layer with `activity_regularizer` out ' 627 'of the control flow branch, e.g.:\n{correct_example}\n' 628 'You can also resolve this by marking your outer model/layer dynamic' 629 ' (eager-only) by passing `dynamic=True` to the layer constructor. ' 630 'Any kind of control flow is supported with dynamic layers. ' 631 'Note that using `dynamic=True` requires you to implement static ' 632 'shape inference in the `compute_output_shape(input_shape)` ' 633 'method.'.format( 634 bad_example=bad_example, correct_example=correct_example)) 635 636 if method == 'add_metric': 637 bad_example = """ 638 def call(self, inputs, training=None): 639 if training: 640 metric = compute_metric(inputs) 641 self.add_metric(metric, name='my_metric', aggregation='mean') 642 return inputs 643 """ 644 correct_example = """ 645 def call(self, inputs, training=None): 646 if training: 647 metric = compute_metric(inputs) 648 else: 649 metric = 0. 650 self.add_metric(metric, name='my_metric', aggregation='mean') 651 return inputs 652 """ 653 elif method == 'add_loss': 654 bad_example = """ 655 def call(self, inputs, training=None): 656 if training: 657 loss = compute_loss(inputs) 658 self.add_loss(loss) 659 return inputs 660 """ 661 correct_example = """ 662 def call(self, inputs, training=None): 663 if training: 664 loss = compute_loss(inputs) 665 else: 666 loss = 0. 667 self.add_loss(loss) 668 return inputs 669 """ 670 else: 671 bad_example = """ 672 def call(self, inputs, training=None): 673 if training: 674 self.add_update(self.w.assign_add(1)) 675 return inputs 676 """ 677 correct_example = """ 678 def call(self, inputs, training=None): 679 if training: 680 increment = 1 681 else: 682 increment = 0 683 self.add_update(self.w.assign_add(increment)) 684 return inputs 685 """ 686 raise RuntimeError( 687 'You are using the method `{method}` in a control flow branch ' 688 'in your layer, e.g.:\n{bad_example}\n' 689 'This is not currently supported. ' 690 'Please move your call to {method} out of the control flow branch, ' 691 'e.g.:\n{correct_example}\n' 692 'You can also resolve this by marking your layer ' 693 'as dynamic (eager-only) by passing ' 694 '`dynamic=True` to the layer constructor. ' 695 'Any kind of control flow is supported with dynamic layers. ' 696 'Note that using `dynamic=True` requires you ' 697 'to implement static shape inference ' 698 'in the `compute_output_shape(input_shape)` method.'.format( 699 method=method, 700 bad_example=bad_example, 701 correct_example=correct_example)) 702 703 704def mark_as_return(outputs, acd): 705 """Marks `outputs` as the return values for automatic control deps.""" 706 707 def _mark_as_return(tensor): 708 """Marks `tensor` as the return value for automatic control deps.""" 709 if not tensor_util.is_tf_type(tensor): 710 return tensor 711 712 # pylint: disable=protected-access 713 return_tensor = acd.mark_as_return(tensor) 714 if getattr(tensor, '_keras_mask', None) is not None: 715 return_tensor._keras_mask = acd.mark_as_return(tensor._keras_mask) 716 else: 717 return_tensor._keras_mask = None 718 719 # Handle TensorFlow Probability attached metadata. 720 # TODO(b/132076537): Remove this once TFP uses `CompositeTensor`. 721 if getattr(tensor, '_tfp_distribution', None) is not None: 722 return_tensor._tfp_distribution = tensor._tfp_distribution 723 724 return return_tensor 725 # pylint: enable=protected-access 726 727 return nest.map_structure(_mark_as_return, outputs) 728 729 730V2_DTYPE_BEHAVIOR = None 731 732 733@keras_export(v1=['keras.layers.enable_v2_dtype_behavior']) 734def enable_v2_dtype_behavior(): 735 """Enable the V2 dtype behavior for Keras layers. 736 737 By default, the V2 dtype behavior is enabled in TensorFlow 2, so this function 738 is only useful if `tf.compat.v1.disable_v2_behavior` has been called. Since 739 mixed precision requires V2 dtype behavior to be enabled, this function allows 740 you to use mixed precision in Keras layers if `disable_v2_behavior` has been 741 called. 742 743 When enabled, the dtype of Keras layers defaults to floatx (which is typically 744 float32) instead of None. In addition, layers will automatically cast 745 floating-point inputs to the layer's dtype. 746 747 >>> x = tf.ones((4, 4, 4, 4), dtype='float64') 748 >>> layer = tf.keras.layers.Conv2D(filters=4, kernel_size=2) 749 >>> print(layer.dtype) # float32 since V2 dtype behavior is enabled 750 float32 751 >>> y = layer(x) # Layer casts inputs since V2 dtype behavior is enabled 752 >>> print(y.dtype.name) 753 float32 754 755 A layer author can opt-out their layer from the automatic input casting by 756 passing `autocast=False` to the base Layer's constructor. This disables the 757 autocasting part of the V2 behavior for that layer, but not the defaulting to 758 floatx part of the V2 behavior. 759 760 When a global `tf.keras.mixed_precision.Policy` is set, a Keras layer's dtype 761 will default to the global policy instead of floatx. Layers will automatically 762 cast inputs to the policy's compute_dtype. 763 """ 764 global V2_DTYPE_BEHAVIOR 765 V2_DTYPE_BEHAVIOR = True 766 767 768@keras_export(v1=['keras.layers.disable_v2_dtype_behavior']) 769def disable_v2_dtype_behavior(): 770 """Disables the V2 dtype behavior for Keras layers. 771 772 See `tf.compat.v1.keras.layers.enable_v2_dtype_behavior`. 773 """ 774 global V2_DTYPE_BEHAVIOR 775 V2_DTYPE_BEHAVIOR = False 776 777 778def v2_dtype_behavior_enabled(): 779 """Returns True if the V2 dtype behavior is enabled.""" 780 if V2_DTYPE_BEHAVIOR is None: 781 return tf2.enabled() 782 return V2_DTYPE_BEHAVIOR 783 784 785class TrackableWeightHandler(object): 786 """Keras wrapper for handling tracking.Trackable object saving and restoring. 787 788 This class handles Trackables in both V1 and V2 modes, ensuring that they can 789 be saved and restored with the correct data and without adding additional ops 790 on every save. 791 792 Attributes: 793 trackable: The trackable to wrap. 794 num_tensors: The number of tensors that this trackable requires for saving. 795 """ 796 797 def __init__(self, trackable): 798 if not isinstance(trackable, tracking.Trackable): 799 raise ValueError('%s is not a Trackable object.' % (trackable,)) 800 self._trackable = trackable 801 self._distribute_strategy = distribution_strategy_context.get_strategy() 802 803 # TODO(b/141682913): Figure out why this is private and fix it. 804 saveables = trackable._gather_saveables_for_checkpoint().values() # pylint: disable=protected-access 805 # 'Saveables' won't exist when we're passed a legacy TF1 table like 806 # a StaticHashTable. 807 if not saveables: 808 self._num_tensors = 0 809 self._setter = lambda weights: None 810 self._getter = lambda: [] 811 812 elif len(saveables) == 1: 813 saveable = list(saveables)[0] 814 815 if ops.executing_eagerly_outside_functions(): 816 # If we're in eager mode, we need to defer calling the Trackable's 817 # saveable() callable until data export time. 818 # However, it is safe to call the saveable as many times as we want, so 819 # we will call it now to figure out how many tensors this Trackable will 820 # produce. 821 self._saveable = saveable 822 self._num_tensors = len(self._saveable().specs) 823 self._setter = lambda weights: self._saveable().restore(weights, None) 824 self._getter = lambda: [spec.tensor for spec in self._saveable().specs] 825 else: 826 # If we're in Graph mode, we need to evaluate the Saveable only once and 827 # cache the resulting restore graph. Failing to do this will result in 828 # new assignment ops being added to the graph each time set_weights() is 829 # called. 830 self._placeholder_tensors = [] 831 self._saveable = saveable() 832 self._num_tensors = len(self._saveable.specs) 833 for spec in self._saveable.specs: 834 tensor = spec.tensor 835 self._placeholder_tensors.append( 836 array_ops.placeholder(tensor.dtype, tensor.shape)) 837 self._assign_op = self._saveable.restore(self._placeholder_tensors, 838 None) 839 self._setter = self._set_weights_v1 840 self._getter = lambda: [spec.tensor for spec in self._saveable.specs] 841 else: 842 raise ValueError('Only Trackables with one Saveable are supported. ' 843 'The Trackable %s has %d Saveables.' % 844 (trackable, len(saveables))) 845 846 @property 847 def num_tensors(self): 848 return self._num_tensors 849 850 def set_weights(self, weights): 851 if len(weights) != self._num_tensors: 852 raise ValueError( 853 ('Weight handler for trackable %s received the wrong number of ' + 854 'weights: expected %s, got %s.') % 855 (self._trackable, self._num_tensors, len(weights))) 856 self._setter(weights) 857 858 def get_tensors(self): 859 return self._getter() 860 861 def _set_weights_v1(self, weights): 862 feed_dict = {} 863 for idx, tensor in enumerate(weights): 864 feed_dict[self._placeholder_tensors[idx]] = tensor 865 backend.get_session().run(self._assign_op, feed_dict) 866 867 868def no_ragged_support(inputs, layer_name): 869 input_list = nest.flatten(inputs) 870 if any(isinstance(x, ragged_tensor.RaggedTensor) for x in input_list): 871 raise ValueError('Layer %s does not support RaggedTensors as input. ' 872 'Inputs received: %s. You can try converting your ' 873 'input to an uniform tensor.' % (layer_name, inputs)) 874 875 876def is_split_variable(v): 877 """Returns True if `v` is either a PartionedVariable or a ShardedVariable.""" 878 return hasattr(v, '_variable_list') or hasattr(v, '_variables') 879 880 881def has_weights(obj): 882 obj_type = type(obj) 883 return (hasattr(obj_type, 'trainable_weights') and 884 hasattr(obj_type, 'non_trainable_weights') and 885 not isinstance(obj, type)) 886 887 888# TODO(kathywu): This is a temporary hack. When a network of layers is revived 889# from SavedModel, only the top-level layer will have losses. This causes issues 890# in eager mode because the child layers may have graph losses 891# (thus model.losses returns a mix of Eager and graph tensors). To fix this, 892# whenever eager losses are added to one layer, add eager losses to all 893# child layers. This causes `.losses` to only return eager losses. 894REVIVED_LOSS_PLACEHOLDER = ( 895 'This layer\'s losses have been added to the parent layer.') 896