• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Loops with type changing variables."""
16
17import re
18
19from absl.testing import parameterized
20import tensorflow as tf
21
22from tensorflow.python.autograph.tests import reference_test_base
23
24
25def while_with_variable_py_type():
26  n = tf.constant(0, dtype=tf.int32)
27  c = True
28  while c:
29    c = tf.constant(True)
30  return n
31
32
33def while_with_variable_dtype():
34  n = tf.constant(0, dtype=tf.int32)
35  while tf.constant(True):
36    n = tf.constant(0, dtype=tf.float32)
37  return n
38
39
40def while_with_variable_dtype_and_early_stopping():
41  n = tf.constant(0, dtype=tf.int32)
42  while tf.constant(True):
43    n = tf.constant(0, dtype=tf.float32)
44    break
45  return n
46
47
48def for_with_variable_dtype(l):
49  n = tf.constant(0, dtype=tf.int32)
50  for _ in l:
51    n = tf.constant(0, dtype=tf.float32)
52  return n
53
54
55def for_with_variable_dtype_and_early_stopping(l):
56  n = tf.constant(0, dtype=tf.int32)
57  for _ in l:
58    n = tf.constant(0, dtype=tf.float32)
59    break
60  return n
61
62
63def while_with_variable_shape():
64  t = tf.constant([1])
65  while tf.constant(True):
66    t = tf.constant([1, 1])
67  return t
68
69
70def for_with_variable_shape(l):
71  t = tf.constant([1])
72  for _ in l:
73    t = tf.constant([1, 1])
74  return t
75
76
77def while_with_shape_erasure():
78  t = tf.constant([1])
79  while tf.constant(True):
80    t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32))
81  return t
82
83
84def for_with_shape_erasure(l):
85  t = tf.constant([1])
86  for _ in l:
87    t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32))
88  return t
89
90
91def while_with_shape_invariant_violation():
92  t = tf.constant([1])
93  while tf.constant(True):
94    tf.autograph.experimental.set_loop_options(
95        shape_invariants=((t, tf.TensorShape([1])),))
96    t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32))
97  return t
98
99
100def for_with_shape_invariant_violation(l):
101  t = tf.constant([1])
102  for _ in l:
103    tf.autograph.experimental.set_loop_options(
104        shape_invariants=((t, tf.TensorShape([1])),))
105    t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32))
106  return t
107
108
109def while_with_variable_structure():
110  s = {'a': tf.constant(0)}
111  while tf.constant(True):
112    s = tf.constant(7.0)
113  return s
114
115
116def for_with_variable_structure(l):
117  s = [tf.constant(0)]
118  for _ in l:
119    s = s + [tf.constant(0)]
120  return s
121
122
123def _tf_range(l):
124  return tf.range(len(l))
125
126
127def _dataset(l):
128  return tf.data.Dataset.from_tensor_slices(l)
129
130
131def _dataset_iterator(l):
132  return iter(tf.data.Dataset.from_tensor_slices(l))
133
134
135def _distributed_dataset(l):
136  ds = tf.data.Dataset.from_tensor_slices([l] * 2)
137  return tf.distribute.MirroredStrategy().experimental_distribute_dataset(ds)
138
139
140class ReferenceTest(reference_test_base.TestCase, parameterized.TestCase):
141
142  def test_while_with_variable_py_type(self):
143    with self.assertRaisesRegex(
144        NotImplementedError,
145        re.compile(
146            r'.*condition of while loop started as non\-Tensor,'
147            r' then changed to Tensor.*', re.DOTALL)):
148      tf.function(while_with_variable_py_type)()
149
150  def test_while_with_variable_dtype(self):
151    with self.assertRaisesRegex(
152        TypeError,
153        "'n' has dtype int32 before the loop, but dtype float32 after"):
154      tf.function(while_with_variable_dtype)()
155
156  def test_while_with_variable_dtype_and_early_stopping(self):
157    with self.assertRaisesRegex(
158        TypeError,
159        "'n' has dtype int32 before the loop, but dtype float32 after"):
160      tf.function(while_with_variable_dtype_and_early_stopping)()
161
162  @parameterized.parameters(
163      (tf.constant,),
164      (_tf_range,),
165      (_dataset,),
166      (_dataset_iterator,),
167      (_distributed_dataset,),
168  )
169  def test_for_with_variable_dtype(self, type_):
170    l = type_([1, 2, 3])
171    with self.assertRaisesRegex(
172        TypeError,
173        "'n' has dtype int32 before the loop, but dtype float32 after"):
174      tf.function(for_with_variable_dtype)(l)
175
176  # Note: distributed datasets don't allow early stopping.
177  @parameterized.parameters(
178      (tf.constant,),
179      (_tf_range,),
180      (_dataset,),
181      (_dataset_iterator,),
182  )
183  def test_for_with_variable_dtype_and_early_stopping(self, type_):
184    l = type_([1, 2, 3])
185    with self.assertRaisesRegex(
186        TypeError,
187        "'n' has dtype int32 before the loop, but dtype float32 after"):
188      tf.function(for_with_variable_dtype_and_early_stopping)(l)
189
190  def test_while_with_variable_shape(self):
191    with self.assertRaisesRegex(
192        ValueError,
193        r"'t' has shape \(1,\) before the loop, but shape \(2,\) after"):
194      tf.function(while_with_variable_shape)()
195
196  # Note: datasets do allow variable shape.
197  @parameterized.parameters(
198      (tf.constant,),
199      (_tf_range,),
200      (_dataset_iterator,),
201      (_distributed_dataset,),
202  )
203  def test_for_with_variable_shape(self, type_):
204    l = type_([1, 2, 3])
205    with self.assertRaisesRegex(
206        ValueError,
207        r"'t' has shape \(1,\) before the loop, but shape \(2,\) after"):
208      tf.function(for_with_variable_shape)(l)
209
210  def test_while_with_shape_erasure(self):
211    with self.assertRaisesRegex(
212        ValueError,
213        r"'t' has shape \(1,\) before the loop, but shape \(None,\) after"):
214      tf.function(while_with_shape_erasure)()
215
216  # Note: datasets do allow variable shape.
217  @parameterized.parameters(
218      (tf.constant,),
219      (_tf_range,),
220      (_dataset_iterator,),
221      (_distributed_dataset,),
222  )
223  def test_for_with_shape_erasure(self, type_):
224    l = type_([1, 2, 3])
225    with self.assertRaisesRegex(
226        ValueError,
227        r"'t' has shape \(1,\) before the loop, but shape \(None,\) after"):
228      tf.function(for_with_shape_erasure)(l)
229
230  def test_while_with_shape_invariant_violation(self):
231    with self.assertRaisesRegex(
232        ValueError,
233        r"'t' has shape \(None,\) after one iteration, which does not conform"):
234      tf.function(while_with_shape_invariant_violation)()
235
236  # Note: dataset loops ignore shape invariants.
237  @parameterized.parameters(
238      (tf.constant,),
239      (_tf_range,),
240      (_dataset_iterator,),
241      (_distributed_dataset,),
242  )
243  def test_for_with_shape_invariant_violation(self, type_):
244    l = type_([1, 2, 3])
245    with self.assertRaisesRegex(
246        ValueError,
247        r"'t' has shape \(None,\) after one iteration, which does not conform"):
248      tf.function(for_with_shape_invariant_violation)(l)
249
250  def test_while_with_variable_structure(self):
251    with self.assertRaisesRegex(
252        TypeError,
253        "'s' does not have the same nested structure"):
254      tf.function(while_with_variable_structure)()
255
256  @parameterized.parameters(
257      (tf.constant,),
258      (_tf_range,),
259      (_dataset,),
260      (_dataset_iterator,),
261      (_distributed_dataset,),
262  )
263  def test_for_with_variable_structure(self, type_):
264    l = type_([1, 2, 3])
265    with self.assertRaisesRegex(
266        TypeError,
267        "'s' does not have the same nested structure"):
268      tf.function(for_with_variable_structure)(l)
269
270
271if __name__ == '__main__':
272  tf.test.main()
273