• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests to improve the consistency of tf.function I/O."""
17
18from absl.testing import parameterized
19
20import tensorflow as tf
21
22from tensorflow.python.platform import test
23from tensorflow.tools.consistency_integration_test.consistency_test_base import ConsistencyTestBase
24
25
26class TfFunctionIOConsistencyTests(ConsistencyTestBase, parameterized.TestCase):
27  """Test cases for known issues or bugs related to tf.function I/O."""
28
29  def testDynamicIndirectVariableCreation(self):
30    """Tests tf.function that tries to re-create `tf.Variable`s.
31
32    Bugs:   b/147231209
33    Status: Known issue
34            (In the short term, we should allow `tf.Variable`s to be lifted out
35            of each trace, rather than only one per `tf.function`.
36            In the long term, we could allow `tf.Variable`s to be created
37            arbitrarily (go/tf-mutable-refs).)
38    Issue:  Re-creating `tf.Variables` inside tf.function is not allowed and
39            the error message thrown is ambiguous (i.e. missing information
40            about which variable it causing the failure and where it happened).
41
42    Error message:
43      "Creating variables on a non-first call to a function decorated with
44      tf.function."
45
46    Improve error message? Needed. (b/187847612)
47
48    Notes:
49    * If `tf.Variable` creation is detected in the initial trace, tf.function
50      will retrace the function. For example:
51      ```
52      class Foo:
53        def __init__(self):
54          self.var = None
55
56        @tf.function
57        def __call__(self, x):
58          print("#tracing")
59          if self.var is None:
60            self.var = tf.Variable(x)
61          return self.var
62
63      foo = Foo()
64      foo(True)  # traced twice instead of once; tracing + variable lifting
65                 # '#tracing' prints 2 times.
66      foo(True)  # not traced; `#tracing` doesn't get printed.
67      foo(False)  # retraced once; '#tracing' prints once since `self.var` is
68                  # not None
69      ```
70      If `tf.Variable` creation is detected in a different trace for the same
71      tf.function, it will fail during the retrace's variable lifting stage.
72      (This is a simpler example of the test case.)
73      ```
74      class Baz:
75        def __init__(self):
76          self.cnt = 0
77
78        @tf.function
79        def __call__(self, x):
80          print("#tracing")
81          if self.cnt == 0:
82            self._var = tf.Variable(x)
83          elif self.cnt > 1:
84            self._var = tf.Variable(x)
85          self.cnt += 1
86
87      baz = Baz()
88      baz(True)  # traced twice instead of once; tracing + variable lifting
89                 # '#tracing' prints 2 times.
90      baz(True)  # not traced; no `tf.Variable` creation when `self.cnt == 1`.
91                 # `#tracing` doesn't get printed.
92      baz(False)  # retraced twice; retracing + variable lifting
93                  # '#tracing' prints once since it fails at variable lifting
94                  # stage.
95      ```
96    * The issue is prevalent when working with `tf.metrics.Mean` inside a
97      tf.function (b/187445546):
98      ```
99      class Foo:
100
101        def __init__(self):
102          self._metrics = collections.defaultdict(tf.metrics.Mean)
103
104        def __call__(self, is_training):
105          self.compute(is_training)
106
107        @tf.function
108        def compute(self, is_training):
109          if is_training:
110            self._metrics['test'].update_state([1., 2.])
111
112      foo = Foo()
113
114      # Calling `foo` here with `False` will trigger tracing; retriggering
115      # the tracing with `True` will cause the error.
116      foo(False)  # tracing
117      foo(True)  # error
118      ```
119    * Improve error message. It should mention the variable name and which
120      function tried to re-create `tf.Variable`s
121    * go/tf-mutable-refs is a work-in-progress, longer term project designed to
122      address this issue.
123    """
124    self.skipTest('b/147231209')
125
126    class Foo:
127      """Foo class for demonstrating the issue."""
128
129      def __init__(self):
130        self._flag_keyed_vars = {}
131
132      def __call__(self, var_creation_flag):
133        self.compute(var_creation_flag)
134
135      @tf.function
136      def compute(self, var_creation_flag):
137        if var_creation_flag not in self._flag_keyed_vars:
138          self._flag_keyed_vars[var_creation_flag] = tf.Variable(1.0)
139
140    foo = Foo()
141    foo(True)  # traced twice, with variable lifting
142    foo(True)  # not traced, reuses variables from first trace
143    foo(False)  # re-traced twice, variable lifting raises error; but we don't
144                # need to raise, we can just lift like in the first trace.
145
146  @parameterized.named_parameters([('_RunFunctionEagerly', True),
147                                   ('_RunFunctionNonEagerly', False)])
148  def testVariableCreationCustomModule(self, run_eagerly):
149    """Tests tf.function variable creation with custom objects (`tf.Module`).
150
151    Bugs:   b/184210116
152    Status: Working as intended
153            (However, moving forward, we should support re-creating
154            `tf.Variables` inside tf.function for each trace. This test case
155            should pass eventually.)
156    Issue:  `tf.Variable` creation in a custom module causes 'non-first call
157            variable creation' error in a tf.function.
158
159    Error message:
160      "tf.function-decorated function tried to create variables on non-first
161      call."
162
163    Notes:
164    * This is a simplified version of `testVariableCreationKerasLayers` test in
165      //tensorflow/tools/consistency_integration_test/keras_integration_tests.py
166      without involving Keras.
167    * Inconsistent behavior between eager and non-eager mode execution of the
168      tf.function.
169    * In non-eager mode (graph mode), double tracing (i.e. first one during
170      function tracing and second one in execution) causes variable creation in
171      non-first call error.
172    * go/tf-mutable-refs is a work-in-progress, longer term project designed to
173      address this issue.
174
175    Args:
176      run_eagerly: Boolean deciding whether to run tf.function decorated
177        functions eagerly or not.
178    """
179    self.skipTest('b/184210116')
180
181    try:
182      original_setting = tf.config.functions_run_eagerly()
183      tf.config.run_functions_eagerly(run_eagerly)
184
185      class Dense(tf.Module):
186        """Custom Dense class for demonstration."""
187
188        def __init__(self, in_features, out_features):
189          super().__init__()
190          self.w = tf.Variable(tf.random.normal([in_features, out_features]))
191          self.b = tf.Variable(tf.zeros([out_features]))
192
193        def __call__(self, x):
194          y = tf.matmul(x, self.w) + self.b
195          return tf.nn.relu(y)
196
197      @tf.function
198      def f(x):
199        layer = Dense(3, 3)(x)
200        return layer
201
202      in_val = tf.constant([[1., 2., 3]])
203
204      if run_eagerly:
205        self.assertAllEqual(
206            tf.constant([[0., 2.037801, 0.]], dtype=tf.float32), f(in_val))
207      else:
208        f(in_val)
209
210    finally:
211      tf.config.run_functions_eagerly(original_setting)
212
213  def testRetraceOnObjectPropertyChange(self):
214    """Tests retracing behavior of tf.function when object property has changed.
215
216    Bugs:   b/162221622
217    Status: Broken
218            (When the property of an object has changed, tf.function should
219            detect the update and retrace.)
220    Issue:  Changing the property of an object does not trigger retracing and
221            outputs wrong results.
222
223    Error message:
224      There isn't an error message thrown out; things work but wrongly because
225      the correct conditional branch didn't get traced initially and because
226      retracing doesn't take place.
227    """
228    self.skipTest('b/162221622')
229    trace = []
230
231    class Foo:
232      """Foo class for demonstration."""
233
234      def __init__(self):
235        self.condition = True
236        self.n = 1.0
237
238      @tf.function
239      def f(self, x):
240        """Function `f` for demonstration."""
241        nonlocal trace
242        trace.append('#tracing')
243
244        if not self.condition:
245          trace.append('#retracing')
246          self.n = x
247
248        return self.n
249
250    foo = Foo()
251    a = 2.0
252
253    out0 = foo.f(a)
254    self.assertEqual(out0, tf.constant(1.))
255    self.assertEqual(trace, ['#tracing'])
256
257    trace = []
258    foo.condition = False
259
260    out1 = foo.f(a)
261    # `out1` is 1.0 and `trace` is `[]` because tf.function did not retrace
262    # despite that `foo`'s property has changed.
263    self.assertEqual(out1, tf.constant(2.))
264    self.assertEqual(trace, ['#tracing', '#retracing'])
265
266  def testRetraceOnObjectPropertyChangeOneWorkaround(self):
267    """Tests a possible workaround for handling changes in object property.
268
269    Bugs:   b/162221622
270    Status: Broken
271            (The workaround demonstrated in this test case, however, works.
272            The eventual goal though should be to improve the behavior by
273            allowing retracing upon object property changes.)
274    Issue:  n/a
275
276    Error message: n/a
277
278    Notes:
279    * This is a workaround for issue demonstrated in
280      `testRetraceOnObjectPropertyChange` test case. We are explicitly
281      passing in the conditional variable in order to trigger retracing.
282    """
283    trace = []
284
285    class Foo:
286      """Foo class for demonstration."""
287
288      def __init__(self):
289        self.condition = True
290        self.n = 1.0
291        self.var = None
292
293      @tf.function
294      def f(self, x, condition):
295        """Function `f` for demonstration."""
296        nonlocal trace
297        trace.append('#tracing')
298
299        self.condition = condition
300
301        if self.var is None:
302          self.var = tf.Variable(x)
303
304        if not self.condition:
305          trace.append('#retracing')
306          self.n = 5.0
307
308        return self.var.assign_add(self.n)
309
310    foo = Foo()
311    a = 2.0
312
313    out0 = foo.f(a, True)
314    self.assertEqual(out0, tf.constant(3.))
315    self.assertEqual(trace, ['#tracing', '#tracing'])
316
317    trace = []
318
319    out1 = foo.f(a, False)
320    self.assertEqual(out1, tf.constant(8.))
321    self.assertEqual(trace, ['#tracing', '#retracing'])
322
323  def testDataResourcesIO(self):
324    """Tests returning iterators from tf.function.
325
326    Bugs:   b/170436338, b/170497789 (feature request)
327    Status: Broken
328    Issue:  Unable to return iterators from tf.function.
329
330    Error message:
331      "InvalidArgumentError: 6 nodes in a cycle [Op:__inference_f_11]"
332
333    Improve error message? Needed. (b/187850865)
334
335    Notes:
336    * Current error message is not helpful; we need to improve it to explain
337      what is causing the error where and suggest the known workaround.
338    * One workaround is to keep the iterator as a global variable:
339        ```
340        its = []
341
342        class Model(tf.Module):
343
344          @tf.function
345          def train(self):
346            global its
347            it = iter(tf.data.Dataset.from_tensors([0.0]).repeat())
348            its.append(it)
349            return it
350
351        model = Model()
352        model.train()
353        ```
354    * Another workaround is to create it upon `Model` initialization.
355        ```
356        class Model(tf.Module):
357
358          def __init__(self):
359            self.traced = False
360            self.dataset = tf.data.Dataset.from_tensor_slices([1., 2.])
361            self.iterator = iter(self.dataset)
362
363          def create_variables(self):
364            self.w = tf.Variable(0.0)
365
366          @tf.function
367          def train(self):
368            if not self.traced:
369              self.traced = True
370              self.create_variables()
371            return next(self.iterator)
372
373        model = Model()
374        model.train()
375        ```
376    """
377    self.skipTest('b/170436338')
378
379    class Model(tf.Module):
380      """Model class for demonstrating the issue."""
381
382      @tf.function
383      def f(self):
384        dataset = iter(tf.data.Dataset.from_tensors([0.0]).repeat())
385        iterator = iter(dataset)
386        return iterator
387
388    m = Model()
389    it0 = m.f()
390    it1 = iter(tf.data.Dataset.from_tensors([0.0]).repeat())
391    self.assertEqual(type(it0), type(it1))
392
393  def testCachedTensor(self):
394    """Tests tf.function behavior with cached tensors (side I/O).
395
396    Bugs:   b/149094965
397    Status: Working as intended
398    Issue:  When there exists a trace that has cached tensors, retracing the
399            function (upon receiving new input signature) will result in an
400            error as the cached tensor is from the previous trace.
401
402    Error message:
403      "An op outside of the function building code is being passed a "Graph"
404      tensor."
405
406    Improve error message? Needed. (b/187850615)
407
408    Notes:
409    * `self._cached_value` is already a cached tensor when the program tries to
410      retrace upon receiving `tf.constant([1, 2])` as input.
411    * Error message mentions about "Graph" tensor being passed in. Is this the
412      most informative message? Left a TODO.
413    """
414    self.skipTest('b/149094965')
415
416    class Context(object):
417      """Context class for demonstrating the issue."""
418
419      def __init__(self):
420        self._cached_value = None
421
422      def f(self, x):
423        result = x + 1
424        if self._cached_value is not None:
425          result += self._cached_value
426
427        self._cached_value = x
428        return result
429
430    @tf.function
431    def some_func(ctx, x):
432      return ctx.f(x + 1)
433
434    ctx = Context()
435    some_func(ctx, tf.constant(1))
436    some_func(ctx, tf.constant(2))
437    self.assertAllEqual(
438        some_func(ctx, tf.constant([1, 2])), tf.constant([6, 7]))
439
440
441if __name__ == '__main__':
442  test.main()
443