• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Nested loops and conditional statements (e.g. while, for, if).
16
17Meant to verify that arbitrarily nested statements are processed correctly.
18"""
19
20import itertools
21
22from absl.testing import parameterized
23import tensorflow as tf
24
25from tensorflow.python.autograph.tests import reference_test_base
26
27
28def independent_ifs(x, y):
29  z = 0
30  if x > 0:
31    if y > 0:
32      z = x + y
33  return z
34
35
36def dependent_inner_if(x):
37  y = 0
38  if x > 0:
39    y = -2 * x
40    if y > 0:
41      x = -3 * x
42  else:
43    y = 4 * x
44  return x, y
45
46
47def dependent_imbalanced_inner_if(x):
48  y = 0
49  if x > 0:
50    if x < 3:
51      y = -2 * x
52      x = -3 * x
53  return x, y
54
55
56def _hidden_raise():
57  raise ValueError('exception used for control flow')
58
59
60def if_with_local_modification_masked_by_exception(x):
61  y = 0
62  if x > 0:
63    try:
64      if x > 1:
65        _hidden_raise()
66      y = 1
67    except ValueError:
68      pass
69    if y == 0:
70      y = 2
71  return y
72
73
74_test_global = None
75
76
77def if_nested_with_modification_of_global(x):
78  y = 0
79  if x > 0:
80    if x > 0:
81      global _test_global
82      if _test_global is None:
83        _test_global = 1
84      else:
85        _test_global += 1
86      y += _test_global
87  return y
88
89
90def independent_inner_for(a, b):
91  p = 0
92  for _ in a:
93    tmp = b
94    for j in tmp:
95      p += j
96  return p
97
98
99def independent_inner_while(a, b):
100  p = 0
101  while a > 0:
102    tmp = b
103    while tmp > 0:
104      p += 1
105      tmp -= 1
106    a -= 1
107  return p
108
109
110def dependent_inner_for(a, b):
111  r = 1
112  s = 0
113  for _ in a:
114    r += s
115    tmp = b
116    for j in tmp:
117      s += j
118  return r
119
120
121def dependent_inner_while(a, b):
122  r = 1
123  while a > 0:
124    r += 1
125    tmp = b
126    while tmp > 0:
127      a -= 1
128      tmp -= 1
129    a -= 1
130  return r
131
132
133def if_in_for(a):
134  k = 0
135  for i in a:
136    if i % 2 > 0:
137      j = i // 2
138      k += j
139  return k
140
141
142def while_with_continue_in_context_manager(x):
143  z = 0
144  while x > 0:
145    with tf.name_scope(''):
146      x = x - 1
147      if x < 5:
148        continue
149      z = z + 1
150  return z
151
152
153def while_continue_in_try(x):
154  z = 0
155  while x > 0:
156    x = x - 1
157    try:
158      if x < 5:
159        continue
160      z = z + 1
161    finally:
162      z = z + 10
163  return z
164
165
166def while_break_in_context_manager(x):
167  z = 0
168  while x > 0:
169    with tf.name_scope(''):
170      x = x - 1
171      if x < 5:
172        break
173      z = z + 1
174  return z
175
176
177def while_break_in_try(x):
178  z = 0
179  while x > 0:
180    x = x - 1
181    try:
182      if x < 5:
183        break
184      z = z + 1
185    finally:
186      z = z + 10
187  return z
188
189
190def loop_initializing_invariant_variable(n):
191  for i in range(n):
192    if i == 0:
193      a = 1
194    else:
195      a = 2
196  return a
197
198
199def loop_initializing_variant_variable(n):
200  for i in range(n):
201    if i == 0:
202      a = 1
203    else:
204      a = a + 1
205  return a
206
207
208def _int_tensor(x):
209  return tf.constant(x, dtype=tf.int32)
210
211
212class NestedControlFlowTest(
213    reference_test_base.TestCase, parameterized.TestCase):
214
215  @parameterized.parameters(*itertools.product(
216      (
217          -1, 1,
218      ),
219      (
220          -1, 1,
221      ),
222      (
223          int,
224          _int_tensor,
225      ),
226      (
227          int,
228          _int_tensor,
229      ),
230  ))
231  def test_independent_ifs(self, x, y, type_x, type_y):
232    x = type_x(x)
233    y = type_x(y)
234    self.assertFunctionMatchesEager(independent_ifs, x, y)
235
236  @parameterized.parameters(*itertools.product(
237      (
238          -1, 1,
239      ),
240      (
241          int,
242          _int_tensor,
243      ),
244  ))
245  def test_dependent_inner_if(self, x, type_):
246    x = type_(x)
247    self.assertFunctionMatchesEager(dependent_inner_if, x)
248
249  @parameterized.parameters(*itertools.product(
250      (
251          -1, 1,
252      ),
253      (
254          int,
255          _int_tensor,
256      ),
257  ))
258  def test_dependent_imbalanced_inner_if(self, x, type_):
259    x = type_(x)
260    self.assertFunctionMatchesEager(dependent_imbalanced_inner_if, x)
261
262  @parameterized.parameters(
263      (-1,),
264      (0,),
265      (1,),
266      (2,),
267  )
268  def test_if_with_local_modification_masked_by_exception(self, x):
269    # Note: If the input is a Tensor, the behavior is undefined.
270    self.assertFunctionMatchesEager(
271        if_with_local_modification_masked_by_exception, x)
272
273  def test_if_nested_with_modification_of_global(self):
274    global _test_global
275    _test_global = None
276    self.assertEqual(tf.function(if_nested_with_modification_of_global)(1), 1)
277    self.assertEqual(_test_global, 1)
278
279  def test_if_nested_with_modification_of_global_not_executed(self):
280    global _test_global
281    _test_global = None
282    self.assertEqual(tf.function(if_nested_with_modification_of_global)(0), 0)
283    self.assertIsNone(_test_global)
284
285  @parameterized.parameters(*itertools.product(
286      (
287          0, 1, 2,
288      ),
289      (
290          0, 1, 2,
291      ),
292      (
293          range,
294          tf.range,
295      ),
296      (
297          range,
298          tf.range,
299      ),
300  ))
301  def test_independent_inner_for(self, a, b, type_a, type_b):
302    a = type_a(a)
303    b = type_b(b)
304    self.assertFunctionMatchesEager(independent_inner_for, a, b)
305
306  @parameterized.parameters(*itertools.product(
307      (
308          0, 1, 2,
309      ),
310      (
311          0, 1, 2,
312      ),
313      (
314          int,
315          _int_tensor,
316      ),
317      (
318          int,
319          _int_tensor,
320      ),
321  ))
322  def test_independent_inner_while(self, a, b, type_a, type_b):
323    a = type_a(a)
324    b = type_b(b)
325    self.assertFunctionMatchesEager(independent_inner_while, a, b)
326
327  @parameterized.parameters(*itertools.product(
328      (
329          0, 1, 2,
330      ),
331      (
332          0, 1, 2,
333      ),
334      (
335          range,
336          tf.range,
337      ),
338      (
339          range,
340          tf.range,
341      ),
342  ))
343  def test_dependent_inner_for(self, a, b, type_a, type_b):
344    a = type_a(a)
345    b = type_b(b)
346    self.assertFunctionMatchesEager(dependent_inner_for, a, b)
347
348  @parameterized.parameters(*itertools.product(
349      (
350          0, 1, 2, 3, 4,
351      ),
352      (
353          0, 1, 2, 3, 4,
354      ),
355      (
356          int,
357          _int_tensor,
358      ),
359      (
360          int,
361          _int_tensor,
362      ),
363  ))
364  def test_dependent_inner_while(self, a, b, type_a, type_b):
365    if (type_a is int) and (type_b is _int_tensor):
366      self.skipTest('b/124378596')
367    a = type_a(a)
368    b = type_b(b)
369    self.assertFunctionMatchesEager(dependent_inner_while, a, b)
370
371  @parameterized.parameters(*itertools.product(
372      (
373          0, 1, 2,
374      ),
375      (
376          range,
377          tf.range,
378      ),
379  ))
380  def test_if_in_for(self, a, type_):
381    a = type_(a)
382    self.assertFunctionMatchesEager(if_in_for, a)
383
384  @parameterized.parameters(*itertools.product(
385      (
386          0, 4, 10,
387      ),
388      (
389          int,
390          _int_tensor,
391      ),
392  ))
393  def test_while_continue_in_context_manager(self, x, type_):
394    x = type_(x)
395    self.assertFunctionMatchesEager(while_with_continue_in_context_manager, x)
396
397  @parameterized.parameters(*itertools.product(
398      (
399          0, 4, 10,
400      ),
401      (
402          int,
403          _int_tensor,
404      ),
405  ))
406  def test_while_continue_in_try(self, x, type_):
407    x = type_(x)
408    self.assertFunctionMatchesEager(while_continue_in_try, x)
409
410  @parameterized.parameters(*itertools.product(
411      (
412          0, 4, 10,
413      ),
414      (
415          int,
416          _int_tensor,
417      ),
418  ))
419  def test_while_break_in_context_manager(self, x, type_):
420    x = type_(x)
421    self.assertFunctionMatchesEager(while_break_in_context_manager, x)
422
423  @parameterized.parameters(*itertools.product(
424      (
425          0, 4, 10,
426      ),
427      (
428          int,
429          _int_tensor,
430      ),
431  ))
432  def test_while_break_in_try(self, x, type_):
433    x = type_(x)
434    self.assertFunctionMatchesEager(while_break_in_try, x)
435
436  @parameterized.parameters(*itertools.product(
437      (
438          1, 2,
439      ),
440      (
441          int,
442          _int_tensor,
443      ),
444  ))
445  def test_loop_initializing_invariant_variable_legal(self, n, type_):
446    n = type_(n)
447    self.assertFunctionMatchesEager(loop_initializing_invariant_variable, n)
448
449  def test_loop_initializing_invariant_variable_illegal(self):
450    with self.assertRaises(UnboundLocalError):
451      tf.function(loop_initializing_invariant_variable)(0)
452    with self.assertRaisesRegex(
453        tf.errors.InvalidArgumentError, 'loop must iterate at least once'):
454      tf.function(loop_initializing_invariant_variable)(tf.constant(0))
455
456  @parameterized.parameters(
457      (1,),
458      (2,),
459  )
460  def test_loop_initializing_variant_variable_legal(self, n):
461    tf.function(loop_initializing_variant_variable)(n)
462
463  @parameterized.parameters(
464      (0,),
465      (1,),
466      (2,),
467  )
468  def test_loop_initializing_variant_variable_illegal(self, n):
469    with self.assertRaisesRegex(ValueError, 'must be defined before the loop'):
470      tf.function(loop_initializing_variant_variable)(tf.constant(n))
471
472
473if __name__ == '__main__':
474  tf.test.main()
475