1# Lint as: python2, python3 2# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests to improve Keras integration with tf.function.""" 17 18from absl.testing import parameterized 19 20import tensorflow as tf 21 22from tensorflow.python.platform import test 23from tensorflow.tools.consistency_integration_test.consistency_test_base import ConsistencyTestBase 24 25 26class KerasIntegrationTest(ConsistencyTestBase, parameterized.TestCase): 27 """Test cases for Keras integration with tf.function.""" 28 29 @parameterized.named_parameters([('_RunFunctionEagerly', True), 30 ('_RunFunctionNonEagerly', False)]) 31 def testVariableCreationKerasLayers(self, run_eagerly): 32 """Tests tf.function variable creation in Keras layers. 33 34 Bugs: b/184210116 35 Status: Known issue 36 (However, moving forward, we should support re-creating 37 `tf.Variables` inside tf.function for each trace. This test case 38 should pass eventually.) 39 Issue: `tf.Variable` creation in Keras layers causes 'non-first call 40 variable creation' error in a tf.function. 41 42 Error message: 43 "tf.function-decorated function tried to create variables on non-first 44 call." 45 46 Improve error message? Needed. (b/187847612) 47 48 Notes: 49 * Inconsistent behavior between eager and non-eager mode execution of the 50 tf.function. 51 * In non-eager mode (graph mode), double tracing (i.e. first one during 52 function tracing and second one in execution) causes variable creation in 53 non-first call error. 54 * This is an expected behavior as Keras's Dense layer creates variables. 55 * go/tf-mutable-refs is a work-in-progress, longer term project designed to 56 address this issue. 57 58 Args: 59 run_eagerly: Boolean deciding whether to run tf.function decorated 60 functions eagerly or not. 61 """ 62 self.skipTest('b/184210116') 63 64 try: 65 original_setting = tf.config.functions_run_eagerly() 66 tf.config.run_functions_eagerly(run_eagerly) 67 68 @tf.function 69 def f(x): 70 layer = tf.keras.layers.Dense(2)(x) 71 return layer 72 73 if run_eagerly: 74 self.assertAllEqual( 75 tf.constant([[0.7891873, -0.5761101], [1.7832438, -1.6489036]], 76 dtype=tf.float32), f(tf.constant([[1., 2.], [3., 4.]]))) 77 else: 78 f(tf.constant([[1., 2.], [3., 4.]])) 79 80 finally: 81 tf.config.run_functions_eagerly(original_setting) 82 83 def testVariableCreationKerasLayersRecommended(self): 84 """Tests the recommended way of creating Keras layers in tf.function. 85 86 Bugs: b/184210116 87 Status: Working as intended 88 Issue: n/a 89 90 Error message: n/a 91 92 Notes: 93 * The suggested way of going about the problematic case written in 94 `testVariableCreationKerasLayers` test case. 95 """ 96 layer = None 97 98 @tf.function 99 def f(x): 100 nonlocal layer 101 if layer is None: 102 layer = tf.keras.layers.Dense(2) 103 return layer(x) 104 105 self.assertAllEqual( 106 f(tf.constant([[1., 2.], [3., 4.]])), 107 tf.constant([[0.7891873, -0.5761101], [1.7832438, -1.6489036]])) 108 109 @parameterized.named_parameters([('_RunFunctionEagerly', True), 110 ('_RunFunctionNonEagerly', False)]) 111 def testRetracingKerasOptimAsPythonObj(self, run_eagerly): 112 """Tests tf.function variable creation in Keras optimizers. 113 114 Bugs: b/184210116 115 Status: Working as intended 116 (However, moving forward, we should support re-creating 117 `tf.Variables` inside tf.function for each trace. This test case 118 should pass eventually.) 119 Issue: Passing in different Keras optimizers (Python objects) to 120 tf.function is not allowed as they create 121 `tf.Variable`s and will result in 'non-first call variable creation' 122 error. 123 124 Error message: 125 "tf.function-decorated function tried to create variables on non-first 126 call." 127 128 Notes: 129 * Inconsistent behavior between eager and non-eager mode execution of the 130 tf.function. 131 * go/tf-mutable-refs is a work-in-progress, longer term project designed to 132 address this issue. 133 * `trace` has three '#training' strings (before erroring out in the last 134 `train_one_step()` call) when two is generally expected. Why? 135 Answer: First `train_one_step` call traces twice because, after the first 136 trace, tf.function detects `tf.Variable` creation and immediately traces a 137 second time to see whether new variables are being created. (This has 138 been a common source of confusion.) 139 140 Args: 141 run_eagerly: Boolean deciding whether to run tf.function decorated 142 functions eagerly or not. 143 """ 144 self.skipTest('b/184210116') 145 146 try: 147 original_setting = tf.config.functions_run_eagerly() 148 tf.config.run_functions_eagerly(run_eagerly) 149 trace = [] 150 151 @tf.function 152 def train_one_step(a, x, y, optim): 153 nonlocal trace 154 trace.append('#tracing') 155 with tf.GradientTape() as tape: 156 l = tf.reduce_sum(tf.square(a * x - y)) 157 158 w = [a] 159 g = tape.gradient(l, w) 160 optim.apply_gradients(zip(g, w)) 161 return a 162 163 optim0 = tf.keras.optimizers.Adam() 164 optim1 = tf.keras.optimizers.Adam() 165 a = tf.Variable(2.) 166 x = tf.Variable([-1., -1., -1.]) 167 y = tf.Variable([2., 2., 2.]) 168 169 tf.config.run_functions_eagerly(run_eagerly) 170 train_one_step(a, x, y, optim0) # tracing 171 172 if run_eagerly: 173 train_one_step(a, x, y, optim1) 174 else: 175 self.assertLen(trace, 2) # traces two times; see "Notes" in the 176 # test case docstring for more info. 177 train_one_step(a, x, y, optim1) 178 179 finally: 180 tf.config.run_functions_eagerly(original_setting) 181 182 def testCachedTensorKerasLayers(self): 183 """Tests tf.function I/O behavior with cached tensors in Keras layers. 184 185 Bugs: b/149094965 186 Status: Working as intended 187 Issue: When there exists a trace that has cached tensors, retracing the 188 function (upon receiving new input signature) will result in an 189 error as the cached tensor is from the previous trace. 190 191 Error message: 192 "The tensor 'Tensor("Placeholder:0", shape=(None, 1), dtype=float32)' 193 cannot be accessed here: it is defined in another function or code block." 194 195 Notes: 196 * `self._cached_value` is already a cached tensor when the program tries to 197 retrace upon calling `model.fit()`. 198 * This test is equivalent to `testCachedTensor` test case but just with 199 Keras layers. 200 * Calling custom Keras layer initially with 201 `pred_out = layer(tf.constant(1.0))` as input should cache 202 `self._cached_value` as tensor, leading to an error upon calling 203 `model.fit()` with a different input signature. However, commenting out 204 the first step does not have any effect. Why? Left a TODO. 205 """ 206 self.skipTest('b/149094965') 207 208 class Context(object): 209 """Context class for demonstrating the issue.""" 210 211 def __init__(self): 212 self._cached_value = None 213 214 def f(self, x): 215 result = x + 1 216 if self._cached_value is not None: 217 result += self._cached_value 218 219 self._cached_value = x 220 return result 221 222 class CustomLayer(tf.keras.layers.Layer): 223 224 def __init__(self, context, **kwargs): 225 self.context = context 226 super(CustomLayer, self).__init__(**kwargs) 227 228 def call(self, x, training=None): 229 return self.context.f(x) 230 231 ctx = Context() 232 layer = CustomLayer(ctx) 233 # TODO(hyey): Investigate why the line below doesn't have any effect. 234 # Commenting out the line below (tensor caching step) still works. That 235 # probably means that tensors are being cached somewhere else? 236 pred_out = layer(tf.constant(1.0)) # pylint:disable=unused-variable 237 model = tf.keras.models.Sequential([layer]) 238 model.compile('sgd', 'mean_squared_error') 239 model.fit(tf.constant([1., 2., 3.]), tf.constant([1., 2., 3.])) 240 241 242if __name__ == '__main__': 243 test.main() 244