• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests to improve Keras integration with tf.function."""
17
18from absl.testing import parameterized
19
20import tensorflow as tf
21
22from tensorflow.python.platform import test
23from tensorflow.tools.consistency_integration_test.consistency_test_base import ConsistencyTestBase
24
25
26class KerasIntegrationTest(ConsistencyTestBase, parameterized.TestCase):
27  """Test cases for Keras integration with tf.function."""
28
29  @parameterized.named_parameters([('_RunFunctionEagerly', True),
30                                   ('_RunFunctionNonEagerly', False)])
31  def testVariableCreationKerasLayers(self, run_eagerly):
32    """Tests tf.function variable creation in Keras layers.
33
34    Bugs:   b/184210116
35    Status: Known issue
36            (However, moving forward, we should support re-creating
37            `tf.Variables` inside tf.function for each trace. This test case
38            should pass eventually.)
39    Issue:  `tf.Variable` creation in Keras layers causes 'non-first call
40            variable creation' error in a tf.function.
41
42    Error message:
43      "tf.function-decorated function tried to create variables on non-first
44      call."
45
46    Improve error message? Needed. (b/187847612)
47
48    Notes:
49    * Inconsistent behavior between eager and non-eager mode execution of the
50      tf.function.
51    * In non-eager mode (graph mode), double tracing (i.e. first one during
52      function tracing and second one in execution) causes variable creation in
53      non-first call error.
54    * This is an expected behavior as Keras's Dense layer creates variables.
55    * go/tf-mutable-refs is a work-in-progress, longer term project designed to
56      address this issue.
57
58    Args:
59      run_eagerly: Boolean deciding whether to run tf.function decorated
60        functions eagerly or not.
61    """
62    self.skipTest('b/184210116')
63
64    try:
65      original_setting = tf.config.functions_run_eagerly()
66      tf.config.run_functions_eagerly(run_eagerly)
67
68      @tf.function
69      def f(x):
70        layer = tf.keras.layers.Dense(2)(x)
71        return layer
72
73      if run_eagerly:
74        self.assertAllEqual(
75            tf.constant([[0.7891873, -0.5761101], [1.7832438, -1.6489036]],
76                        dtype=tf.float32), f(tf.constant([[1., 2.], [3., 4.]])))
77      else:
78        f(tf.constant([[1., 2.], [3., 4.]]))
79
80    finally:
81      tf.config.run_functions_eagerly(original_setting)
82
83  def testVariableCreationKerasLayersRecommended(self):
84    """Tests the recommended way of creating Keras layers in tf.function.
85
86    Bugs:   b/184210116
87    Status: Working as intended
88    Issue:  n/a
89
90    Error message: n/a
91
92    Notes:
93    * The suggested way of going about the problematic case written in
94      `testVariableCreationKerasLayers` test case.
95    """
96    layer = None
97
98    @tf.function
99    def f(x):
100      nonlocal layer
101      if layer is None:
102        layer = tf.keras.layers.Dense(2)
103      return layer(x)
104
105    self.assertAllEqual(
106        f(tf.constant([[1., 2.], [3., 4.]])),
107        tf.constant([[0.7891873, -0.5761101], [1.7832438, -1.6489036]]))
108
109  @parameterized.named_parameters([('_RunFunctionEagerly', True),
110                                   ('_RunFunctionNonEagerly', False)])
111  def testRetracingKerasOptimAsPythonObj(self, run_eagerly):
112    """Tests tf.function variable creation in Keras optimizers.
113
114    Bugs:   b/184210116
115    Status: Working as intended
116            (However, moving forward, we should support re-creating
117            `tf.Variables` inside tf.function for each trace. This test case
118            should pass eventually.)
119    Issue:  Passing in different Keras optimizers (Python objects) to
120            tf.function is not allowed as they create
121            `tf.Variable`s and will result in 'non-first call variable creation'
122            error.
123
124    Error message:
125      "tf.function-decorated function tried to create variables on non-first
126      call."
127
128    Notes:
129    * Inconsistent behavior between eager and non-eager mode execution of the
130      tf.function.
131    * go/tf-mutable-refs is a work-in-progress, longer term project designed to
132      address this issue.
133    * `trace` has three '#training' strings (before erroring out in the last
134      `train_one_step()` call) when two is generally expected. Why?
135      Answer: First `train_one_step` call traces twice because, after the first
136      trace, tf.function detects `tf.Variable` creation and immediately traces a
137      second time to see whether new variables are being created. (This has
138      been a common source of confusion.)
139
140    Args:
141      run_eagerly: Boolean deciding whether to run tf.function decorated
142        functions eagerly or not.
143    """
144    self.skipTest('b/184210116')
145
146    try:
147      original_setting = tf.config.functions_run_eagerly()
148      tf.config.run_functions_eagerly(run_eagerly)
149      trace = []
150
151      @tf.function
152      def train_one_step(a, x, y, optim):
153        nonlocal trace
154        trace.append('#tracing')
155        with tf.GradientTape() as tape:
156          l = tf.reduce_sum(tf.square(a * x - y))
157
158        w = [a]
159        g = tape.gradient(l, w)
160        optim.apply_gradients(zip(g, w))
161        return a
162
163      optim0 = tf.keras.optimizers.Adam()
164      optim1 = tf.keras.optimizers.Adam()
165      a = tf.Variable(2.)
166      x = tf.Variable([-1., -1., -1.])
167      y = tf.Variable([2., 2., 2.])
168
169      tf.config.run_functions_eagerly(run_eagerly)
170      train_one_step(a, x, y, optim0)  # tracing
171
172      if run_eagerly:
173        train_one_step(a, x, y, optim1)
174      else:
175        self.assertLen(trace, 2)  # traces two times; see "Notes" in the
176                                  # test case docstring for more info.
177        train_one_step(a, x, y, optim1)
178
179    finally:
180      tf.config.run_functions_eagerly(original_setting)
181
182  def testCachedTensorKerasLayers(self):
183    """Tests tf.function I/O behavior with cached tensors in Keras layers.
184
185    Bugs:   b/149094965
186    Status: Working as intended
187    Issue:  When there exists a trace that has cached tensors, retracing the
188            function (upon receiving new input signature) will result in an
189            error as the cached tensor is from the previous trace.
190
191    Error message:
192      "The tensor 'Tensor("Placeholder:0", shape=(None, 1), dtype=float32)'
193      cannot be accessed here: it is defined in another function or code block."
194
195    Notes:
196    * `self._cached_value` is already a cached tensor when the program tries to
197      retrace upon calling `model.fit()`.
198    * This test is equivalent to `testCachedTensor` test case but just with
199      Keras layers.
200    * Calling custom Keras layer initially with
201      `pred_out = layer(tf.constant(1.0))` as input should cache
202      `self._cached_value` as tensor, leading to an error upon calling
203      `model.fit()` with a different input signature. However, commenting out
204      the first step does not have any effect. Why? Left a TODO.
205    """
206    self.skipTest('b/149094965')
207
208    class Context(object):
209      """Context class for demonstrating the issue."""
210
211      def __init__(self):
212        self._cached_value = None
213
214      def f(self, x):
215        result = x + 1
216        if self._cached_value is not None:
217          result += self._cached_value
218
219        self._cached_value = x
220        return result
221
222    class CustomLayer(tf.keras.layers.Layer):
223
224      def __init__(self, context, **kwargs):
225        self.context = context
226        super(CustomLayer, self).__init__(**kwargs)
227
228      def call(self, x, training=None):
229        return self.context.f(x)
230
231    ctx = Context()
232    layer = CustomLayer(ctx)
233    # TODO(hyey): Investigate why the line below doesn't have any effect.
234    # Commenting out the line below (tensor caching step) still works. That
235    # probably means that tensors are being cached somewhere else?
236    pred_out = layer(tf.constant(1.0))  # pylint:disable=unused-variable
237    model = tf.keras.models.Sequential([layer])
238    model.compile('sgd', 'mean_squared_error')
239    model.fit(tf.constant([1., 2., 3.]), tf.constant([1., 2., 3.]))
240
241
242if __name__ == '__main__':
243  test.main()
244