1# Lint as: python2, python3 2# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests to improve the consistency of tf.function I/O.""" 17 18from absl.testing import parameterized 19 20import tensorflow as tf 21 22from tensorflow.python.platform import test 23from tensorflow.tools.consistency_integration_test.consistency_test_base import ConsistencyTestBase 24 25 26class TfFunctionIOConsistencyTests(ConsistencyTestBase, parameterized.TestCase): 27 """Test cases for known issues or bugs related to tf.function I/O.""" 28 29 def testDynamicIndirectVariableCreation(self): 30 """Tests tf.function that tries to re-create `tf.Variable`s. 31 32 Bugs: b/147231209 33 Status: Known issue 34 (In the short term, we should allow `tf.Variable`s to be lifted out 35 of each trace, rather than only one per `tf.function`. 36 In the long term, we could allow `tf.Variable`s to be created 37 arbitrarily (go/tf-mutable-refs).) 38 Issue: Re-creating `tf.Variables` inside tf.function is not allowed and 39 the error message thrown is ambiguous (i.e. missing information 40 about which variable it causing the failure and where it happened). 41 42 Error message: 43 "Creating variables on a non-first call to a function decorated with 44 tf.function." 45 46 Improve error message? Needed. (b/187847612) 47 48 Notes: 49 * If `tf.Variable` creation is detected in the initial trace, tf.function 50 will retrace the function. For example: 51 ``` 52 class Foo: 53 def __init__(self): 54 self.var = None 55 56 @tf.function 57 def __call__(self, x): 58 print("#tracing") 59 if self.var is None: 60 self.var = tf.Variable(x) 61 return self.var 62 63 foo = Foo() 64 foo(True) # traced twice instead of once; tracing + variable lifting 65 # '#tracing' prints 2 times. 66 foo(True) # not traced; `#tracing` doesn't get printed. 67 foo(False) # retraced once; '#tracing' prints once since `self.var` is 68 # not None 69 ``` 70 If `tf.Variable` creation is detected in a different trace for the same 71 tf.function, it will fail during the retrace's variable lifting stage. 72 (This is a simpler example of the test case.) 73 ``` 74 class Baz: 75 def __init__(self): 76 self.cnt = 0 77 78 @tf.function 79 def __call__(self, x): 80 print("#tracing") 81 if self.cnt == 0: 82 self._var = tf.Variable(x) 83 elif self.cnt > 1: 84 self._var = tf.Variable(x) 85 self.cnt += 1 86 87 baz = Baz() 88 baz(True) # traced twice instead of once; tracing + variable lifting 89 # '#tracing' prints 2 times. 90 baz(True) # not traced; no `tf.Variable` creation when `self.cnt == 1`. 91 # `#tracing` doesn't get printed. 92 baz(False) # retraced twice; retracing + variable lifting 93 # '#tracing' prints once since it fails at variable lifting 94 # stage. 95 ``` 96 * The issue is prevalent when working with `tf.metrics.Mean` inside a 97 tf.function (b/187445546): 98 ``` 99 class Foo: 100 101 def __init__(self): 102 self._metrics = collections.defaultdict(tf.metrics.Mean) 103 104 def __call__(self, is_training): 105 self.compute(is_training) 106 107 @tf.function 108 def compute(self, is_training): 109 if is_training: 110 self._metrics['test'].update_state([1., 2.]) 111 112 foo = Foo() 113 114 # Calling `foo` here with `False` will trigger tracing; retriggering 115 # the tracing with `True` will cause the error. 116 foo(False) # tracing 117 foo(True) # error 118 ``` 119 * Improve error message. It should mention the variable name and which 120 function tried to re-create `tf.Variable`s 121 * go/tf-mutable-refs is a work-in-progress, longer term project designed to 122 address this issue. 123 """ 124 self.skipTest('b/147231209') 125 126 class Foo: 127 """Foo class for demonstrating the issue.""" 128 129 def __init__(self): 130 self._flag_keyed_vars = {} 131 132 def __call__(self, var_creation_flag): 133 self.compute(var_creation_flag) 134 135 @tf.function 136 def compute(self, var_creation_flag): 137 if var_creation_flag not in self._flag_keyed_vars: 138 self._flag_keyed_vars[var_creation_flag] = tf.Variable(1.0) 139 140 foo = Foo() 141 foo(True) # traced twice, with variable lifting 142 foo(True) # not traced, reuses variables from first trace 143 foo(False) # re-traced twice, variable lifting raises error; but we don't 144 # need to raise, we can just lift like in the first trace. 145 146 @parameterized.named_parameters([('_RunFunctionEagerly', True), 147 ('_RunFunctionNonEagerly', False)]) 148 def testVariableCreationCustomModule(self, run_eagerly): 149 """Tests tf.function variable creation with custom objects (`tf.Module`). 150 151 Bugs: b/184210116 152 Status: Working as intended 153 (However, moving forward, we should support re-creating 154 `tf.Variables` inside tf.function for each trace. This test case 155 should pass eventually.) 156 Issue: `tf.Variable` creation in a custom module causes 'non-first call 157 variable creation' error in a tf.function. 158 159 Error message: 160 "tf.function-decorated function tried to create variables on non-first 161 call." 162 163 Notes: 164 * This is a simplified version of `testVariableCreationKerasLayers` test in 165 //tensorflow/tools/consistency_integration_test/keras_integration_tests.py 166 without involving Keras. 167 * Inconsistent behavior between eager and non-eager mode execution of the 168 tf.function. 169 * In non-eager mode (graph mode), double tracing (i.e. first one during 170 function tracing and second one in execution) causes variable creation in 171 non-first call error. 172 * go/tf-mutable-refs is a work-in-progress, longer term project designed to 173 address this issue. 174 175 Args: 176 run_eagerly: Boolean deciding whether to run tf.function decorated 177 functions eagerly or not. 178 """ 179 self.skipTest('b/184210116') 180 181 try: 182 original_setting = tf.config.functions_run_eagerly() 183 tf.config.run_functions_eagerly(run_eagerly) 184 185 class Dense(tf.Module): 186 """Custom Dense class for demonstration.""" 187 188 def __init__(self, in_features, out_features): 189 super().__init__() 190 self.w = tf.Variable(tf.random.normal([in_features, out_features])) 191 self.b = tf.Variable(tf.zeros([out_features])) 192 193 def __call__(self, x): 194 y = tf.matmul(x, self.w) + self.b 195 return tf.nn.relu(y) 196 197 @tf.function 198 def f(x): 199 layer = Dense(3, 3)(x) 200 return layer 201 202 in_val = tf.constant([[1., 2., 3]]) 203 204 if run_eagerly: 205 self.assertAllEqual( 206 tf.constant([[0., 2.037801, 0.]], dtype=tf.float32), f(in_val)) 207 else: 208 f(in_val) 209 210 finally: 211 tf.config.run_functions_eagerly(original_setting) 212 213 def testRetraceOnObjectPropertyChange(self): 214 """Tests retracing behavior of tf.function when object property has changed. 215 216 Bugs: b/162221622 217 Status: Broken 218 (When the property of an object has changed, tf.function should 219 detect the update and retrace.) 220 Issue: Changing the property of an object does not trigger retracing and 221 outputs wrong results. 222 223 Error message: 224 There isn't an error message thrown out; things work but wrongly because 225 the correct conditional branch didn't get traced initially and because 226 retracing doesn't take place. 227 """ 228 self.skipTest('b/162221622') 229 trace = [] 230 231 class Foo: 232 """Foo class for demonstration.""" 233 234 def __init__(self): 235 self.condition = True 236 self.n = 1.0 237 238 @tf.function 239 def f(self, x): 240 """Function `f` for demonstration.""" 241 nonlocal trace 242 trace.append('#tracing') 243 244 if not self.condition: 245 trace.append('#retracing') 246 self.n = x 247 248 return self.n 249 250 foo = Foo() 251 a = 2.0 252 253 out0 = foo.f(a) 254 self.assertEqual(out0, tf.constant(1.)) 255 self.assertEqual(trace, ['#tracing']) 256 257 trace = [] 258 foo.condition = False 259 260 out1 = foo.f(a) 261 # `out1` is 1.0 and `trace` is `[]` because tf.function did not retrace 262 # despite that `foo`'s property has changed. 263 self.assertEqual(out1, tf.constant(2.)) 264 self.assertEqual(trace, ['#tracing', '#retracing']) 265 266 def testRetraceOnObjectPropertyChangeOneWorkaround(self): 267 """Tests a possible workaround for handling changes in object property. 268 269 Bugs: b/162221622 270 Status: Broken 271 (The workaround demonstrated in this test case, however, works. 272 The eventual goal though should be to improve the behavior by 273 allowing retracing upon object property changes.) 274 Issue: n/a 275 276 Error message: n/a 277 278 Notes: 279 * This is a workaround for issue demonstrated in 280 `testRetraceOnObjectPropertyChange` test case. We are explicitly 281 passing in the conditional variable in order to trigger retracing. 282 """ 283 trace = [] 284 285 class Foo: 286 """Foo class for demonstration.""" 287 288 def __init__(self): 289 self.condition = True 290 self.n = 1.0 291 self.var = None 292 293 @tf.function 294 def f(self, x, condition): 295 """Function `f` for demonstration.""" 296 nonlocal trace 297 trace.append('#tracing') 298 299 self.condition = condition 300 301 if self.var is None: 302 self.var = tf.Variable(x) 303 304 if not self.condition: 305 trace.append('#retracing') 306 self.n = 5.0 307 308 return self.var.assign_add(self.n) 309 310 foo = Foo() 311 a = 2.0 312 313 out0 = foo.f(a, True) 314 self.assertEqual(out0, tf.constant(3.)) 315 self.assertEqual(trace, ['#tracing', '#tracing']) 316 317 trace = [] 318 319 out1 = foo.f(a, False) 320 self.assertEqual(out1, tf.constant(8.)) 321 self.assertEqual(trace, ['#tracing', '#retracing']) 322 323 def testDataResourcesIO(self): 324 """Tests returning iterators from tf.function. 325 326 Bugs: b/170436338, b/170497789 (feature request) 327 Status: Broken 328 Issue: Unable to return iterators from tf.function. 329 330 Error message: 331 "InvalidArgumentError: 6 nodes in a cycle [Op:__inference_f_11]" 332 333 Improve error message? Needed. (b/187850865) 334 335 Notes: 336 * Current error message is not helpful; we need to improve it to explain 337 what is causing the error where and suggest the known workaround. 338 * One workaround is to keep the iterator as a global variable: 339 ``` 340 its = [] 341 342 class Model(tf.Module): 343 344 @tf.function 345 def train(self): 346 global its 347 it = iter(tf.data.Dataset.from_tensors([0.0]).repeat()) 348 its.append(it) 349 return it 350 351 model = Model() 352 model.train() 353 ``` 354 * Another workaround is to create it upon `Model` initialization. 355 ``` 356 class Model(tf.Module): 357 358 def __init__(self): 359 self.traced = False 360 self.dataset = tf.data.Dataset.from_tensor_slices([1., 2.]) 361 self.iterator = iter(self.dataset) 362 363 def create_variables(self): 364 self.w = tf.Variable(0.0) 365 366 @tf.function 367 def train(self): 368 if not self.traced: 369 self.traced = True 370 self.create_variables() 371 return next(self.iterator) 372 373 model = Model() 374 model.train() 375 ``` 376 """ 377 self.skipTest('b/170436338') 378 379 class Model(tf.Module): 380 """Model class for demonstrating the issue.""" 381 382 @tf.function 383 def f(self): 384 dataset = iter(tf.data.Dataset.from_tensors([0.0]).repeat()) 385 iterator = iter(dataset) 386 return iterator 387 388 m = Model() 389 it0 = m.f() 390 it1 = iter(tf.data.Dataset.from_tensors([0.0]).repeat()) 391 self.assertEqual(type(it0), type(it1)) 392 393 def testCachedTensor(self): 394 """Tests tf.function behavior with cached tensors (side I/O). 395 396 Bugs: b/149094965 397 Status: Working as intended 398 Issue: When there exists a trace that has cached tensors, retracing the 399 function (upon receiving new input signature) will result in an 400 error as the cached tensor is from the previous trace. 401 402 Error message: 403 "An op outside of the function building code is being passed a "Graph" 404 tensor." 405 406 Improve error message? Needed. (b/187850615) 407 408 Notes: 409 * `self._cached_value` is already a cached tensor when the program tries to 410 retrace upon receiving `tf.constant([1, 2])` as input. 411 * Error message mentions about "Graph" tensor being passed in. Is this the 412 most informative message? Left a TODO. 413 """ 414 self.skipTest('b/149094965') 415 416 class Context(object): 417 """Context class for demonstrating the issue.""" 418 419 def __init__(self): 420 self._cached_value = None 421 422 def f(self, x): 423 result = x + 1 424 if self._cached_value is not None: 425 result += self._cached_value 426 427 self._cached_value = x 428 return result 429 430 @tf.function 431 def some_func(ctx, x): 432 return ctx.f(x + 1) 433 434 ctx = Context() 435 some_func(ctx, tf.constant(1)) 436 some_func(ctx, tf.constant(2)) 437 self.assertAllEqual( 438 some_func(ctx, tf.constant([1, 2])), tf.constant([6, 7])) 439 440 441if __name__ == '__main__': 442 test.main() 443