• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.concatenate()."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python.data.kernel_tests import test_base
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.data.util import nest
26from tensorflow.python.framework import combinations
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.platform import test
30
31
32class ConcatenateTest(test_base.DatasetTestBase, parameterized.TestCase):
33
34  @combinations.generate(test_base.default_test_combinations())
35  def testConcatenateDataset(self):
36    input_components = (
37        np.tile(np.array([[1], [2], [3], [4]]), 20),
38        np.tile(np.array([[12], [13], [14], [15]]), 15),
39        np.array([37.0, 38.0, 39.0, 40.0]))
40    to_concatenate_components = (
41        np.tile(np.array([[1], [2], [3], [4], [5]]), 20),
42        np.tile(np.array([[12], [13], [14], [15], [16]]), 15),
43        np.array([37.0, 38.0, 39.0, 40.0, 41.0]))
44
45    input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
46    dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
47        to_concatenate_components)
48    concatenated = input_dataset.concatenate(dataset_to_concatenate)
49    self.assertEqual(
50        dataset_ops.get_legacy_output_shapes(concatenated),
51        (tensor_shape.TensorShape([20]), tensor_shape.TensorShape([15]),
52         tensor_shape.TensorShape([])))
53
54    get_next = self.getNext(concatenated)
55
56    for i in range(9):
57      result = self.evaluate(get_next())
58      if i < 4:
59        for component, result_component in zip(input_components, result):
60          self.assertAllEqual(component[i], result_component)
61      else:
62        for component, result_component in zip(to_concatenate_components,
63                                               result):
64          self.assertAllEqual(component[i - 4], result_component)
65    with self.assertRaises(errors.OutOfRangeError):
66      self.evaluate(get_next())
67
68  @combinations.generate(test_base.default_test_combinations())
69  def testConcatenateDatasetDifferentShape(self):
70    input_components = (
71        np.tile(np.array([[1], [2], [3], [4]]), 20),
72        np.tile(np.array([[12], [13], [14], [15]]), 4))
73    to_concatenate_components = (
74        np.tile(np.array([[1], [2], [3], [4], [5]]), 20),
75        np.tile(np.array([[12], [13], [14], [15], [16]]), 15))
76
77    input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
78    dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
79        to_concatenate_components)
80    concatenated = input_dataset.concatenate(dataset_to_concatenate)
81    self.assertEqual(
82        [ts.as_list()
83         for ts in nest.flatten(
84             dataset_ops.get_legacy_output_shapes(concatenated))],
85        [[20], [None]])
86    get_next = self.getNext(concatenated)
87    for i in range(9):
88      result = self.evaluate(get_next())
89      if i < 4:
90        for component, result_component in zip(input_components, result):
91          self.assertAllEqual(component[i], result_component)
92      else:
93        for component, result_component in zip(to_concatenate_components,
94                                               result):
95          self.assertAllEqual(component[i - 4], result_component)
96    with self.assertRaises(errors.OutOfRangeError):
97      self.evaluate(get_next())
98
99  @combinations.generate(test_base.default_test_combinations())
100  def testConcatenateDatasetDifferentStructure(self):
101    input_components = (
102        np.tile(np.array([[1], [2], [3], [4]]), 5),
103        np.tile(np.array([[12], [13], [14], [15]]), 4))
104    to_concatenate_components = (
105        np.tile(np.array([[1], [2], [3], [4], [5]]), 20),
106        np.tile(np.array([[12], [13], [14], [15], [16]]), 15),
107        np.array([37.0, 38.0, 39.0, 40.0, 41.0]))
108
109    input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
110    dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
111        to_concatenate_components)
112
113    with self.assertRaisesRegex(TypeError, "have different types"):
114      input_dataset.concatenate(dataset_to_concatenate)
115
116  @combinations.generate(test_base.default_test_combinations())
117  def testConcatenateDatasetDifferentKeys(self):
118    input_components = {
119        "foo": np.array([[1], [2], [3], [4]]),
120        "bar": np.array([[12], [13], [14], [15]])
121    }
122    to_concatenate_components = {
123        "foo": np.array([[1], [2], [3], [4]]),
124        "baz": np.array([[5], [6], [7], [8]])
125    }
126
127    input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
128    dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
129        to_concatenate_components)
130
131    with self.assertRaisesRegex(TypeError, "have different types"):
132      input_dataset.concatenate(dataset_to_concatenate)
133
134  @combinations.generate(test_base.default_test_combinations())
135  def testConcatenateDatasetDifferentType(self):
136    input_components = (
137        np.tile(np.array([[1], [2], [3], [4]]), 5),
138        np.tile(np.array([[12], [13], [14], [15]]), 4))
139    to_concatenate_components = (
140        np.tile(np.array([[1.0], [2.0], [3.0], [4.0]]), 5),
141        np.tile(np.array([[12], [13], [14], [15]]), 15))
142
143    input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components)
144    dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices(
145        to_concatenate_components)
146
147    with self.assertRaisesRegex(TypeError, "have different types"):
148      input_dataset.concatenate(dataset_to_concatenate)
149
150
151if __name__ == "__main__":
152  test.main()
153