• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests to improve the consistency with tf.TensorArray."""
17
18import io
19import logging as std_logging
20
21import tensorflow as tf
22from tensorflow.python.platform import test
23from tensorflow.tools.consistency_integration_test.consistency_test_base import ConsistencyTestBase
24
25
26class TensorArrayConsistencyTests(ConsistencyTestBase):
27  """Test cases for known issues or bugs related to tf.TensorArray."""
28
29  def testConcat(self):
30    """Tests inconsistent behavior with `.concat()`.
31
32    Bugs:   b/180921284
33    Status: Missing error message
34    Issue:  Running functions eagerly, calling `.concat` on a `tf.TensorArray`
35            should raise an error but it does not.
36
37    Error message:
38      Expected error message is "Concatenating scalars in `tf.TensorArray` is
39      unsupported in eager mode. Please use `.stack()` instead".
40
41    Improve error message? Needed. (b/187851559)
42
43    Notes:
44    * Inconsistent behavior between eager and non-eager mode execution of the
45      `tf.function` decorated function. In graph mode, the error is thrown.
46    * We need to improve graph mode error message though. The error gets thrown
47      is "Shapes must be equal rank, but are 1 and 0" and is hard to grasp.
48    * Using `.stack()` as a workaround is working as intended:
49      ```
50      def f(x):
51        return x.write(1, tf.constant([4, 5, 6]))
52
53      ta = tf.TensorArray(dtype=tf.int32, dynamic_size=True, size=0)
54      ta = ta.write(0, tf.constant([1, 2, 3]))
55      f(ta).stack()  # <tf.Tensor: shape=(2, 3), dtype=int32,
56                     # numpy=array([[1, 2, 3], [4, 5, 6]], dtype=int32)>
57      ```
58    """
59    self.skipTest('b/180921284')
60    try:
61      tf.config.run_functions_eagerly(True)
62
63      @tf.function
64      def f(x, y, z):
65        ta = tf.TensorArray(size=3, dtype=tf.int32, element_shape=())
66        ta = ta.write(0, x)
67        ta = ta.write(1, y)
68        ta = ta.write(2, z)
69        return ta.concat()
70
71      with self.assertRaisesWithPredicateMatch(
72          BaseException,
73          # TODO(hyey): Below is a placeholder error message of what we
74          # probably want but it needs to be updated to specify what caused
75          # the error and where.
76          'Concatenating scalars in `tf.TensorArray` is unsupported in eager '
77          'mode. Please use `.stack()` instead'):
78        f(1, 2, 3)
79
80    finally:
81      tf.config.run_functions_eagerly(False)
82
83  def testArrayReturnedFromTfFunction(self):
84    """Tests bad handling of tf.TensorArray returned from tf.function.
85
86    Bugs:   b/147450234
87    Status: Broken
88    Issue:  `tf.TensorArray` returned from tf.function is a `tf.variant` tensor
89            (i.e. `tf.Tensor(<unprintable>, shape=(), dtype=variant)`). Calling
90            `stack()` on it causes an AttributeError.
91
92    Error message:
93      "AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has"
94      " no attribute 'stack'"
95
96    Notes:
97    * Note that XLA fails with a different error that is equally confusing:
98      "Support for TensorList crossing the XLA/TF boundary is not implemented."
99    """
100    self.skipTest('b/147450234')
101    num_rows = 2
102
103    @tf.function
104    def f(x):
105      ta = tf.TensorArray(tf.float32, num_rows)
106      for i in range(num_rows):
107        ta = ta.write(i, x[i])
108
109      return ta
110
111    n = tf.constant([[1., 2.], [3., 4.]])
112    ta0 = f(n)
113    ta1 = tf.TensorArray(tf.float32, num_rows)
114    ta1 = ta1.write(0, n[0])
115    ta1 = ta1.write(1, n[1])
116
117    # Output of `f(n)` is `tf.Tensor(<unprintable>, shape=(), dtype=variant)`.
118    self.assertAllEqual(ta0.stack(), ta1.stack())
119
120  def testTensorArraySpec(self):
121    """Tests tf.TensorArray behavior with `TensorArraySpec` as input signature.
122
123    Bugs:   b/162452468, b/187114287
124    Status: Broken
125    Issue:  Using `tf.TensorArraySpec` as the input signature to tf.function
126            does not work. This is not documented anywhere.
127
128    Error message:
129      "If shallow structure is a sequence, input must also be a sequence."
130
131    Notes:
132    * Documentation for `tf.TensorArraySpec` appears to be minimal. Need to
133      update it.
134    """
135    self.skipTest('b/187114287')
136    input_signature = [
137        tf.TensorArraySpec(
138            element_shape=None, dtype=tf.float32, dynamic_size=True)
139    ]
140
141    @tf.function(input_signature=input_signature)
142    def f(ta):
143      return ta.stack()
144
145    ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0)
146    ta = ta.write(0, tf.constant([1.0, 2.0]))
147    ta = ta.write(1, tf.constant([3.0, 4.0]))
148
149    out_t = tf.constant([[1.0, 2.0], [3.0, 4.0]])
150    self.assertAllEqual(f(ta), out_t)
151
152  def testTensorArrayConcreteFunction(self):
153    """Tests ConcreteFunction retrieval of a tf.function with a tf.TensorArray.
154
155    Bugs:   b/162452468, b/187114664
156    Status: Broken
157    Issue:  Calling tf.function with a proper argument (i.e. traced input)
158            fails. More specifically, calling `cf(arr)` should work but doesn't
159            and calling `cf()` works rather when it should fail.
160    """
161    self.skipTest('b/187114664')
162
163    @tf.function
164    def fun(x):
165      return x.stack()
166
167    ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0)
168    ta = ta.write(0, tf.constant([1.0, 2.0]))
169    ta = ta.write(1, tf.constant([3.0, 4.0]))
170
171    cf = fun.get_concrete_function(ta)
172    t0 = cf(ta)
173    t1 = ta.stack()
174    self.assertAllEqual(t0, t1)
175
176  def testVariantTensorAsOutput(self):
177    """Tests that tf.variant tensor returns from tf.function for tf.TensorArray.
178
179    Bugs:   b/162452468, b/187115938
180    Status: Broken
181    Issue:  `tf.TensorArray` returned from tf.function is a tf.variant tensor
182            and is limited in functionality. For e.g., as simple as trying to
183            `print()` or call `.numpy()` on it does not work (see
184            `testBadIOErrorMsg` test case above).
185
186    Notes:
187    * When tf.function returns a `tf.TensorArray`, output returned should be a
188      `tf.TensorArray`.
189    """
190    self.skipTest('b/187115938')
191
192    @tf.function
193    def f():
194      ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0)
195      ta = ta.write(0, tf.constant([1.0, 2.0]))
196      ta = ta.write(1, tf.constant([3.0, 4.0]))
197      return ta
198
199    rtn_ta = f()
200    # Initialize a `tf.TensorArray` to check against `rtn_ta` that it is a
201    # `tf.TensorArray`.
202    a_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0)
203    self.assertEqual(rtn_ta.__module__, a_ta.__module__)
204
205  def testTensorArrayPassedInAndReturnedFromTfFunction(self):
206    """Tests tf.TensorArray passed in as input and returned as output.
207
208    Bugs:   b/162452468, b/187115435, b/147450234
209    Status: Broken
210    Issue:  Returning `tf.TensorArray` from a tf.function does not work when
211            passing it in as an input works. This is not documented anywhere.
212
213    Error message:
214      "Attempting to build a graph-mode TF2-style TensorArray from either an
215      eager-mode TensorArray or a TF1-style TensorArray."
216    """
217    self.skipTest('b/187115435')
218
219    @tf.function
220    def f(ta):
221      ta = ta.write(1, tf.constant([3.0, 4.0]))
222      return ta
223
224    ta0 = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0)
225    ta0 = ta0.write(0, tf.constant([1.0, 2.0]))
226    ta0 = f(ta0)
227
228    ta1 = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0)
229    ta1 = ta1.write(0, tf.constant([1.0, 2.0]))
230    ta1 = ta1.write(1, tf.constant([3.0, 4.0]))
231
232    self.assertAllEqual(ta0.stack(), ta1.stack())
233
234  def testMissingWarning(self):
235    """Tests warnings when the output of tf.TensorArray methods is unused.
236
237    Bugs:   b/150784251
238    Status: Broken
239    Issue:  tf.TensorArray API doc specifies that a warning should be present
240            when the output of tf.TensorArray methods is unused but no warning
241            is present for tf.function decorated functions.
242            https://www.tensorflow.org/api_docs/python/tf/TensorArray
243
244    Error message:
245      'Object was never used ... If you want to mark it as used call its
246      "mark_used()" method.'
247
248    Improve error message? Needed. (b/187852489)
249
250    Notes:
251    * Inconsistent behavior between when a function is decorated with
252      tf.function and not. For example, if `f()` is tf.function-decorated, then
253      it will NOT print the warning. If `f()` is NOT tf.function-decorated, then
254      it will print the warning.
255        ```
256        @tf.function
257        def f(x):
258          ta = tf.TensorArray(x.dtype, tf.shape(x)[0])
259          ta.write(0, x[0])
260
261        f(tf.constant([1, 2, 3, 4]))
262        ```
263    * As simple as assignment operation is enough to avoid the warning case.
264        ```
265        @tf.function
266        def f(x):
267          ta = tf.TensorArray(x.dtype, tf.shape(x)[0])
268          ta = ta.write(0, x[0])
269
270        f(tf.constant([1, 2, 3, 4]))
271        ```
272    """
273    self.skipTest('b/150784251')
274
275    log = io.StringIO()
276    handler = std_logging.StreamHandler(log)
277    std_logging.root.addHandler(handler)
278
279    @tf.function
280    def f(x):
281      ta = tf.TensorArray(x.dtype, tf.shape(x)[0])
282      # A warning should be thrown with the line below. This is the case only
283      # when `f()` is not decorated with tf.function.
284      ta.write(0, x[0])
285
286    f(tf.constant([1, 2, 3, 4]))
287
288    self.assertIn('Object was never used', log.getvalue())
289    std_logging.root.removeHandler(handler)
290
291
292if __name__ == '__main__':
293  test.main()
294