1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utilities for Keras classes with v1 and v2 versions.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import ops 23from tensorflow.python.keras.utils.generic_utils import LazyLoader 24 25# TODO(b/134426265): Switch back to single-quotes once the issue 26# with copybara is fixed. 27# pylint: disable=g-inconsistent-quotes 28training = LazyLoader( 29 "training", globals(), 30 "tensorflow.python.keras.engine.training") 31training_v1 = LazyLoader( 32 "training_v1", globals(), 33 "tensorflow.python.keras.engine.training_v1") 34base_layer = LazyLoader( 35 "base_layer", globals(), 36 "tensorflow.python.keras.engine.base_layer") 37base_layer_v1 = LazyLoader( 38 "base_layer_v1", globals(), 39 "tensorflow.python.keras.engine.base_layer_v1") 40callbacks = LazyLoader( 41 "callbacks", globals(), 42 "tensorflow.python.keras.callbacks") 43callbacks_v1 = LazyLoader( 44 "callbacks_v1", globals(), 45 "tensorflow.python.keras.callbacks_v1") 46 47 48# pylint: enable=g-inconsistent-quotes 49 50 51class ModelVersionSelector(object): 52 """Chooses between Keras v1 and v2 Model class.""" 53 54 def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument 55 use_v2 = should_use_v2() 56 cls = swap_class(cls, training.Model, training_v1.Model, use_v2) # pylint: disable=self-cls-assignment 57 return super(ModelVersionSelector, cls).__new__(cls) 58 59 60class LayerVersionSelector(object): 61 """Chooses between Keras v1 and v2 Layer class.""" 62 63 def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument 64 use_v2 = should_use_v2() 65 cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, use_v2) # pylint: disable=self-cls-assignment 66 return super(LayerVersionSelector, cls).__new__(cls) 67 68 69class TensorBoardVersionSelector(object): 70 """Chooses between Keras v1 and v2 TensorBoard callback class.""" 71 72 def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument 73 use_v2 = should_use_v2() 74 start_cls = cls 75 cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, 76 use_v2) 77 if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard: 78 # Since the v2 class is not a subclass of the v1 class, __init__ has to 79 # be called manually. 80 return cls(*args, **kwargs) 81 return super(TensorBoardVersionSelector, cls).__new__(cls) 82 83 84def should_use_v2(): 85 """Determine if v1 or v2 version should be used.""" 86 if context.executing_eagerly(): 87 return True 88 elif ops.executing_eagerly_outside_functions(): 89 # Check for a v1 `wrap_function` FuncGraph. 90 # Code inside a `wrap_function` is treated like v1 code. 91 graph = ops.get_default_graph() 92 if (getattr(graph, "name", False) and 93 graph.name.startswith("wrapped_function")): 94 return False 95 return True 96 else: 97 return False 98 99 100def swap_class(cls, v2_cls, v1_cls, use_v2): 101 """Swaps in v2_cls or v1_cls depending on graph mode.""" 102 if cls == object: 103 return cls 104 if cls in (v2_cls, v1_cls): 105 return v2_cls if use_v2 else v1_cls 106 107 # Recursively search superclasses to swap in the right Keras class. 108 new_bases = [] 109 for base in cls.__bases__: 110 if ((use_v2 and issubclass(base, v1_cls) 111 # `v1_cls` often extends `v2_cls`, so it may still call `swap_class` 112 # even if it doesn't need to. That being said, it may be the safest 113 # not to over optimize this logic for the sake of correctness, 114 # especially if we swap v1 & v2 classes that don't extend each other, 115 # or when the inheritance order is different. 116 or (not use_v2 and issubclass(base, v2_cls)))): 117 new_base = swap_class(base, v2_cls, v1_cls, use_v2) 118 else: 119 new_base = base 120 new_bases.append(new_base) 121 cls.__bases__ = tuple(new_bases) 122 return cls 123 124 125def disallow_legacy_graph(cls_name, method_name): 126 if not ops.executing_eagerly_outside_functions(): 127 error_msg = ( 128 "Calling `{cls_name}.{method_name}` in graph mode is not supported " 129 "when the `{cls_name}` instance was constructed with eager mode " 130 "enabled. Please construct your `{cls_name}` instance in graph mode or" 131 " call `{cls_name}.{method_name}` with eager mode enabled.") 132 error_msg = error_msg.format(cls_name=cls_name, method_name=method_name) 133 raise ValueError(error_msg) 134 135 136def is_v1_layer_or_model(obj): 137 return isinstance(obj, (base_layer_v1.Layer, training_v1.Model)) 138