1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains private utilities used mainly by the base Layer class.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections as collections_lib 21import threading 22import enum 23 24from tensorflow.python.eager import context 25from tensorflow.python.framework import auto_control_deps 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.keras import backend 29from tensorflow.python.keras.utils import tf_utils 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import control_flow_util 32from tensorflow.python.ops import init_ops 33from tensorflow.python.ops import init_ops_v2 34from tensorflow.python.ops import variables as tf_variables 35from tensorflow.python.util import nest 36from tensorflow.python.util import tf_contextlib 37 38_call_context = threading.local() 39 40 41class CallConvention(enum.Enum): 42 """Calling conventions for passing `Layer` inputs to `Layer.call`.""" 43 # The Layer takes inputs as its first argument, named "inputs" for 44 # compatibility with the signature of Layer.__call__. This is the mode assumed 45 # for Layers which are not subclassed Models. 46 EXPLICIT_INPUTS_ARGUMENT = 1 47 # The Layer takes a single positional argument, not named "inputs". It's 48 # treated like an "inputs" argument. 49 SINGLE_POSITIONAL_ARGUMENT = 2 50 # The Layer has multiple positional arguments to which its inputs should be 51 # bound. 52 POSITIONAL_ARGUMENTS_ARE_INPUTS = 3 53 54 55def create_mean_metric(value, name=None): 56 # TODO(psv): Remove this import when b/110718070 is fixed. 57 from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top 58 metric_obj = metrics_module.Mean(name=name) 59 result = metric_obj(value) 60 return metric_obj, result 61 62 63def make_variable(name, 64 shape=None, 65 dtype=dtypes.float32, 66 initializer=None, 67 trainable=None, 68 caching_device=None, 69 validate_shape=True, 70 constraint=None, 71 use_resource=None, 72 collections=None, 73 synchronization=tf_variables.VariableSynchronization.AUTO, 74 aggregation=tf_variables.VariableAggregation.NONE, 75 partitioner=None): # pylint: disable=unused-argument 76 """Temporary util to create a variable (relies on `variable_scope.variable`). 77 78 Some reuse-related technicalities prevent us from using 79 `variable_scope.get_variable()` directly, so we use a subcomponent 80 that has fewer constraints (`variable_scope.variable()`). 81 82 In the longer term, it seems like a similar "default variable creator" method 83 should exist in `Trackable` instead. When this happens, we can get 84 rid of this temporary solution. 85 86 TODO(fchollet): remove this method when no longer needed. 87 88 Arguments: 89 name: Variable name. 90 shape: Variable shape. 91 dtype: The type of the variable. Defaults to `self.dtype` or `float32`. 92 initializer: Initializer instance (callable). 93 trainable: Whether the variable should be part of the layer's 94 "trainable_variables" (e.g. variables, biases) 95 or "non_trainable_variables" (e.g. BatchNorm mean, stddev). 96 Note, if the current variable scope is marked as non-trainable 97 then this parameter is ignored and any added variables are also 98 marked as non-trainable. `trainable` defaults to `True` unless 99 `synchronization` is set to `ON_READ`. 100 caching_device: Passed to `tf.Variable`. 101 validate_shape: Passed to `tf.Variable`. 102 constraint: Constraint instance (callable). 103 use_resource: Whether to use a `ResourceVariable`. 104 collections: List of graph collections keys. The new variable is added to 105 these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. 106 synchronization: Indicates when a distributed a variable will be 107 aggregated. Accepted values are constants defined in the class 108 `tf.VariableSynchronization`. By default the synchronization is set to 109 `AUTO` and the current `DistributionStrategy` chooses 110 when to synchronize. If `synchronization` is set to `ON_READ`, 111 `trainable` must not be set to `True`. 112 aggregation: Indicates how a distributed variable will be aggregated. 113 Accepted values are constants defined in the class 114 `tf.VariableAggregation`. 115 partitioner: Not handled at this time. 116 117 Returns: 118 Variable instance. 119 """ 120 initializing_from_value = False 121 if initializer is not None and not callable(initializer): 122 initializing_from_value = True 123 124 with ops.init_scope(): 125 if initializing_from_value: 126 init_val = initializer 127 variable_dtype = None 128 else: 129 # Instantiate initializer if provided initializer is a type object. 130 if isinstance( 131 initializer, 132 (type(init_ops.Initializer), type(init_ops_v2.Initializer))): 133 initializer = initializer() 134 init_val = lambda: initializer(shape, dtype=dtype) 135 variable_dtype = dtype.base_dtype 136 if use_resource is None: 137 use_resource = True 138 139 # TODO(apassos,rohanj) figure out how to remove collections from here so we 140 # can remove the V1. 141 v = tf_variables.VariableV1( 142 initial_value=init_val, 143 name=name, 144 trainable=trainable, 145 caching_device=caching_device, 146 dtype=variable_dtype, 147 validate_shape=validate_shape, 148 constraint=constraint, 149 use_resource=use_resource, 150 collections=collections, 151 synchronization=synchronization, 152 aggregation=aggregation) 153 return v 154 155 156def get_default_graph_uid_map(): 157 # TODO(fchollet): refactor this into backend. 158 graph = ops.get_default_graph() 159 name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS.get(graph, None) 160 if name_uid_map is None: 161 name_uid_map = collections_lib.defaultdict(int) 162 backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map 163 return name_uid_map 164 165 166def unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace='', 167 zero_based=False): 168 """Makes a layer name (or arbitrary string) unique within a TensorFlow graph. 169 170 Arguments: 171 name: String name to make unique. 172 name_uid_map: An optional defaultdict(int) to use when creating unique 173 names. If None (default), uses a per-Graph dictionary. 174 avoid_names: An optional set or dict with names which should not be used. If 175 None (default) does not avoid any names. 176 namespace: Gets a name which is unique within the (graph, namespace). Layers 177 which are not Networks use a blank namespace and so get graph-global 178 names. 179 zero_based: If True, name sequences start with no suffix (e.g. "dense", 180 "dense_1"). If False, naming is one-based ("dense_1", "dense_2"). 181 182 Returns: 183 Unique string name. 184 185 Example: 186 187 ```python 188 _unique_layer_name('dense') # dense_1 189 _unique_layer_name('dense') # dense_2 190 ``` 191 """ 192 if name_uid_map is None: 193 name_uid_map = get_default_graph_uid_map() 194 if avoid_names is None: 195 avoid_names = set() 196 proposed_name = None 197 while proposed_name is None or proposed_name in avoid_names: 198 name_key = (namespace, name) 199 if zero_based: 200 number = name_uid_map[name_key] 201 if number: 202 proposed_name = name + '_' + str(number) 203 else: 204 proposed_name = name 205 name_uid_map[name_key] += 1 206 else: 207 name_uid_map[name_key] += 1 208 proposed_name = name + '_' + str(name_uid_map[name_key]) 209 return proposed_name 210 211 212def collect_previous_mask(input_tensors): 213 """Retrieves the output mask(s) of the previous node. 214 215 Arguments: 216 input_tensors: An arbitrary structure of Tensors. 217 218 Returns: 219 A mask tensor or list of mask tensors. 220 """ 221 222 def _collect_previous_mask(x): 223 return getattr(x, '_keras_mask', None) 224 225 return nest.map_structure(_collect_previous_mask, input_tensors) 226 227 228def have_all_keras_metadata(tensors): 229 return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors)) 230 231 232def generate_placeholders_from_shape(shape): 233 return array_ops.placeholder(shape=shape, dtype=backend.floatx()) 234 235 236def create_keras_history(tensors): 237 """Wraps TensorFlow Operations for compatibility with the Functional API. 238 239 This method checks to see if a Tensor in `tensors` is missing Keras metadata 240 and has its origin in a Keras `Input` Layer. If so, this method will replace 241 the raw TensorFlow Operations that created this tensor with 242 `TensorFlowOpLayer` instances that create identical operations. 243 244 Any Tensors not originating from a Keras `Input` Layer will be treated as 245 constants when constructing `TensorFlowOpLayer` instances. 246 247 Arguments: 248 tensors: A structure of Tensors, some of which come from raw TensorFlow 249 operations and need to have Keras metadata assigned to them. 250 """ 251 _create_keras_history_helper(tensors, set()) 252 253 254def _create_keras_history_helper(tensors, processed_ops=None): 255 """Helper method for `create_keras_history`. 256 257 Arguments: 258 tensors: A structure of Tensors for which to create Keras metadata. 259 processed_ops: Set. TensorFlow operations that have already been wrapped 260 in `TensorFlowOpLayer` instances. 261 262 Returns: 263 The updated set of TensorFlow Operations that have been wrapped 264 in `TensorFlowOpLayer` instances. 265 """ 266 # Import of `base_layer` needed in order to create `TensorFlowOpLayer`. 267 # Cannot be imported at top because of circular dependencies. 268 # TODO(omalleyt): Resolve circular dependency. 269 from tensorflow.python.keras.engine import base_layer # pylint: disable=g-import-not-at-top 270 tensor_list = nest.flatten(tensors) 271 for tensor in tensor_list: 272 if getattr(tensor, '_keras_history', None) is not None: 273 continue 274 op = tensor.op # The Op that created this Tensor. 275 if op not in processed_ops: 276 # Recursively set `_keras_history`. 277 op_inputs = list(op.inputs) 278 constants = {} 279 layer_inputs = [] 280 for i, op_input in enumerate(op_inputs): 281 if uses_keras_history(op_input): 282 layer_inputs.append(op_input) 283 else: 284 # Treat any value not originating from a `keras.Input` as 285 # a constant (Variables currently have `Placeholder` op type 286 # when originating from an eager context 287 # so can't be supported. 288 constants[i] = backend.function([], op_input)([]) 289 processed_ops = _create_keras_history_helper(layer_inputs, processed_ops) 290 name = op.name 291 node_def = op.node_def.SerializeToString() 292 op_layer = base_layer.TensorFlowOpLayer( 293 node_def, constants=constants, name=name) 294 op_layer._add_inbound_node( # pylint: disable=protected-access 295 layer_inputs, op.outputs) 296 processed_ops.update([op]) 297 return processed_ops 298 299 300def needs_keras_history(tensors): 301 """Check if any Tensors need to be wrapped in TensorFlowOpLayers. 302 303 This will never return True inside a sublayer, because sublayers 304 do not need to create Keras History. Otherwise, this returns True 305 if one or more of `tensors` originates from a `keras.Input` and 306 does not have `_keras_history` set. 307 308 Arguments: 309 tensors: An arbitrary nested structure of Tensors. 310 311 Returns: 312 Bool, whether at least one Tensor needs to be wrapped. 313 """ 314 input_tensors = nest.flatten(tensors) 315 if is_in_call_context() or all( 316 getattr(tensor, '_keras_history', None) is not None 317 for tensor in input_tensors): 318 # KerasHistory already set. 319 return False 320 return uses_keras_history(tensors) 321 322 323def is_in_call_context(): 324 """Returns true if inside of a model/layer '__call__'.""" 325 return getattr(_call_context, 'in_call', False) 326 327 328def uses_keras_history(tensors): 329 """Check if at least one Tensor originates from a `keras.Input`. 330 331 This is `True` if at least one Tensor has its origin in a `keras.Input`. 332 Any Tensor that originates from a `keras.Input` will have a dependency 333 Tensor with a `_keras_history` attribute attached. Tensors that have 334 already been checked to not originate from a `keras.Input` 335 are marked as `_keras_history_checked`. 336 337 Arguments: 338 tensors: An arbitrary nested structure of Tensors. 339 340 Returns: 341 Bool, whether at least one Tensor originates from a `keras.Input`. 342 """ 343 checked_tensors = set() 344 tensors_to_check = nest.flatten(tensors) 345 346 while tensors_to_check: 347 new_tensors_to_check = set() 348 for tensor in tensors_to_check: 349 if getattr(tensor, '_keras_history_checked', None) is not None: 350 continue 351 if getattr(tensor, '_keras_history', None) is not None: 352 return True 353 354 try: 355 new_tensors_to_check.update(tensor.op.inputs) 356 except AttributeError: 357 # In case `tensor` is a Variable created in an Eager context. 358 pass 359 360 checked_tensors.update(tensors_to_check) 361 tensors_to_check = list(new_tensors_to_check - checked_tensors) 362 363 # Mark that these Tensors have been checked once for `_keras_history`, 364 # and should not be checked again for performance reasons. 365 mark_checked(tensors) 366 return False 367 368 369def mark_checked(tensors): 370 """Marks that these Tensors should not be tracked. 371 372 This prevents Layers from attempting to create TensorFlowOpLayers 373 for these Tensors. 374 375 Arguments: 376 tensors: An arbitrary structure of Tensors. 377 """ 378 379 def _mark_checked(tensor): 380 tensor._keras_history_checked = True # pylint: disable=protected-access 381 382 nest.map_structure(_mark_checked, tensors) 383 384 385@tf_contextlib.contextmanager 386def call_context(): 387 """Scope that marks when we are currently inside a Layer/Model's `call`.""" 388 was_in_call = is_in_call_context() 389 _call_context.in_call = True 390 try: 391 yield 392 finally: 393 _call_context.in_call = was_in_call 394 395 396def training_arg_passed_to_call(argspec, args, kwargs): 397 """Returns whether a user passed the `training` argument in `__call__`.""" 398 # `argspec.args` starts with ['self', 'inputs'] 399 full_args = dict(zip(argspec.args[2:], args)) 400 full_args.update(kwargs) 401 return 'training' in full_args 402 403 404class AutoAddUpdates(object): 405 """Automatically track stateful ops with `add_update`. 406 407 This context manager is used to automatically add stateful ops to a Layer 408 or Model's `.updates`. This ensures that stateful ops are run in the Keras 409 training loop. It also allows for these stateful ops to be disabled by 410 setting `trainable=False`. 411 412 Example: 413 414 ``` 415 with AutoAddUpdates(layer, inputs) as auto_updates: 416 outputs = layer.call(inputs) 417 auto_updates.set_outputs(outputs) 418 ``` 419 420 Attributes: 421 layer: Layer or Model instance to add the updates to. 422 inputs: The inputs to this Layer or Model, to be used for input-conditional 423 updates. 424 outputs: The outputs of this Layer or Model. 425 """ 426 427 def __init__(self, layer, inputs): 428 self.layer = layer 429 self.inputs = inputs 430 self.outputs = [] 431 432 def set_outputs(self, outputs): 433 if self.outputs: 434 raise RuntimeError('`set_outputs` should only be called once on an' 435 '`AutoAddUpdates` instance.') 436 self.outputs = outputs 437 438 def __enter__(self): 439 # Only run in V2 Function mode. 440 if (context.executing_eagerly() or 441 not ops.executing_eagerly_outside_functions()): 442 return self 443 444 self._graph = ops.get_default_graph() 445 self._num_operations = len(self._graph.get_operations()) 446 return self 447 448 def __exit__(self, error_type, unused_value, unused_traceback): 449 if error_type: 450 # Allow errors that occurred inside this context manager to pass through 451 # normally. 452 return 453 454 # Only run in V2 Function mode. 455 if (context.executing_eagerly() or 456 not ops.executing_eagerly_outside_functions()): 457 return 458 459 if (self._graph is not ops.get_default_graph() or 460 self._graph.name != 'keras_graph'): 461 # Only auto-track updates when the Keras Graph is the only one used. 462 return 463 464 new_operations = self._graph.get_operations()[self._num_operations:] 465 new_stateful_ops = set() 466 467 # pylint: disable=protected-access 468 for op in new_operations: 469 # While loop is not supported in general for automatic control 470 # dependencies. 471 if control_flow_util.IsInWhileLoop(op): 472 continue 473 474 # Track stateful ops via `add_update`. 475 is_stateful_op = ( 476 op.type not in self._graph._registered_ops or 477 auto_control_deps.op_is_stateful( 478 self._graph._registered_ops[op.type])) 479 480 # Ignore ReadVariableOps as they are not needed to be run separately. 481 # This ensures existing Layers don't get extra updates. 482 if is_stateful_op and op.type != 'ReadVariableOp': 483 new_stateful_ops.add(op) 484 485 explicit_updates = set([ 486 u for u in self.layer._get_unfiltered_updates(check_trainable=False) 487 if not isinstance(u, tuple) 488 ]) 489 # pylint: enable=protected-access 490 491 # Don't add updates that will already be run by virtue of being consumed by 492 # other stateful ops or by the Layer's outputs. This ensures that existing 493 # Layers like `BatchNormalization` continue to return the same values for 494 # `.update` calls. 495 minimum_ops = set() 496 targets = new_stateful_ops.union( 497 set(nest.flatten(self.outputs)), explicit_updates) 498 for op in new_stateful_ops: 499 # Scrub any ops that are consumed by the outputs or other stateful ops. 500 reachable = tf_utils.get_reachable_from_inputs(op) 501 if not (targets - {op}).intersection(reachable): 502 minimum_ops.add(op) 503 new_stateful_ops = minimum_ops 504 505 # Don't double-track updates added via explicitly calling `add_update`. 506 # Also don't double-track updates already tracked in sublayers. 507 new_stateful_ops = new_stateful_ops - explicit_updates 508 509 # Decide whether to track as input-conditional or unconditional. 510 input_reachable_ops = tf_utils.get_reachable_from_inputs( 511 self.inputs, targets=new_stateful_ops) 512 unconditional_updates = new_stateful_ops - input_reachable_ops 513 conditional_updates = new_stateful_ops - unconditional_updates 514 515 if unconditional_updates: 516 self.layer.add_update(list(unconditional_updates)) 517 if conditional_updates: 518 self.layer.add_update(list(conditional_updates), inputs=self.inputs) 519 520 521def _get_var_read_dtype(input_list, should_cast): 522 """Gets the dtype that AutoCastVariables should be read in.""" 523 if should_cast and input_list and input_list[0].dtype.is_floating: 524 return input_list[0].dtype.base_dtype 525 else: 526 return None 527 528 529def autocast_context_manager(input_list, should_cast): 530 """Returns a context manager to autocast AutoCastVariables. 531 532 Under this context manager, if `should_cast` is True, AutoCastVariables will 533 be casted. If `should_cast` is False, AutoCastVariables will not be casted, 534 which can be used to disable autocasting if nested under another 535 call to `autocast_context_manager`. 536 537 Args: 538 input_list: The inputs to the layer with the AutoCastVariables. 539 should_cast: Whether AutoCastVariables should be casted. 540 541 Returns: 542 A context manager to automatically cast AutoCastVariables. 543 """ 544 var_read_dtype = _get_var_read_dtype(input_list, should_cast) 545 return ops.get_default_graph()._enable_auto_casting_variables( # pylint: disable=protected-access 546 var_read_dtype) 547