• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the experimental input pipeline ops."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import itertools
21
22import numpy as np
23
24from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
25from tensorflow.contrib.data.python.ops import scan_ops
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors
30from tensorflow.python.ops import array_ops
31from tensorflow.python.platform import test
32
33
34class ScanDatasetTest(test.TestCase):
35
36  def _count(self, start, step):
37    return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
38        scan_ops.scan(start, lambda state, _: (state + step, state)))
39
40  def testCount(self):
41    start = array_ops.placeholder(dtypes.int32, shape=[])
42    step = array_ops.placeholder(dtypes.int32, shape=[])
43    take = array_ops.placeholder(dtypes.int64, shape=[])
44    iterator = self._count(start, step).take(take).make_initializable_iterator()
45    next_element = iterator.get_next()
46
47    with self.test_session() as sess:
48
49      for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
50                                            (10, 2, 10), (10, -1, 10),
51                                            (10, -2, 10)]:
52        sess.run(iterator.initializer,
53                 feed_dict={start: start_val, step: step_val, take: take_val})
54        for expected, _ in zip(
55            itertools.count(start_val, step_val), range(take_val)):
56          self.assertEqual(expected, sess.run(next_element))
57        with self.assertRaises(errors.OutOfRangeError):
58          sess.run(next_element)
59
60  def testFibonacci(self):
61    iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
62        scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))
63    ).make_one_shot_iterator()
64    next_element = iterator.get_next()
65
66    with self.test_session() as sess:
67      self.assertEqual(1, sess.run(next_element))
68      self.assertEqual(1, sess.run(next_element))
69      self.assertEqual(2, sess.run(next_element))
70      self.assertEqual(3, sess.run(next_element))
71      self.assertEqual(5, sess.run(next_element))
72      self.assertEqual(8, sess.run(next_element))
73
74  def testChangingStateShape(self):
75    # Test the fixed-point shape invariant calculations: start with
76    # initial values with known shapes, and use a scan function that
77    # changes the size of the state on each element.
78    def _scan_fn(state, input_value):
79      # Statically known rank, but dynamic length.
80      ret_longer_vector = array_ops.concat([state[0], state[0]], 0)
81      # Statically unknown rank.
82      ret_larger_rank = array_ops.expand_dims(state[1], 0)
83      return (ret_longer_vector, ret_larger_rank), (state, input_value)
84
85    dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply(
86        scan_ops.scan(([0], 1), _scan_fn))
87    self.assertEqual([None], dataset.output_shapes[0][0].as_list())
88    self.assertIs(None, dataset.output_shapes[0][1].ndims)
89    self.assertEqual([], dataset.output_shapes[1].as_list())
90
91    iterator = dataset.make_one_shot_iterator()
92    next_element = iterator.get_next()
93
94    with self.test_session() as sess:
95      for i in range(5):
96        (longer_vector_val, larger_rank_val), _ = sess.run(next_element)
97        self.assertAllEqual([0] * (2**i), longer_vector_val)
98        self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
99      with self.assertRaises(errors.OutOfRangeError):
100        sess.run(next_element)
101
102  def testIncorrectStateType(self):
103
104    def _scan_fn(state, _):
105      return constant_op.constant(1, dtype=dtypes.int64), state
106
107    dataset = dataset_ops.Dataset.range(10)
108    with self.assertRaisesRegexp(
109        TypeError,
110        "The element types for the new state must match the initial state."):
111      dataset.apply(
112          scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
113
114  def testIncorrectReturnType(self):
115
116    def _scan_fn(unused_state, unused_input_value):
117      return constant_op.constant(1, dtype=dtypes.int64)
118
119    dataset = dataset_ops.Dataset.range(10)
120    with self.assertRaisesRegexp(
121        TypeError,
122        "The scan function must return a pair comprising the new state and the "
123        "output value."):
124      dataset.apply(
125          scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
126
127
128class ScanDatasetSerialzationTest(
129    dataset_serialization_test_base.DatasetSerializationTestBase):
130
131  def _build_dataset(self, num_elements):
132    return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
133        scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
134
135  def testScanCore(self):
136    num_output = 5
137    self.run_core_tests(lambda: self._build_dataset(num_output),
138                        lambda: self._build_dataset(2), num_output)
139
140
141if __name__ == "__main__":
142  test.main()
143