1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.concatenate().""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python.data.kernel_tests import test_base 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.data.util import nest 26from tensorflow.python.framework import combinations 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.platform import test 30 31 32class ConcatenateTest(test_base.DatasetTestBase, parameterized.TestCase): 33 34 @combinations.generate(test_base.default_test_combinations()) 35 def testConcatenateDataset(self): 36 input_components = ( 37 np.tile(np.array([[1], [2], [3], [4]]), 20), 38 np.tile(np.array([[12], [13], [14], [15]]), 15), 39 np.array([37.0, 38.0, 39.0, 40.0])) 40 to_concatenate_components = ( 41 np.tile(np.array([[1], [2], [3], [4], [5]]), 20), 42 np.tile(np.array([[12], [13], [14], [15], [16]]), 15), 43 np.array([37.0, 38.0, 39.0, 40.0, 41.0])) 44 45 input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) 46 dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( 47 to_concatenate_components) 48 concatenated = input_dataset.concatenate(dataset_to_concatenate) 49 self.assertEqual( 50 dataset_ops.get_legacy_output_shapes(concatenated), 51 (tensor_shape.TensorShape([20]), tensor_shape.TensorShape([15]), 52 tensor_shape.TensorShape([]))) 53 54 get_next = self.getNext(concatenated) 55 56 for i in range(9): 57 result = self.evaluate(get_next()) 58 if i < 4: 59 for component, result_component in zip(input_components, result): 60 self.assertAllEqual(component[i], result_component) 61 else: 62 for component, result_component in zip(to_concatenate_components, 63 result): 64 self.assertAllEqual(component[i - 4], result_component) 65 with self.assertRaises(errors.OutOfRangeError): 66 self.evaluate(get_next()) 67 68 @combinations.generate(test_base.default_test_combinations()) 69 def testConcatenateDatasetDifferentShape(self): 70 input_components = ( 71 np.tile(np.array([[1], [2], [3], [4]]), 20), 72 np.tile(np.array([[12], [13], [14], [15]]), 4)) 73 to_concatenate_components = ( 74 np.tile(np.array([[1], [2], [3], [4], [5]]), 20), 75 np.tile(np.array([[12], [13], [14], [15], [16]]), 15)) 76 77 input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) 78 dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( 79 to_concatenate_components) 80 concatenated = input_dataset.concatenate(dataset_to_concatenate) 81 self.assertEqual( 82 [ts.as_list() 83 for ts in nest.flatten( 84 dataset_ops.get_legacy_output_shapes(concatenated))], 85 [[20], [None]]) 86 get_next = self.getNext(concatenated) 87 for i in range(9): 88 result = self.evaluate(get_next()) 89 if i < 4: 90 for component, result_component in zip(input_components, result): 91 self.assertAllEqual(component[i], result_component) 92 else: 93 for component, result_component in zip(to_concatenate_components, 94 result): 95 self.assertAllEqual(component[i - 4], result_component) 96 with self.assertRaises(errors.OutOfRangeError): 97 self.evaluate(get_next()) 98 99 @combinations.generate(test_base.default_test_combinations()) 100 def testConcatenateDatasetDifferentStructure(self): 101 input_components = ( 102 np.tile(np.array([[1], [2], [3], [4]]), 5), 103 np.tile(np.array([[12], [13], [14], [15]]), 4)) 104 to_concatenate_components = ( 105 np.tile(np.array([[1], [2], [3], [4], [5]]), 20), 106 np.tile(np.array([[12], [13], [14], [15], [16]]), 15), 107 np.array([37.0, 38.0, 39.0, 40.0, 41.0])) 108 109 input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) 110 dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( 111 to_concatenate_components) 112 113 with self.assertRaisesRegex(TypeError, "have different types"): 114 input_dataset.concatenate(dataset_to_concatenate) 115 116 @combinations.generate(test_base.default_test_combinations()) 117 def testConcatenateDatasetDifferentKeys(self): 118 input_components = { 119 "foo": np.array([[1], [2], [3], [4]]), 120 "bar": np.array([[12], [13], [14], [15]]) 121 } 122 to_concatenate_components = { 123 "foo": np.array([[1], [2], [3], [4]]), 124 "baz": np.array([[5], [6], [7], [8]]) 125 } 126 127 input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) 128 dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( 129 to_concatenate_components) 130 131 with self.assertRaisesRegex(TypeError, "have different types"): 132 input_dataset.concatenate(dataset_to_concatenate) 133 134 @combinations.generate(test_base.default_test_combinations()) 135 def testConcatenateDatasetDifferentType(self): 136 input_components = ( 137 np.tile(np.array([[1], [2], [3], [4]]), 5), 138 np.tile(np.array([[12], [13], [14], [15]]), 4)) 139 to_concatenate_components = ( 140 np.tile(np.array([[1.0], [2.0], [3.0], [4.0]]), 5), 141 np.tile(np.array([[12], [13], [14], [15]]), 15)) 142 143 input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) 144 dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( 145 to_concatenate_components) 146 147 with self.assertRaisesRegex(TypeError, "have different types"): 148 input_dataset.concatenate(dataset_to_concatenate) 149 150 151if __name__ == "__main__": 152 test.main() 153