• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for lists module."""
16
17from tensorflow.python.autograph.converters import directives as directives_converter
18from tensorflow.python.autograph.converters import lists
19from tensorflow.python.autograph.core import converter_testing
20from tensorflow.python.autograph.lang import directives
21from tensorflow.python.autograph.lang import special_functions
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import list_ops
26from tensorflow.python.platform import test
27
28
29class ListTest(converter_testing.TestCase):
30
31  def test_empty_list(self):
32
33    def f():
34      return []
35
36    tr = self.transform(f, lists)
37
38    tl = tr()
39    # Empty tensor lists cannot be evaluated or stacked.
40    self.assertIsInstance(tl, ops.Tensor)
41    self.assertEqual(tl.dtype, dtypes.variant)
42
43  def test_initialized_list(self):
44
45    def f():
46      return [1, 2, 3]
47
48    tr = self.transform(f, lists)
49
50    self.assertAllEqual(tr(), [1, 2, 3])
51
52  def test_list_append(self):
53
54    def f():
55      l = special_functions.tensor_list([1])
56      l.append(2)
57      l.append(3)
58      return l
59
60    tr = self.transform(f, lists)
61
62    tl = tr()
63    r = list_ops.tensor_list_stack(tl, dtypes.int32)
64    self.assertAllEqual(self.evaluate(r), [1, 2, 3])
65
66  def test_list_pop(self):
67
68    def f():
69      l = special_functions.tensor_list([1, 2, 3])
70      directives.set_element_type(l, dtype=dtypes.int32, shape=())
71      s = l.pop()
72      return s, l
73
74    tr = self.transform(f, (directives_converter, lists))
75
76    ts, tl = tr()
77    r = list_ops.tensor_list_stack(tl, dtypes.int32)
78    self.assertAllEqual(self.evaluate(r), [1, 2])
79    self.assertAllEqual(self.evaluate(ts), 3)
80
81  def test_double_list_pop(self):
82
83    def f(l):
84      s = l.pop().pop()
85      return s
86
87    tr = self.transform(f, lists)
88
89    test_input = [1, 2, [1, 2, 3]]
90    # TODO(mdan): Pass a list of lists of tensor when we fully support that.
91    # For now, we just pass a regular Python list of lists just to verify that
92    # the two pop calls are sequenced properly.
93    self.assertAllEqual(tr(test_input), 3)
94
95  def test_list_stack(self):
96
97    def f():
98      l = [1, 2, 3]
99      return array_ops.stack(l)
100
101    tr = self.transform(f, lists)
102
103    self.assertAllEqual(self.evaluate(tr()), [1, 2, 3])
104
105
106if __name__ == '__main__':
107  test.main()
108