1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.range()`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python.data.kernel_tests import test_base 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.framework import combinations 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.platform import test 29 30 31class RangeTest(test_base.DatasetTestBase, parameterized.TestCase): 32 33 @combinations.generate( 34 combinations.times( 35 test_base.default_test_combinations(), 36 combinations.combine(output_type=[ 37 dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 38 ]))) 39 def testStop(self, output_type): 40 stop = 5 41 dataset = dataset_ops.Dataset.range(stop, output_type=output_type) 42 expected_output = np.arange(stop, dtype=output_type.as_numpy_dtype) 43 self.assertDatasetProduces(dataset, expected_output=expected_output) 44 self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) 45 46 @combinations.generate( 47 combinations.times( 48 test_base.default_test_combinations(), 49 combinations.combine(output_type=[ 50 dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 51 ]))) 52 def testStartStop(self, output_type): 53 start, stop = 2, 5 54 dataset = dataset_ops.Dataset.range(start, stop, output_type=output_type) 55 expected_output = np.arange(start, stop, dtype=output_type.as_numpy_dtype) 56 self.assertDatasetProduces(dataset, expected_output=expected_output) 57 self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) 58 59 @combinations.generate( 60 combinations.times( 61 test_base.default_test_combinations(), 62 combinations.combine(output_type=[ 63 dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 64 ]))) 65 def testStartStopStep(self, output_type): 66 start, stop, step = 2, 10, 2 67 dataset = dataset_ops.Dataset.range( 68 start, stop, step, output_type=output_type) 69 expected_output = np.arange( 70 start, stop, step, dtype=output_type.as_numpy_dtype) 71 self.assertDatasetProduces(dataset, expected_output=expected_output) 72 self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) 73 74 @combinations.generate( 75 combinations.times( 76 test_base.default_test_combinations(), 77 combinations.combine(output_type=[ 78 dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 79 ]))) 80 def testZeroStep(self, output_type): 81 start, stop, step = 2, 10, 0 82 with self.assertRaises(errors.InvalidArgumentError): 83 dataset = dataset_ops.Dataset.range( 84 start, stop, step, output_type=output_type) 85 self.evaluate(dataset._variant_tensor) 86 87 @combinations.generate( 88 combinations.times( 89 test_base.default_test_combinations(), 90 combinations.combine(output_type=[ 91 dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 92 ]))) 93 def testNegativeStep(self, output_type): 94 start, stop, step = 2, 10, -1 95 dataset = dataset_ops.Dataset.range( 96 start, stop, step, output_type=output_type) 97 expected_output = np.arange( 98 start, stop, step, dtype=output_type.as_numpy_dtype) 99 self.assertDatasetProduces(dataset, expected_output=expected_output) 100 self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) 101 102 @combinations.generate( 103 combinations.times( 104 test_base.default_test_combinations(), 105 combinations.combine(output_type=[ 106 dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 107 ]))) 108 def testStopLessThanStart(self, output_type): 109 start, stop = 10, 2 110 dataset = dataset_ops.Dataset.range(start, stop, output_type=output_type) 111 expected_output = np.arange(start, stop, dtype=output_type.as_numpy_dtype) 112 self.assertDatasetProduces(dataset, expected_output=expected_output) 113 self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) 114 115 @combinations.generate( 116 combinations.times( 117 test_base.default_test_combinations(), 118 combinations.combine(output_type=[ 119 dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 120 ]))) 121 def testStopLessThanStartWithPositiveStep(self, output_type): 122 start, stop, step = 10, 2, 2 123 dataset = dataset_ops.Dataset.range( 124 start, stop, step, output_type=output_type) 125 expected_output = np.arange( 126 start, stop, step, dtype=output_type.as_numpy_dtype) 127 self.assertDatasetProduces(dataset, expected_output=expected_output) 128 self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) 129 130 @combinations.generate( 131 combinations.times( 132 test_base.default_test_combinations(), 133 combinations.combine(output_type=[ 134 dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64 135 ]))) 136 def testStopLessThanStartWithNegativeStep(self, output_type): 137 start, stop, step = 10, 2, -1 138 dataset = dataset_ops.Dataset.range( 139 start, stop, step, output_type=output_type) 140 expected_output = np.arange( 141 start, stop, step, dtype=output_type.as_numpy_dtype) 142 self.assertDatasetProduces(dataset, expected_output=expected_output) 143 self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset)) 144 145 146if __name__ == "__main__": 147 test.main() 148