1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the experimental input pipeline ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base 23from tensorflow.contrib.data.python.ops import grouping 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import string_ops 33from tensorflow.python.platform import test 34 35 36class GroupByWindowTest(test.TestCase): 37 38 def testSimple(self): 39 components = np.random.randint(100, size=(200,)).astype(np.int64) 40 iterator = ( 41 dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x) 42 .apply( 43 grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 44 4)).make_initializable_iterator()) 45 init_op = iterator.initializer 46 get_next = iterator.get_next() 47 48 with self.test_session() as sess: 49 sess.run(init_op) 50 counts = [] 51 with self.assertRaises(errors.OutOfRangeError): 52 while True: 53 result = sess.run(get_next) 54 self.assertTrue( 55 all(x % 2 == 0 56 for x in result) or all(x % 2 == 1) 57 for x in result) 58 counts.append(result.shape[0]) 59 60 self.assertEqual(len(components), sum(counts)) 61 num_full_batches = len([c for c in counts if c == 4]) 62 self.assertGreaterEqual(num_full_batches, 23) 63 self.assertTrue(all(c == 4 for c in counts[:num_full_batches])) 64 65 def testImmediateOutput(self): 66 components = np.array( 67 [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) 68 iterator = ( 69 dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( 70 grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 71 4)).make_initializable_iterator()) 72 init_op = iterator.initializer 73 get_next = iterator.get_next() 74 75 with self.test_session() as sess: 76 sess.run(init_op) 77 # The input is infinite, so this test demonstrates that: 78 # 1. We produce output without having to consume the entire input, 79 # 2. Different buckets can produce output at different rates, and 80 # 3. For deterministic input, the output is deterministic. 81 for _ in range(3): 82 self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) 83 self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) 84 self.assertAllEqual([2, 2, 2, 2], sess.run(get_next)) 85 self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) 86 87 def testSmallGroups(self): 88 components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64) 89 iterator = ( 90 dataset_ops.Dataset.from_tensor_slices(components).apply( 91 grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), 92 4)).make_initializable_iterator()) 93 init_op = iterator.initializer 94 get_next = iterator.get_next() 95 96 with self.test_session() as sess: 97 sess.run(init_op) 98 self.assertAllEqual([0, 0, 0, 0], sess.run(get_next)) 99 self.assertAllEqual([1, 1, 1, 1], sess.run(get_next)) 100 # The small outputs at the end are deterministically produced in key 101 # order. 102 self.assertAllEqual([0, 0, 0], sess.run(get_next)) 103 self.assertAllEqual([1], sess.run(get_next)) 104 105 def testReduceFuncError(self): 106 components = np.random.randint(100, size=(200,)).astype(np.int64) 107 108 def reduce_func(_, xs): 109 # Introduce an incorrect padded shape that cannot (currently) be 110 # detected at graph construction time. 111 return xs.padded_batch( 112 4, 113 padded_shapes=(tensor_shape.TensorShape([]), 114 constant_op.constant([5], dtype=dtypes.int64) * -1)) 115 116 iterator = ( 117 dataset_ops.Dataset.from_tensor_slices(components) 118 .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply( 119 grouping.group_by_window(lambda x, _: x % 2, reduce_func, 120 32)).make_initializable_iterator()) 121 init_op = iterator.initializer 122 get_next = iterator.get_next() 123 124 with self.test_session() as sess: 125 sess.run(init_op) 126 with self.assertRaises(errors.InvalidArgumentError): 127 sess.run(get_next) 128 129 def testConsumeWindowDatasetMoreThanOnce(self): 130 components = np.random.randint(50, size=(200,)).astype(np.int64) 131 132 def reduce_func(key, window): 133 # Apply two different kinds of padding to the input: tight 134 # padding, and quantized (to a multiple of 10) padding. 135 return dataset_ops.Dataset.zip(( 136 window.padded_batch( 137 4, padded_shapes=tensor_shape.TensorShape([None])), 138 window.padded_batch( 139 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])), 140 )) 141 142 iterator = ( 143 dataset_ops.Dataset.from_tensor_slices(components) 144 .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x)) 145 .apply(grouping.group_by_window( 146 lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64), 147 reduce_func, 4)) 148 .make_initializable_iterator()) 149 init_op = iterator.initializer 150 get_next = iterator.get_next() 151 152 with self.test_session() as sess: 153 sess.run(init_op) 154 counts = [] 155 with self.assertRaises(errors.OutOfRangeError): 156 while True: 157 tight_result, multiple_of_10_result = sess.run(get_next) 158 self.assertEqual(0, multiple_of_10_result.shape[1] % 10) 159 self.assertAllEqual(tight_result, 160 multiple_of_10_result[:, :tight_result.shape[1]]) 161 counts.append(tight_result.shape[0]) 162 self.assertEqual(len(components), sum(counts)) 163 164 165class GroupByWindowSerializationTest( 166 dataset_serialization_test_base.DatasetSerializationTestBase): 167 168 def _build_dataset(self, components): 169 return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( 170 grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) 171 172 def testCoreGroupByWindow(self): 173 components = np.array( 174 [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) 175 self.verify_unused_iterator( 176 lambda: self._build_dataset(components), 12, verify_exhausted=False) 177 self.verify_init_before_restore( 178 lambda: self._build_dataset(components), 12, verify_exhausted=False) 179 self.verify_multiple_breaks( 180 lambda: self._build_dataset(components), 12, verify_exhausted=False) 181 self.verify_reset_restored_iterator( 182 lambda: self._build_dataset(components), 12, verify_exhausted=False) 183 self.verify_restore_in_empty_graph( 184 lambda: self._build_dataset(components), 12, verify_exhausted=False) 185 diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) 186 self.verify_restore_in_modified_graph( 187 lambda: self._build_dataset(components), 188 lambda: self._build_dataset(diff_components), 189 12, 190 verify_exhausted=False) 191 192 193# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. 194# Currently, they use a constant batch size, though should be made to use a 195# different batch size per key. 196class BucketTest(test.TestCase): 197 198 def _dynamicPad(self, bucket, window, window_size): 199 # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a 200 # generic form of padded_batch that pads every component 201 # dynamically and does not rely on static shape information about 202 # the arguments. 203 return dataset_ops.Dataset.zip( 204 (dataset_ops.Dataset.from_tensors(bucket), 205 window.padded_batch( 206 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape( 207 [None]), tensor_shape.TensorShape([3]))))) 208 209 def testSingleBucket(self): 210 211 def _map_fn(v): 212 return (v, array_ops.fill([v], v), 213 array_ops.fill([3], string_ops.as_string(v))) 214 215 input_dataset = ( 216 dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn)) 217 218 bucketed_dataset = input_dataset.apply( 219 grouping.group_by_window( 220 lambda x, y, z: 0, 221 lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) 222 223 iterator = bucketed_dataset.make_initializable_iterator() 224 init_op = iterator.initializer 225 get_next = iterator.get_next() 226 227 with self.test_session() as sess: 228 sess.run(init_op) 229 230 which_bucket, bucketed_values = sess.run(get_next) 231 232 self.assertEqual(0, which_bucket) 233 234 expected_scalar_int = np.arange(32, dtype=np.int64) 235 expected_unk_int64 = np.zeros((32, 31)).astype(np.int64) 236 for i in range(32): 237 expected_unk_int64[i, :i] = i 238 expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T 239 240 self.assertAllEqual(expected_scalar_int, bucketed_values[0]) 241 self.assertAllEqual(expected_unk_int64, bucketed_values[1]) 242 self.assertAllEqual(expected_vec3_str, bucketed_values[2]) 243 244 def testEvenOddBuckets(self): 245 246 def _map_fn(v): 247 return (v, array_ops.fill([v], v), 248 array_ops.fill([3], string_ops.as_string(v))) 249 250 input_dataset = ( 251 dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn)) 252 253 bucketed_dataset = input_dataset.apply( 254 grouping.group_by_window( 255 lambda x, y, z: math_ops.cast(x % 2, dtypes.int64), 256 lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) 257 258 iterator = bucketed_dataset.make_initializable_iterator() 259 init_op = iterator.initializer 260 get_next = iterator.get_next() 261 262 with self.test_session() as sess: 263 sess.run(init_op) 264 265 # Get two minibatches (one containing even values, one containing odds) 266 which_bucket_even, bucketed_values_even = sess.run(get_next) 267 which_bucket_odd, bucketed_values_odd = sess.run(get_next) 268 269 # Count number of bucket_tensors. 270 self.assertEqual(3, len(bucketed_values_even)) 271 self.assertEqual(3, len(bucketed_values_odd)) 272 273 # Ensure bucket 0 was used for all minibatch entries. 274 self.assertAllEqual(0, which_bucket_even) 275 self.assertAllEqual(1, which_bucket_odd) 276 277 # Test the first bucket outputted, the events starting at 0 278 expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64) 279 expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64) 280 for i in range(0, 32): 281 expected_unk_int64[i, :2 * i] = 2 * i 282 expected_vec3_str = np.vstack( 283 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T 284 285 self.assertAllEqual(expected_scalar_int, bucketed_values_even[0]) 286 self.assertAllEqual(expected_unk_int64, bucketed_values_even[1]) 287 self.assertAllEqual(expected_vec3_str, bucketed_values_even[2]) 288 289 # Test the second bucket outputted, the odds starting at 1 290 expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64) 291 expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64) 292 for i in range(0, 32): 293 expected_unk_int64[i, :2 * i + 1] = 2 * i + 1 294 expected_vec3_str = np.vstack( 295 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T 296 297 self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0]) 298 self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1]) 299 self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2]) 300 301 def testEvenOddBucketsFilterOutAllOdd(self): 302 303 def _map_fn(v): 304 return { 305 "x": v, 306 "y": array_ops.fill([v], v), 307 "z": array_ops.fill([3], string_ops.as_string(v)) 308 } 309 310 def _dynamic_pad_fn(bucket, window, _): 311 return dataset_ops.Dataset.zip( 312 (dataset_ops.Dataset.from_tensors(bucket), 313 window.padded_batch( 314 32, { 315 "x": tensor_shape.TensorShape([]), 316 "y": tensor_shape.TensorShape([None]), 317 "z": tensor_shape.TensorShape([3]) 318 }))) 319 320 input_dataset = ( 321 dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn) 322 .filter(lambda d: math_ops.equal(d["x"] % 2, 0))) 323 324 bucketed_dataset = input_dataset.apply( 325 grouping.group_by_window( 326 lambda d: math_ops.cast(d["x"] % 2, dtypes.int64), 327 lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)) 328 329 iterator = bucketed_dataset.make_initializable_iterator() 330 init_op = iterator.initializer 331 get_next = iterator.get_next() 332 333 with self.test_session() as sess: 334 sess.run(init_op) 335 336 # Get two minibatches ([0, 2, ...] and [64, 66, ...]) 337 which_bucket0, bucketed_values_even0 = sess.run(get_next) 338 which_bucket1, bucketed_values_even1 = sess.run(get_next) 339 340 # Ensure that bucket 1 was completely filtered out 341 self.assertAllEqual(0, which_bucket0) 342 self.assertAllEqual(0, which_bucket1) 343 self.assertAllEqual( 344 np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"]) 345 self.assertAllEqual( 346 np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"]) 347 348 def testDynamicWindowSize(self): 349 components = np.arange(100).astype(np.int64) 350 351 # Key fn: even/odd 352 # Reduce fn: batches of 5 353 # Window size fn: even=5, odd=10 354 355 def window_size_func(key): 356 window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64) 357 return window_sizes[key] 358 359 dataset = dataset_ops.Dataset.from_tensor_slices(components).apply( 360 grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20), 361 None, window_size_func)) 362 iterator = dataset.make_initializable_iterator() 363 init_op = iterator.initializer 364 get_next = iterator.get_next() 365 366 with self.test_session() as sess: 367 sess.run(init_op) 368 with self.assertRaises(errors.OutOfRangeError): 369 batches = 0 370 while True: 371 result = sess.run(get_next) 372 is_even = all(x % 2 == 0 for x in result) 373 is_odd = all(x % 2 == 1 for x in result) 374 self.assertTrue(is_even or is_odd) 375 expected_batch_size = 5 if is_even else 10 376 self.assertEqual(expected_batch_size, result.shape[0]) 377 batches += 1 378 379 self.assertEqual(batches, 15) 380 381 382if __name__ == "__main__": 383 test.main() 384