1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Loops with type changing variables.""" 16 17import re 18 19from absl.testing import parameterized 20import tensorflow as tf 21 22from tensorflow.python.autograph.tests import reference_test_base 23 24 25def while_with_variable_py_type(): 26 n = tf.constant(0, dtype=tf.int32) 27 c = True 28 while c: 29 c = tf.constant(True) 30 return n 31 32 33def while_with_variable_dtype(): 34 n = tf.constant(0, dtype=tf.int32) 35 while tf.constant(True): 36 n = tf.constant(0, dtype=tf.float32) 37 return n 38 39 40def while_with_variable_dtype_and_early_stopping(): 41 n = tf.constant(0, dtype=tf.int32) 42 while tf.constant(True): 43 n = tf.constant(0, dtype=tf.float32) 44 break 45 return n 46 47 48def for_with_variable_dtype(l): 49 n = tf.constant(0, dtype=tf.int32) 50 for _ in l: 51 n = tf.constant(0, dtype=tf.float32) 52 return n 53 54 55def for_with_variable_dtype_and_early_stopping(l): 56 n = tf.constant(0, dtype=tf.int32) 57 for _ in l: 58 n = tf.constant(0, dtype=tf.float32) 59 break 60 return n 61 62 63def while_with_variable_shape(): 64 t = tf.constant([1]) 65 while tf.constant(True): 66 t = tf.constant([1, 1]) 67 return t 68 69 70def for_with_variable_shape(l): 71 t = tf.constant([1]) 72 for _ in l: 73 t = tf.constant([1, 1]) 74 return t 75 76 77def while_with_shape_erasure(): 78 t = tf.constant([1]) 79 while tf.constant(True): 80 t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32)) 81 return t 82 83 84def for_with_shape_erasure(l): 85 t = tf.constant([1]) 86 for _ in l: 87 t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32)) 88 return t 89 90 91def while_with_shape_invariant_violation(): 92 t = tf.constant([1]) 93 while tf.constant(True): 94 tf.autograph.experimental.set_loop_options( 95 shape_invariants=((t, tf.TensorShape([1])),)) 96 t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32)) 97 return t 98 99 100def for_with_shape_invariant_violation(l): 101 t = tf.constant([1]) 102 for _ in l: 103 tf.autograph.experimental.set_loop_options( 104 shape_invariants=((t, tf.TensorShape([1])),)) 105 t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32)) 106 return t 107 108 109def while_with_variable_structure(): 110 s = {'a': tf.constant(0)} 111 while tf.constant(True): 112 s = tf.constant(7.0) 113 return s 114 115 116def for_with_variable_structure(l): 117 s = [tf.constant(0)] 118 for _ in l: 119 s = s + [tf.constant(0)] 120 return s 121 122 123def _tf_range(l): 124 return tf.range(len(l)) 125 126 127def _dataset(l): 128 return tf.data.Dataset.from_tensor_slices(l) 129 130 131def _dataset_iterator(l): 132 return iter(tf.data.Dataset.from_tensor_slices(l)) 133 134 135def _distributed_dataset(l): 136 ds = tf.data.Dataset.from_tensor_slices([l] * 2) 137 return tf.distribute.MirroredStrategy().experimental_distribute_dataset(ds) 138 139 140class ReferenceTest(reference_test_base.TestCase, parameterized.TestCase): 141 142 def test_while_with_variable_py_type(self): 143 with self.assertRaisesRegex( 144 NotImplementedError, 145 re.compile( 146 r'.*condition of while loop started as non\-Tensor,' 147 r' then changed to Tensor.*', re.DOTALL)): 148 tf.function(while_with_variable_py_type)() 149 150 def test_while_with_variable_dtype(self): 151 with self.assertRaisesRegex( 152 TypeError, 153 "'n' has dtype int32 before the loop, but dtype float32 after"): 154 tf.function(while_with_variable_dtype)() 155 156 def test_while_with_variable_dtype_and_early_stopping(self): 157 with self.assertRaisesRegex( 158 TypeError, 159 "'n' has dtype int32 before the loop, but dtype float32 after"): 160 tf.function(while_with_variable_dtype_and_early_stopping)() 161 162 @parameterized.parameters( 163 (tf.constant,), 164 (_tf_range,), 165 (_dataset,), 166 (_dataset_iterator,), 167 (_distributed_dataset,), 168 ) 169 def test_for_with_variable_dtype(self, type_): 170 l = type_([1, 2, 3]) 171 with self.assertRaisesRegex( 172 TypeError, 173 "'n' has dtype int32 before the loop, but dtype float32 after"): 174 tf.function(for_with_variable_dtype)(l) 175 176 # Note: distributed datasets don't allow early stopping. 177 @parameterized.parameters( 178 (tf.constant,), 179 (_tf_range,), 180 (_dataset,), 181 (_dataset_iterator,), 182 ) 183 def test_for_with_variable_dtype_and_early_stopping(self, type_): 184 l = type_([1, 2, 3]) 185 with self.assertRaisesRegex( 186 TypeError, 187 "'n' has dtype int32 before the loop, but dtype float32 after"): 188 tf.function(for_with_variable_dtype_and_early_stopping)(l) 189 190 def test_while_with_variable_shape(self): 191 with self.assertRaisesRegex( 192 ValueError, 193 r"'t' has shape \(1,\) before the loop, but shape \(2,\) after"): 194 tf.function(while_with_variable_shape)() 195 196 # Note: datasets do allow variable shape. 197 @parameterized.parameters( 198 (tf.constant,), 199 (_tf_range,), 200 (_dataset_iterator,), 201 (_distributed_dataset,), 202 ) 203 def test_for_with_variable_shape(self, type_): 204 l = type_([1, 2, 3]) 205 with self.assertRaisesRegex( 206 ValueError, 207 r"'t' has shape \(1,\) before the loop, but shape \(2,\) after"): 208 tf.function(for_with_variable_shape)(l) 209 210 def test_while_with_shape_erasure(self): 211 with self.assertRaisesRegex( 212 ValueError, 213 r"'t' has shape \(1,\) before the loop, but shape \(None,\) after"): 214 tf.function(while_with_shape_erasure)() 215 216 # Note: datasets do allow variable shape. 217 @parameterized.parameters( 218 (tf.constant,), 219 (_tf_range,), 220 (_dataset_iterator,), 221 (_distributed_dataset,), 222 ) 223 def test_for_with_shape_erasure(self, type_): 224 l = type_([1, 2, 3]) 225 with self.assertRaisesRegex( 226 ValueError, 227 r"'t' has shape \(1,\) before the loop, but shape \(None,\) after"): 228 tf.function(for_with_shape_erasure)(l) 229 230 def test_while_with_shape_invariant_violation(self): 231 with self.assertRaisesRegex( 232 ValueError, 233 r"'t' has shape \(None,\) after one iteration, which does not conform"): 234 tf.function(while_with_shape_invariant_violation)() 235 236 # Note: dataset loops ignore shape invariants. 237 @parameterized.parameters( 238 (tf.constant,), 239 (_tf_range,), 240 (_dataset_iterator,), 241 (_distributed_dataset,), 242 ) 243 def test_for_with_shape_invariant_violation(self, type_): 244 l = type_([1, 2, 3]) 245 with self.assertRaisesRegex( 246 ValueError, 247 r"'t' has shape \(None,\) after one iteration, which does not conform"): 248 tf.function(for_with_shape_invariant_violation)(l) 249 250 def test_while_with_variable_structure(self): 251 with self.assertRaisesRegex( 252 TypeError, 253 "'s' does not have the same nested structure"): 254 tf.function(while_with_variable_structure)() 255 256 @parameterized.parameters( 257 (tf.constant,), 258 (_tf_range,), 259 (_dataset,), 260 (_dataset_iterator,), 261 (_distributed_dataset,), 262 ) 263 def test_for_with_variable_structure(self, type_): 264 l = type_([1, 2, 3]) 265 with self.assertRaisesRegex( 266 TypeError, 267 "'s' does not have the same nested structure"): 268 tf.function(for_with_variable_structure)(l) 269 270 271if __name__ == '__main__': 272 tf.test.main() 273