1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for lists module.""" 16 17from tensorflow.python.autograph.converters import directives as directives_converter 18from tensorflow.python.autograph.converters import lists 19from tensorflow.python.autograph.core import converter_testing 20from tensorflow.python.autograph.lang import directives 21from tensorflow.python.autograph.lang import special_functions 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import list_ops 26from tensorflow.python.platform import test 27 28 29class ListTest(converter_testing.TestCase): 30 31 def test_empty_list(self): 32 33 def f(): 34 return [] 35 36 tr = self.transform(f, lists) 37 38 tl = tr() 39 # Empty tensor lists cannot be evaluated or stacked. 40 self.assertIsInstance(tl, ops.Tensor) 41 self.assertEqual(tl.dtype, dtypes.variant) 42 43 def test_initialized_list(self): 44 45 def f(): 46 return [1, 2, 3] 47 48 tr = self.transform(f, lists) 49 50 self.assertAllEqual(tr(), [1, 2, 3]) 51 52 def test_list_append(self): 53 54 def f(): 55 l = special_functions.tensor_list([1]) 56 l.append(2) 57 l.append(3) 58 return l 59 60 tr = self.transform(f, lists) 61 62 tl = tr() 63 r = list_ops.tensor_list_stack(tl, dtypes.int32) 64 self.assertAllEqual(self.evaluate(r), [1, 2, 3]) 65 66 def test_list_pop(self): 67 68 def f(): 69 l = special_functions.tensor_list([1, 2, 3]) 70 directives.set_element_type(l, dtype=dtypes.int32, shape=()) 71 s = l.pop() 72 return s, l 73 74 tr = self.transform(f, (directives_converter, lists)) 75 76 ts, tl = tr() 77 r = list_ops.tensor_list_stack(tl, dtypes.int32) 78 self.assertAllEqual(self.evaluate(r), [1, 2]) 79 self.assertAllEqual(self.evaluate(ts), 3) 80 81 def test_double_list_pop(self): 82 83 def f(l): 84 s = l.pop().pop() 85 return s 86 87 tr = self.transform(f, lists) 88 89 test_input = [1, 2, [1, 2, 3]] 90 # TODO(mdan): Pass a list of lists of tensor when we fully support that. 91 # For now, we just pass a regular Python list of lists just to verify that 92 # the two pop calls are sequenced properly. 93 self.assertAllEqual(tr(test_input), 3) 94 95 def test_list_stack(self): 96 97 def f(): 98 l = [1, 2, 3] 99 return array_ops.stack(l) 100 101 tr = self.transform(f, lists) 102 103 self.assertAllEqual(self.evaluate(tr()), [1, 2, 3]) 104 105 106if __name__ == '__main__': 107 test.main() 108