• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for control_flow module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22
23from tensorflow.python.autograph.converters import control_flow
24from tensorflow.python.autograph.core import converter_testing
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import test_util
28from tensorflow.python.platform import test
29
30
31class ControlFlowTest(converter_testing.TestCase):
32
33  def assertTransformedResult(self, test_fn, inputs, expected, symbols=None):
34    if not isinstance(inputs, tuple):
35      inputs = (inputs,)
36    if not symbols:
37      symbols = {}
38    with self.converted(test_fn, control_flow, symbols,
39                        constant_op.constant) as result:
40      self.assertEqual(self.evaluate(result.test_fn(*inputs)), expected)
41
42  @test_util.run_deprecated_v1
43  def test_while_basic(self):
44
45    def test_fn(n):
46      i = 0
47      s = 0
48      while i < n:
49        s += i
50        i += 1
51      return s, i, n
52
53    self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5))
54
55  @test_util.run_deprecated_v1
56  def test_while_nested(self):
57
58    def test_fn(n):
59      i = 0
60      j = 0
61      s = 0
62      while i < n:
63        while j < i:
64          j += 3
65        u = i + j  # 'u' is not defined within the inner loop
66        s += u
67        i += 1
68        j = 0
69      return s, i, j, n
70
71    self.assertTransformedResult(test_fn, constant_op.constant(5),
72                                 (25, 5, 0, 5))
73
74  @test_util.run_deprecated_v1
75  def test_while_single_output(self):
76
77    def test_fn(n):
78      while n > 0:
79        n -= 1
80      return n
81
82    self.assertTransformedResult(test_fn, constant_op.constant(5), 0)
83
84  def test_while_local_composite(self):
85
86    class TestClass(object):
87
88      def __init__(self):
89        self.x = constant_op.constant(3)
90
91    def test_fn(n):
92      while n > 0:
93        tc = TestClass()
94        tc.x = tc.x
95        n -= 1
96      return n
97
98    self.assertTransformedResult(
99        test_fn, constant_op.constant(5), 0, symbols={'TestClass': TestClass})
100
101  # TODO(b/127642077): Add tests for x.y.z = 2*x.y.z and x.y[z] = 2*x.y[z].
102  def test_while_local_composite_complex_nestable(self):
103
104    # This class is ok to be in a tf.while_loop's state.
105    class TestClass(collections.namedtuple('TestClass', ('x'))):
106      pass
107
108    def test_fn(n):
109      tc = TestClass([constant_op.constant(0)])
110      while n > 0:
111        tc = TestClass([constant_op.constant(3)])
112        tc.x[0] = tc.x[0] + 1
113        n -= 1
114      return tc.x[0]
115
116    ns = {'TestClass': TestClass, 'constant_op': constant_op}
117    self.assertTransformedResult(
118        test_fn, constant_op.constant(5), 4, symbols=ns)
119
120  def test_while_local_composite_complex_illegal(self):
121
122    class TestClass(object):
123
124      def __init__(self):
125        self.x = [constant_op.constant(3)]
126
127    def test_fn(n):
128      while n > 0:
129        tc = TestClass()
130        tc.x[0] = tc.x[0] + 1
131        n -= 1
132      return tc.x[0]
133
134    with self.converted(
135        test_fn, control_flow, {'TestClass': TestClass}) as result:
136      # The tested function would require `tc` to become part of the while loop
137      # state, but TensorFlow doesn't support classes at the moment.
138      with self.assertRaisesRegexp(ValueError, 'must.*initialize.*Tensor.*tc'):
139        result.test_fn(constant_op.constant(5))
140
141  @test_util.run_deprecated_v1
142  def test_while_dispatches_by_cond_only(self):
143
144    class TensorIncompatibleNumeric(object):
145      """Works in arithmetic expression, but errors out with TF ops."""
146
147      def __init__(self, val):
148        self.val = val
149
150      def __add__(self, other):
151        return TensorIncompatibleNumeric(self.val + other)
152
153    def test_fn(n, s):
154      while n > 0:
155        n -= 1
156        s += n
157      return s
158
159    self.assertTransformedResult(test_fn, (constant_op.constant(5), 0), 10)
160    with self.converted(test_fn, control_flow, {}) as result:
161      # n alone controls the staging. When the loop is not staged, Python
162      # knows how to add the two objects. But when staged, tf.while_loop will
163      # not know how to deal with the TensorIncompatibleNumeric object.
164      self.assertEqual(result.test_fn(5, TensorIncompatibleNumeric(0)).val, 10)
165      with self.assertRaises(TypeError):
166        result.test_fn(constant_op.constant(5), TensorIncompatibleNumeric(0))
167
168  @test_util.run_deprecated_v1
169  def test_if_basic(self):
170
171    def test_fn(n):
172      a = 0
173      b = 0
174      if n > 0:
175        a = -n
176      else:
177        b = 2 * n
178      return a, b
179
180    self.assertTransformedResult(test_fn, constant_op.constant(1), (-1, 0))
181    self.assertTransformedResult(test_fn, constant_op.constant(-1), (0, -2))
182
183  @test_util.run_deprecated_v1
184  def test_if_complex_outputs(self):
185
186    class TestClass(object):
187
188      def __init__(self, a, b):
189        self.a = a
190        self.b = b
191
192    def test_fn(n, obj):
193      obj.a = 0
194      obj.b = 0
195      if n > 0:
196        obj.a = -n
197      else:
198        obj.b = 2 * n
199      return obj
200
201    with self.converted(test_fn, control_flow, {}) as result:
202      res_obj = result.test_fn(constant_op.constant(1), TestClass(0, 0))
203      self.assertEqual(self.evaluate((res_obj.a, res_obj.b)), (-1, 0))
204      res_obj = result.test_fn(constant_op.constant(-1), TestClass(0, 0))
205      self.assertEqual(self.evaluate((res_obj.a, res_obj.b)), (0, -2))
206
207  @test_util.run_deprecated_v1
208  def test_if_single_output(self):
209
210    def test_fn(n):
211      if n > 0:
212        n = -n
213      return n
214
215    self.assertTransformedResult(test_fn, constant_op.constant(1), -1)
216
217  @test_util.run_deprecated_v1
218  def test_if_semi(self):
219
220    def test_fn(n):
221      if n > 0:
222        n = 3
223      return n
224
225    self.assertTransformedResult(test_fn, constant_op.constant(2), 3)
226    self.assertTransformedResult(test_fn, constant_op.constant(-3), -3)
227
228  @test_util.run_deprecated_v1
229  def test_if_local_var(self):
230
231    def test_fn(n):
232      if n > 0:
233        b = 4
234        n = b + 1
235      return n
236
237    self.assertTransformedResult(test_fn, constant_op.constant(1), 5)
238    self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
239
240  @test_util.run_deprecated_v1
241  def test_if_no_outputs(self):
242
243    def test_fn(n):
244      if n > 0:
245        b = 4  # pylint:disable=unused-variable
246      return n
247
248    # Without side effect guards, the if statement will stage a cond,
249    # but that will be pruned at execution.
250    self.assertTransformedResult(test_fn, constant_op.constant(1), 1)
251    self.assertTransformedResult(test_fn, constant_op.constant(-1), -1)
252
253  @test_util.run_deprecated_v1
254  def test_if_unbalanced_multiple_composites(self):
255
256    class Foo(object):
257
258      def __init__(self):
259        self.b = 2
260        self.c = 3
261
262    def test_fn(x, condition):
263
264      z = 5
265      if condition:
266        x.b = 7
267        x.c = 11
268        z = 13
269
270      return x.b, x.c, z
271
272    self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)),
273                                 (7, 11, 13))
274    self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)),
275                                 (2, 3, 5))
276
277  @test_util.run_deprecated_v1
278  def test_if_unbalanced_composite(self):
279
280    class Foo(object):
281
282      def __init__(self):
283        self.b = 2
284
285    def test_fn(x, condition):
286
287      z = 5
288      if condition:
289        x.b = 7
290        z = 13
291
292      return x.b, z
293
294    self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(True)),
295                                 (7, 13))
296    self.assertTransformedResult(test_fn, (Foo(), constant_op.constant(False)),
297                                 (2, 5))
298
299  @test_util.run_deprecated_v1
300  def test_simple_for(self):
301
302    def test_fn(l):
303      s1 = 0
304      s2 = 0
305      for e in l:
306        s1 += e
307        s2 += e * e
308      return s1, s2
309
310    self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), (4, 10))
311    empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
312    self.assertTransformedResult(test_fn, empty_vector, (0, 0))
313
314  @test_util.run_deprecated_v1
315  def test_for_single_output(self):
316
317    def test_fn(l):
318      s = 0
319      for e in l:
320        s += e
321      return s
322
323    self.assertTransformedResult(test_fn, constant_op.constant([1, 3]), 4)
324    empty_vector = constant_op.constant([], shape=(0,), dtype=dtypes.int32)
325    self.assertTransformedResult(test_fn, empty_vector, 0)
326
327  def test_for_iterated_expression(self):
328
329    eval_count = [0]
330
331    def count_evals(x):
332      eval_count[0] += 1
333      return x
334
335    def test_fn(n):
336      s = 0
337      for e in count_evals(range(n)):
338        s += e
339      return s
340
341    ns = {'count_evals': count_evals}
342    node, ctx = self.prepare(test_fn, ns)
343    node = control_flow.transform(node, ctx)
344
345    with self.compiled(node, ns) as result:
346      self.assertEqual(result.test_fn(5), 10)
347      self.assertEqual(eval_count[0], 1)
348
349  @test_util.run_deprecated_v1
350  def test_for_tuple_unpacking(self):
351    def test_fn(x_list):
352      z = tf.constant(0)  # pylint:disable=undefined-variable
353      for i, x in enumerate(x_list):
354        z = z + x + i
355      return z
356
357    self.assertTransformedResult(test_fn, [3, 3], 7)
358
359
360if __name__ == '__main__':
361  test.main()
362