• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for utilities working with arbitrarily nested structures."""
16
17import functools
18
19from absl.testing import parameterized
20
21from tensorflow.python.data.kernel_tests import test_base
22from tensorflow.python.data.util import nest
23from tensorflow.python.data.util import sparse
24from tensorflow.python.framework import combinations
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.platform import test
31
32
33# NOTE(vikoth18): Arguments of parameterized tests are lifted into lambdas to make
34# sure they are not executed before the (eager- or graph-mode) test environment
35# has been set up.
36#
37
38
39def _test_any_sparse_combinations():
40
41  cases = [("TestCase_0", lambda: (), False),
42           ("TestCase_1", lambda: (ops.Tensor), False),
43           ("TestCase_2", lambda: (((ops.Tensor))), False),
44           ("TestCase_3", lambda: (ops.Tensor, ops.Tensor), False),
45           ("TestCase_4", lambda:
46            (ops.Tensor, sparse_tensor.SparseTensor), True),
47           ("TestCase_5", lambda:
48            (sparse_tensor.SparseTensor, sparse_tensor.SparseTensor), True),
49           ("TestCase_6", lambda: (((sparse_tensor.SparseTensor))), True)]
50
51  def reduce_fn(x, y):
52    name, classes_fn, expected = y
53    return x + combinations.combine(
54        classes_fn=combinations.NamedObject("classes_fn.{}".format(name),
55                                            classes_fn),
56        expected=expected)
57
58  return functools.reduce(reduce_fn, cases, [])
59
60
61def _test_as_dense_shapes_combinations():
62
63  cases = [
64      ("TestCase_0", lambda: (), lambda: (), lambda: ()),
65      ("TestCase_1", lambda: tensor_shape.TensorShape([]), lambda: ops.Tensor,
66       lambda: tensor_shape.TensorShape([])),
67      (
68          "TestCase_2",
69          lambda: tensor_shape.TensorShape([]),
70          lambda: sparse_tensor.SparseTensor,
71          lambda: tensor_shape.unknown_shape()  # pylint: disable=unnecessary-lambda
72      ),
73      ("TestCase_3", lambda: (tensor_shape.TensorShape([])), lambda:
74       (ops.Tensor), lambda: (tensor_shape.TensorShape([]))),
75      (
76          "TestCase_4",
77          lambda: (tensor_shape.TensorShape([])),
78          lambda: (sparse_tensor.SparseTensor),
79          lambda: (tensor_shape.unknown_shape())  # pylint: disable=unnecessary-lambda
80      ),
81      ("TestCase_5", lambda: (tensor_shape.TensorShape([]), ()), lambda:
82       (ops.Tensor, ()), lambda: (tensor_shape.TensorShape([]), ())),
83      ("TestCase_6", lambda: ((), tensor_shape.TensorShape([])), lambda:
84       ((), ops.Tensor), lambda: ((), tensor_shape.TensorShape([]))),
85      ("TestCase_7", lambda: (tensor_shape.TensorShape([]), ()), lambda:
86       (sparse_tensor.SparseTensor, ()), lambda: (tensor_shape.unknown_shape(),
87                                                  ())),
88      ("TestCase_8", lambda: ((), tensor_shape.TensorShape([])), lambda:
89       ((), sparse_tensor.SparseTensor), lambda: (
90           (), tensor_shape.unknown_shape())),
91      ("TestCase_9", lambda: (tensor_shape.TensorShape([]),
92                              (), tensor_shape.TensorShape([])), lambda:
93       (ops.Tensor, (), ops.Tensor), lambda:
94       (tensor_shape.TensorShape([]), (), tensor_shape.TensorShape([]))),
95      ("TestCase_10", lambda: (tensor_shape.TensorShape([]),
96                               (), tensor_shape.TensorShape([])), lambda:
97       (sparse_tensor.SparseTensor, (), sparse_tensor.SparseTensor), lambda:
98       (tensor_shape.unknown_shape(), (), tensor_shape.unknown_shape())),
99      ("TestCase_11", lambda: ((), tensor_shape.TensorShape([]), ()), lambda:
100       ((), ops.Tensor, ()), lambda: ((), tensor_shape.TensorShape([]), ())),
101      ("TestCase_12", lambda: ((), tensor_shape.TensorShape([]), ()), lambda:
102       ((), sparse_tensor.SparseTensor,
103        ()), lambda: ((), tensor_shape.unknown_shape(), ()))
104  ]
105
106  def reduce_fn(x, y):
107    name, types_fn, classes_fn, expected_fn = y
108    return x + combinations.combine(
109        types_fn=combinations.NamedObject("types_fn.{}".format(name), types_fn),
110        classes_fn=combinations.NamedObject("classes_fn.{}".format(name),
111                                            classes_fn),
112        expected_fn=combinations.NamedObject("expected_fn.{}".format(name),
113                                             expected_fn))
114
115  return functools.reduce(reduce_fn, cases, [])
116
117
118def _test_as_dense_types_combinations():
119  cases = [
120      ("TestCase_0", lambda: (), lambda: (), lambda: ()),
121      ("TestCase_1", lambda: dtypes.int32, lambda: ops.Tensor,
122       lambda: dtypes.int32),
123      ("TestCase_2", lambda: dtypes.int32, lambda: sparse_tensor.SparseTensor,
124       lambda: dtypes.variant),
125      ("TestCase_3", lambda: (dtypes.int32), lambda: (ops.Tensor), lambda:
126       (dtypes.int32)),
127      ("TestCase_4", lambda: (dtypes.int32), lambda:
128       (sparse_tensor.SparseTensor), lambda: (dtypes.variant)),
129      ("TestCase_5", lambda: (dtypes.int32, ()), lambda:
130       (ops.Tensor, ()), lambda: (dtypes.int32, ())),
131      ("TestCase_6", lambda: ((), dtypes.int32), lambda:
132       ((), ops.Tensor), lambda: ((), dtypes.int32)),
133      ("TestCase_7", lambda: (dtypes.int32, ()), lambda:
134       (sparse_tensor.SparseTensor, ()), lambda: (dtypes.variant, ())),
135      ("TestCase_8", lambda: ((), dtypes.int32), lambda:
136       ((), sparse_tensor.SparseTensor), lambda: ((), dtypes.variant)),
137      ("TestCase_9", lambda: (dtypes.int32, (), dtypes.int32), lambda:
138       (ops.Tensor, (), ops.Tensor), lambda: (dtypes.int32, (), dtypes.int32)),
139      ("TestCase_10", lambda: (dtypes.int32, (), dtypes.int32), lambda:
140       (sparse_tensor.SparseTensor, (), sparse_tensor.SparseTensor), lambda:
141       (dtypes.variant, (), dtypes.variant)),
142      ("TestCase_11", lambda: ((), dtypes.int32, ()), lambda:
143       ((), ops.Tensor, ()), lambda: ((), dtypes.int32, ())),
144      ("TestCase_12", lambda: ((), dtypes.int32, ()), lambda:
145       ((), sparse_tensor.SparseTensor, ()), lambda: ((), dtypes.variant, ())),
146  ]
147
148  def reduce_fn(x, y):
149    name, types_fn, classes_fn, expected_fn = y
150    return x + combinations.combine(
151        types_fn=combinations.NamedObject("types_fn.{}".format(name), types_fn),
152        classes_fn=combinations.NamedObject("classes_fn.{}".format(name),
153                                            classes_fn),
154        expected_fn=combinations.NamedObject("expected_fn.{}".format(name),
155                                             expected_fn))
156
157  return functools.reduce(reduce_fn, cases, [])
158
159
160def _test_get_classes_combinations():
161  cases = [
162      ("TestCase_0", lambda: (), lambda: ()),
163      ("TestCase_1", lambda: sparse_tensor.SparseTensor(
164          indices=[[0]], values=[1], dense_shape=[1]),
165       lambda: sparse_tensor.SparseTensor),
166      ("TestCase_2", lambda: constant_op.constant([1]), lambda: ops.Tensor),
167      ("TestCase_3", lambda:
168       (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1])),
169       lambda: (sparse_tensor.SparseTensor)),
170      ("TestCase_4", lambda: (constant_op.constant([1])), lambda: (ops.Tensor)),
171      ("TestCase_5", lambda:
172       (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]),
173        ()), lambda: (sparse_tensor.SparseTensor, ())),
174      ("TestCase_6", lambda:
175       ((),
176        sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1])),
177       lambda: ((), sparse_tensor.SparseTensor)),
178      ("TestCase_7", lambda: (constant_op.constant([1]), ()), lambda:
179       (ops.Tensor, ())),
180      ("TestCase_8", lambda: ((), constant_op.constant([1])), lambda:
181       ((), ops.Tensor)),
182      ("TestCase_9", lambda:
183       (sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]),
184        (), constant_op.constant([1])), lambda: (sparse_tensor.SparseTensor,
185                                                 (), ops.Tensor)),
186      ("TestCase_10", lambda:
187       ((),
188        sparse_tensor.SparseTensor(indices=[[0]], values=[1], dense_shape=[1]),
189        ()), lambda: ((), sparse_tensor.SparseTensor, ())),
190      ("TestCase_11", lambda: ((), constant_op.constant([1]), ()), lambda:
191       ((), ops.Tensor, ())),
192  ]
193
194  def reduce_fn(x, y):
195    name, classes_fn, expected_fn = y
196    return x + combinations.combine(
197        classes_fn=combinations.NamedObject("classes_fn.{}".format(name),
198                                            classes_fn),
199        expected_fn=combinations.NamedObject("expected_fn.{}".format(name),
200                                             expected_fn))
201
202  return functools.reduce(reduce_fn, cases, [])
203
204
205def _test_serialize_deserialize_combinations():
206  cases = [("TestCase_0", lambda: ()),
207           ("TestCase_1", lambda: sparse_tensor.SparseTensor(
208               indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
209           ("TestCase_2", lambda: sparse_tensor.SparseTensor(
210               indices=[[3, 4]], values=[-1], dense_shape=[4, 5])),
211           ("TestCase_3", lambda: sparse_tensor.SparseTensor(
212               indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5])),
213           ("TestCase_4", lambda: (sparse_tensor.SparseTensor(
214               indices=[[0, 0]], values=[1], dense_shape=[1, 1]))),
215           ("TestCase_5", lambda: (sparse_tensor.SparseTensor(
216               indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ())),
217           ("TestCase_6", lambda:
218            ((),
219             sparse_tensor.SparseTensor(
220                 indices=[[0, 0]], values=[1], dense_shape=[1, 1])))]
221
222  def reduce_fn(x, y):
223    name, input_fn = y
224    return x + combinations.combine(
225        input_fn=combinations.NamedObject("input_fn.{}".format(name), input_fn))
226
227  return functools.reduce(reduce_fn, cases, [])
228
229
230def _test_serialize_many_deserialize_combinations():
231  cases = [("TestCase_0", lambda: ()),
232           ("TestCase_1", lambda: sparse_tensor.SparseTensor(
233               indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
234           ("TestCase_2", lambda: sparse_tensor.SparseTensor(
235               indices=[[3, 4]], values=[-1], dense_shape=[4, 5])),
236           ("TestCase_3", lambda: sparse_tensor.SparseTensor(
237               indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5])),
238           ("TestCase_4", lambda: (sparse_tensor.SparseTensor(
239               indices=[[0, 0]], values=[1], dense_shape=[1, 1]))),
240           ("TestCase_5", lambda: (sparse_tensor.SparseTensor(
241               indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ())),
242           ("TestCase_6", lambda:
243            ((),
244             sparse_tensor.SparseTensor(
245                 indices=[[0, 0]], values=[1], dense_shape=[1, 1])))]
246
247  def reduce_fn(x, y):
248    name, input_fn = y
249    return x + combinations.combine(
250        input_fn=combinations.NamedObject("input_fn.{}".format(name), input_fn))
251
252  return functools.reduce(reduce_fn, cases, [])
253
254
255class SparseTest(test_base.DatasetTestBase, parameterized.TestCase):
256
257  @combinations.generate(
258      combinations.times(test_base.default_test_combinations(),
259                         _test_any_sparse_combinations()))
260  def testAnySparse(self, classes_fn, expected):
261    classes = classes_fn()
262    self.assertEqual(sparse.any_sparse(classes), expected)
263
264  def assertShapesEqual(self, a, b):
265    for a, b in zip(nest.flatten(a), nest.flatten(b)):
266      self.assertEqual(a.ndims, b.ndims)
267      if a.ndims is None:
268        continue
269      for c, d in zip(a.as_list(), b.as_list()):
270        self.assertEqual(c, d)
271
272  @combinations.generate(
273      combinations.times(test_base.default_test_combinations(),
274                         _test_as_dense_shapes_combinations()))
275  def testAsDenseShapes(self, types_fn, classes_fn, expected_fn):
276    types = types_fn()
277    classes = classes_fn()
278    expected = expected_fn()
279    self.assertShapesEqual(sparse.as_dense_shapes(types, classes), expected)
280
281  @combinations.generate(
282      combinations.times(test_base.default_test_combinations(),
283                         _test_as_dense_types_combinations()))
284  def testAsDenseTypes(self, types_fn, classes_fn, expected_fn):
285    types = types_fn()
286    classes = classes_fn()
287    expected = expected_fn()
288    self.assertEqual(sparse.as_dense_types(types, classes), expected)
289
290  @combinations.generate(
291      combinations.times(test_base.default_test_combinations(),
292                         _test_get_classes_combinations()))
293  def testGetClasses(self, classes_fn, expected_fn):
294    classes = classes_fn()
295    expected = expected_fn()
296    self.assertEqual(sparse.get_classes(classes), expected)
297
298  def assertSparseValuesEqual(self, a, b):
299    if not isinstance(a, sparse_tensor.SparseTensor):
300      self.assertFalse(isinstance(b, sparse_tensor.SparseTensor))
301      self.assertEqual(a, b)
302      return
303    self.assertTrue(isinstance(b, sparse_tensor.SparseTensor))
304    with self.cached_session():
305      self.assertAllEqual(a.eval().indices, self.evaluate(b).indices)
306      self.assertAllEqual(a.eval().values, self.evaluate(b).values)
307      self.assertAllEqual(a.eval().dense_shape, self.evaluate(b).dense_shape)
308
309  @combinations.generate(
310      combinations.times(test_base.graph_only_combinations(),
311                         _test_serialize_deserialize_combinations()))
312  def testSerializeDeserialize(self, input_fn):
313    test_case = input_fn()
314    classes = sparse.get_classes(test_case)
315    shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
316                                classes)
317    types = nest.map_structure(lambda _: dtypes.int32, classes)
318    actual = sparse.deserialize_sparse_tensors(
319        sparse.serialize_sparse_tensors(test_case), types, shapes,
320        sparse.get_classes(test_case))
321    nest.assert_same_structure(test_case, actual)
322    for a, e in zip(nest.flatten(actual), nest.flatten(test_case)):
323      self.assertSparseValuesEqual(a, e)
324
325  @combinations.generate(
326      combinations.times(test_base.graph_only_combinations(),
327                         _test_serialize_many_deserialize_combinations()))
328  def testSerializeManyDeserialize(self, input_fn):
329    test_case = input_fn()
330    classes = sparse.get_classes(test_case)
331    shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
332                                classes)
333    types = nest.map_structure(lambda _: dtypes.int32, classes)
334    actual = sparse.deserialize_sparse_tensors(
335        sparse.serialize_many_sparse_tensors(test_case), types, shapes,
336        sparse.get_classes(test_case))
337    nest.assert_same_structure(test_case, actual)
338    for a, e in zip(nest.flatten(actual), nest.flatten(test_case)):
339      self.assertSparseValuesEqual(a, e)
340
341
342if __name__ == "__main__":
343  test.main()
344