• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for list_comprehensions module."""
16
17from tensorflow.python.autograph.converters import list_comprehensions
18from tensorflow.python.autograph.core import converter_testing
19from tensorflow.python.platform import test
20
21
22class ListCompTest(converter_testing.TestCase):
23
24  def assertTransformedEquivalent(self, f, *inputs):
25    tr = self.transform(f, list_comprehensions)
26    self.assertEqual(f(*inputs), tr(*inputs))
27
28  def test_basic(self):
29
30    def f(l):
31      s = [e * e for e in l]
32      return s
33
34    self.assertTransformedEquivalent(f, [])
35    self.assertTransformedEquivalent(f, [1, 2, 3])
36
37  def test_multiple_generators(self):
38
39    def f(l):
40      s = [e * e for sublist in l for e in sublist]  # pylint:disable=g-complex-comprehension
41      return s
42
43    self.assertTransformedEquivalent(f, [])
44    self.assertTransformedEquivalent(f, [[1], [2], [3]])
45
46  def test_cond(self):
47
48    def f(l):
49      s = [e * e for e in l if e > 1]
50      return s
51
52    self.assertTransformedEquivalent(f, [])
53    self.assertTransformedEquivalent(f, [1, 2, 3])
54
55
56if __name__ == '__main__':
57  test.main()
58