• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.range()`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python.data.kernel_tests import test_base
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.framework import combinations
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.platform import test
29
30
31class RangeTest(test_base.DatasetTestBase, parameterized.TestCase):
32
33  @combinations.generate(
34      combinations.times(
35          test_base.default_test_combinations(),
36          combinations.combine(output_type=[
37              dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
38          ])))
39  def testStop(self, output_type):
40    stop = 5
41    dataset = dataset_ops.Dataset.range(stop, output_type=output_type)
42    expected_output = np.arange(stop, dtype=output_type.as_numpy_dtype)
43    self.assertDatasetProduces(dataset, expected_output=expected_output)
44    self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
45
46  @combinations.generate(
47      combinations.times(
48          test_base.default_test_combinations(),
49          combinations.combine(output_type=[
50              dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
51          ])))
52  def testStartStop(self, output_type):
53    start, stop = 2, 5
54    dataset = dataset_ops.Dataset.range(start, stop, output_type=output_type)
55    expected_output = np.arange(start, stop, dtype=output_type.as_numpy_dtype)
56    self.assertDatasetProduces(dataset, expected_output=expected_output)
57    self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
58
59  @combinations.generate(
60      combinations.times(
61          test_base.default_test_combinations(),
62          combinations.combine(output_type=[
63              dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
64          ])))
65  def testStartStopStep(self, output_type):
66    start, stop, step = 2, 10, 2
67    dataset = dataset_ops.Dataset.range(
68        start, stop, step, output_type=output_type)
69    expected_output = np.arange(
70        start, stop, step, dtype=output_type.as_numpy_dtype)
71    self.assertDatasetProduces(dataset, expected_output=expected_output)
72    self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
73
74  @combinations.generate(
75      combinations.times(
76          test_base.default_test_combinations(),
77          combinations.combine(output_type=[
78              dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
79          ])))
80  def testZeroStep(self, output_type):
81    start, stop, step = 2, 10, 0
82    with self.assertRaises(errors.InvalidArgumentError):
83      dataset = dataset_ops.Dataset.range(
84          start, stop, step, output_type=output_type)
85      self.evaluate(dataset._variant_tensor)
86
87  @combinations.generate(
88      combinations.times(
89          test_base.default_test_combinations(),
90          combinations.combine(output_type=[
91              dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
92          ])))
93  def testNegativeStep(self, output_type):
94    start, stop, step = 2, 10, -1
95    dataset = dataset_ops.Dataset.range(
96        start, stop, step, output_type=output_type)
97    expected_output = np.arange(
98        start, stop, step, dtype=output_type.as_numpy_dtype)
99    self.assertDatasetProduces(dataset, expected_output=expected_output)
100    self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
101
102  @combinations.generate(
103      combinations.times(
104          test_base.default_test_combinations(),
105          combinations.combine(output_type=[
106              dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
107          ])))
108  def testStopLessThanStart(self, output_type):
109    start, stop = 10, 2
110    dataset = dataset_ops.Dataset.range(start, stop, output_type=output_type)
111    expected_output = np.arange(start, stop, dtype=output_type.as_numpy_dtype)
112    self.assertDatasetProduces(dataset, expected_output=expected_output)
113    self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
114
115  @combinations.generate(
116      combinations.times(
117          test_base.default_test_combinations(),
118          combinations.combine(output_type=[
119              dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
120          ])))
121  def testStopLessThanStartWithPositiveStep(self, output_type):
122    start, stop, step = 10, 2, 2
123    dataset = dataset_ops.Dataset.range(
124        start, stop, step, output_type=output_type)
125    expected_output = np.arange(
126        start, stop, step, dtype=output_type.as_numpy_dtype)
127    self.assertDatasetProduces(dataset, expected_output=expected_output)
128    self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
129
130  @combinations.generate(
131      combinations.times(
132          test_base.default_test_combinations(),
133          combinations.combine(output_type=[
134              dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
135          ])))
136  def testStopLessThanStartWithNegativeStep(self, output_type):
137    start, stop, step = 10, 2, -1
138    dataset = dataset_ops.Dataset.range(
139        start, stop, step, output_type=output_type)
140    expected_output = np.arange(
141        start, stop, step, dtype=output_type.as_numpy_dtype)
142    self.assertDatasetProduces(dataset, expected_output=expected_output)
143    self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
144
145
146if __name__ == "__main__":
147  test.main()
148