1# Lint as: python2, python3 2# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests to improve the consistency with tf.TensorArray.""" 17 18import io 19import logging as std_logging 20 21import tensorflow as tf 22from tensorflow.python.platform import test 23from tensorflow.tools.consistency_integration_test.consistency_test_base import ConsistencyTestBase 24 25 26class TensorArrayConsistencyTests(ConsistencyTestBase): 27 """Test cases for known issues or bugs related to tf.TensorArray.""" 28 29 def testConcat(self): 30 """Tests inconsistent behavior with `.concat()`. 31 32 Bugs: b/180921284 33 Status: Missing error message 34 Issue: Running functions eagerly, calling `.concat` on a `tf.TensorArray` 35 should raise an error but it does not. 36 37 Error message: 38 Expected error message is "Concatenating scalars in `tf.TensorArray` is 39 unsupported in eager mode. Please use `.stack()` instead". 40 41 Improve error message? Needed. (b/187851559) 42 43 Notes: 44 * Inconsistent behavior between eager and non-eager mode execution of the 45 `tf.function` decorated function. In graph mode, the error is thrown. 46 * We need to improve graph mode error message though. The error gets thrown 47 is "Shapes must be equal rank, but are 1 and 0" and is hard to grasp. 48 * Using `.stack()` as a workaround is working as intended: 49 ``` 50 def f(x): 51 return x.write(1, tf.constant([4, 5, 6])) 52 53 ta = tf.TensorArray(dtype=tf.int32, dynamic_size=True, size=0) 54 ta = ta.write(0, tf.constant([1, 2, 3])) 55 f(ta).stack() # <tf.Tensor: shape=(2, 3), dtype=int32, 56 # numpy=array([[1, 2, 3], [4, 5, 6]], dtype=int32)> 57 ``` 58 """ 59 self.skipTest('b/180921284') 60 try: 61 tf.config.run_functions_eagerly(True) 62 63 @tf.function 64 def f(x, y, z): 65 ta = tf.TensorArray(size=3, dtype=tf.int32, element_shape=()) 66 ta = ta.write(0, x) 67 ta = ta.write(1, y) 68 ta = ta.write(2, z) 69 return ta.concat() 70 71 with self.assertRaisesWithPredicateMatch( 72 BaseException, 73 # TODO(hyey): Below is a placeholder error message of what we 74 # probably want but it needs to be updated to specify what caused 75 # the error and where. 76 'Concatenating scalars in `tf.TensorArray` is unsupported in eager ' 77 'mode. Please use `.stack()` instead'): 78 f(1, 2, 3) 79 80 finally: 81 tf.config.run_functions_eagerly(False) 82 83 def testArrayReturnedFromTfFunction(self): 84 """Tests bad handling of tf.TensorArray returned from tf.function. 85 86 Bugs: b/147450234 87 Status: Broken 88 Issue: `tf.TensorArray` returned from tf.function is a `tf.variant` tensor 89 (i.e. `tf.Tensor(<unprintable>, shape=(), dtype=variant)`). Calling 90 `stack()` on it causes an AttributeError. 91 92 Error message: 93 "AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has" 94 " no attribute 'stack'" 95 96 Notes: 97 * Note that XLA fails with a different error that is equally confusing: 98 "Support for TensorList crossing the XLA/TF boundary is not implemented." 99 """ 100 self.skipTest('b/147450234') 101 num_rows = 2 102 103 @tf.function 104 def f(x): 105 ta = tf.TensorArray(tf.float32, num_rows) 106 for i in range(num_rows): 107 ta = ta.write(i, x[i]) 108 109 return ta 110 111 n = tf.constant([[1., 2.], [3., 4.]]) 112 ta0 = f(n) 113 ta1 = tf.TensorArray(tf.float32, num_rows) 114 ta1 = ta1.write(0, n[0]) 115 ta1 = ta1.write(1, n[1]) 116 117 # Output of `f(n)` is `tf.Tensor(<unprintable>, shape=(), dtype=variant)`. 118 self.assertAllEqual(ta0.stack(), ta1.stack()) 119 120 def testTensorArraySpec(self): 121 """Tests tf.TensorArray behavior with `TensorArraySpec` as input signature. 122 123 Bugs: b/162452468, b/187114287 124 Status: Broken 125 Issue: Using `tf.TensorArraySpec` as the input signature to tf.function 126 does not work. This is not documented anywhere. 127 128 Error message: 129 "If shallow structure is a sequence, input must also be a sequence." 130 131 Notes: 132 * Documentation for `tf.TensorArraySpec` appears to be minimal. Need to 133 update it. 134 """ 135 self.skipTest('b/187114287') 136 input_signature = [ 137 tf.TensorArraySpec( 138 element_shape=None, dtype=tf.float32, dynamic_size=True) 139 ] 140 141 @tf.function(input_signature=input_signature) 142 def f(ta): 143 return ta.stack() 144 145 ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 146 ta = ta.write(0, tf.constant([1.0, 2.0])) 147 ta = ta.write(1, tf.constant([3.0, 4.0])) 148 149 out_t = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 150 self.assertAllEqual(f(ta), out_t) 151 152 def testTensorArrayConcreteFunction(self): 153 """Tests ConcreteFunction retrieval of a tf.function with a tf.TensorArray. 154 155 Bugs: b/162452468, b/187114664 156 Status: Broken 157 Issue: Calling tf.function with a proper argument (i.e. traced input) 158 fails. More specifically, calling `cf(arr)` should work but doesn't 159 and calling `cf()` works rather when it should fail. 160 """ 161 self.skipTest('b/187114664') 162 163 @tf.function 164 def fun(x): 165 return x.stack() 166 167 ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 168 ta = ta.write(0, tf.constant([1.0, 2.0])) 169 ta = ta.write(1, tf.constant([3.0, 4.0])) 170 171 cf = fun.get_concrete_function(ta) 172 t0 = cf(ta) 173 t1 = ta.stack() 174 self.assertAllEqual(t0, t1) 175 176 def testVariantTensorAsOutput(self): 177 """Tests that tf.variant tensor returns from tf.function for tf.TensorArray. 178 179 Bugs: b/162452468, b/187115938 180 Status: Broken 181 Issue: `tf.TensorArray` returned from tf.function is a tf.variant tensor 182 and is limited in functionality. For e.g., as simple as trying to 183 `print()` or call `.numpy()` on it does not work (see 184 `testBadIOErrorMsg` test case above). 185 186 Notes: 187 * When tf.function returns a `tf.TensorArray`, output returned should be a 188 `tf.TensorArray`. 189 """ 190 self.skipTest('b/187115938') 191 192 @tf.function 193 def f(): 194 ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 195 ta = ta.write(0, tf.constant([1.0, 2.0])) 196 ta = ta.write(1, tf.constant([3.0, 4.0])) 197 return ta 198 199 rtn_ta = f() 200 # Initialize a `tf.TensorArray` to check against `rtn_ta` that it is a 201 # `tf.TensorArray`. 202 a_ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 203 self.assertEqual(rtn_ta.__module__, a_ta.__module__) 204 205 def testTensorArrayPassedInAndReturnedFromTfFunction(self): 206 """Tests tf.TensorArray passed in as input and returned as output. 207 208 Bugs: b/162452468, b/187115435, b/147450234 209 Status: Broken 210 Issue: Returning `tf.TensorArray` from a tf.function does not work when 211 passing it in as an input works. This is not documented anywhere. 212 213 Error message: 214 "Attempting to build a graph-mode TF2-style TensorArray from either an 215 eager-mode TensorArray or a TF1-style TensorArray." 216 """ 217 self.skipTest('b/187115435') 218 219 @tf.function 220 def f(ta): 221 ta = ta.write(1, tf.constant([3.0, 4.0])) 222 return ta 223 224 ta0 = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 225 ta0 = ta0.write(0, tf.constant([1.0, 2.0])) 226 ta0 = f(ta0) 227 228 ta1 = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 229 ta1 = ta1.write(0, tf.constant([1.0, 2.0])) 230 ta1 = ta1.write(1, tf.constant([3.0, 4.0])) 231 232 self.assertAllEqual(ta0.stack(), ta1.stack()) 233 234 def testMissingWarning(self): 235 """Tests warnings when the output of tf.TensorArray methods is unused. 236 237 Bugs: b/150784251 238 Status: Broken 239 Issue: tf.TensorArray API doc specifies that a warning should be present 240 when the output of tf.TensorArray methods is unused but no warning 241 is present for tf.function decorated functions. 242 https://www.tensorflow.org/api_docs/python/tf/TensorArray 243 244 Error message: 245 'Object was never used ... If you want to mark it as used call its 246 "mark_used()" method.' 247 248 Improve error message? Needed. (b/187852489) 249 250 Notes: 251 * Inconsistent behavior between when a function is decorated with 252 tf.function and not. For example, if `f()` is tf.function-decorated, then 253 it will NOT print the warning. If `f()` is NOT tf.function-decorated, then 254 it will print the warning. 255 ``` 256 @tf.function 257 def f(x): 258 ta = tf.TensorArray(x.dtype, tf.shape(x)[0]) 259 ta.write(0, x[0]) 260 261 f(tf.constant([1, 2, 3, 4])) 262 ``` 263 * As simple as assignment operation is enough to avoid the warning case. 264 ``` 265 @tf.function 266 def f(x): 267 ta = tf.TensorArray(x.dtype, tf.shape(x)[0]) 268 ta = ta.write(0, x[0]) 269 270 f(tf.constant([1, 2, 3, 4])) 271 ``` 272 """ 273 self.skipTest('b/150784251') 274 275 log = io.StringIO() 276 handler = std_logging.StreamHandler(log) 277 std_logging.root.addHandler(handler) 278 279 @tf.function 280 def f(x): 281 ta = tf.TensorArray(x.dtype, tf.shape(x)[0]) 282 # A warning should be thrown with the line below. This is the case only 283 # when `f()` is not decorated with tf.function. 284 ta.write(0, x[0]) 285 286 f(tf.constant([1, 2, 3, 4])) 287 288 self.assertIn('Object was never used', log.getvalue()) 289 std_logging.root.removeHandler(handler) 290 291 292if __name__ == '__main__': 293 test.main() 294