1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the experimental input pipeline ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import itertools 21 22import numpy as np 23 24from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base 25from tensorflow.contrib.data.python.ops import scan_ops 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.framework import constant_op 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import errors 30from tensorflow.python.ops import array_ops 31from tensorflow.python.platform import test 32 33 34class ScanDatasetTest(test.TestCase): 35 36 def _count(self, start, step): 37 return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( 38 scan_ops.scan(start, lambda state, _: (state + step, state))) 39 40 def testCount(self): 41 start = array_ops.placeholder(dtypes.int32, shape=[]) 42 step = array_ops.placeholder(dtypes.int32, shape=[]) 43 take = array_ops.placeholder(dtypes.int64, shape=[]) 44 iterator = self._count(start, step).take(take).make_initializable_iterator() 45 next_element = iterator.get_next() 46 47 with self.test_session() as sess: 48 49 for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), 50 (10, 2, 10), (10, -1, 10), 51 (10, -2, 10)]: 52 sess.run(iterator.initializer, 53 feed_dict={start: start_val, step: step_val, take: take_val}) 54 for expected, _ in zip( 55 itertools.count(start_val, step_val), range(take_val)): 56 self.assertEqual(expected, sess.run(next_element)) 57 with self.assertRaises(errors.OutOfRangeError): 58 sess.run(next_element) 59 60 def testFibonacci(self): 61 iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( 62 scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) 63 ).make_one_shot_iterator() 64 next_element = iterator.get_next() 65 66 with self.test_session() as sess: 67 self.assertEqual(1, sess.run(next_element)) 68 self.assertEqual(1, sess.run(next_element)) 69 self.assertEqual(2, sess.run(next_element)) 70 self.assertEqual(3, sess.run(next_element)) 71 self.assertEqual(5, sess.run(next_element)) 72 self.assertEqual(8, sess.run(next_element)) 73 74 def testChangingStateShape(self): 75 # Test the fixed-point shape invariant calculations: start with 76 # initial values with known shapes, and use a scan function that 77 # changes the size of the state on each element. 78 def _scan_fn(state, input_value): 79 # Statically known rank, but dynamic length. 80 ret_longer_vector = array_ops.concat([state[0], state[0]], 0) 81 # Statically unknown rank. 82 ret_larger_rank = array_ops.expand_dims(state[1], 0) 83 return (ret_longer_vector, ret_larger_rank), (state, input_value) 84 85 dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply( 86 scan_ops.scan(([0], 1), _scan_fn)) 87 self.assertEqual([None], dataset.output_shapes[0][0].as_list()) 88 self.assertIs(None, dataset.output_shapes[0][1].ndims) 89 self.assertEqual([], dataset.output_shapes[1].as_list()) 90 91 iterator = dataset.make_one_shot_iterator() 92 next_element = iterator.get_next() 93 94 with self.test_session() as sess: 95 for i in range(5): 96 (longer_vector_val, larger_rank_val), _ = sess.run(next_element) 97 self.assertAllEqual([0] * (2**i), longer_vector_val) 98 self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val) 99 with self.assertRaises(errors.OutOfRangeError): 100 sess.run(next_element) 101 102 def testIncorrectStateType(self): 103 104 def _scan_fn(state, _): 105 return constant_op.constant(1, dtype=dtypes.int64), state 106 107 dataset = dataset_ops.Dataset.range(10) 108 with self.assertRaisesRegexp( 109 TypeError, 110 "The element types for the new state must match the initial state."): 111 dataset.apply( 112 scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) 113 114 def testIncorrectReturnType(self): 115 116 def _scan_fn(unused_state, unused_input_value): 117 return constant_op.constant(1, dtype=dtypes.int64) 118 119 dataset = dataset_ops.Dataset.range(10) 120 with self.assertRaisesRegexp( 121 TypeError, 122 "The scan function must return a pair comprising the new state and the " 123 "output value."): 124 dataset.apply( 125 scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) 126 127 128class ScanDatasetSerialzationTest( 129 dataset_serialization_test_base.DatasetSerializationTestBase): 130 131 def _build_dataset(self, num_elements): 132 return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( 133 scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) 134 135 def testScanCore(self): 136 num_output = 5 137 self.run_core_tests(lambda: self._build_dataset(num_output), 138 lambda: self._build_dataset(2), num_output) 139 140 141if __name__ == "__main__": 142 test.main() 143