• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for utilities working with arbitrarily nested structures."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23
24import numpy as np
25import wrapt
26from absl.testing import parameterized
27
28from tensorflow.python.data.kernel_tests import test_base
29from tensorflow.python.data.ops import dataset_ops
30from tensorflow.python.data.util import nest
31from tensorflow.python.data.util import structure
32from tensorflow.python.framework import combinations
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import sparse_tensor
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.framework import tensor_spec
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import tensor_array_ops
41from tensorflow.python.ops import variables
42from tensorflow.python.ops.ragged import ragged_factory_ops
43from tensorflow.python.ops.ragged import ragged_tensor
44from tensorflow.python.ops.ragged import ragged_tensor_value
45from tensorflow.python.platform import test
46from tensorflow.python.util.compat import collections_abc
47
48
49# NOTE(mrry): Arguments of parameterized tests are lifted into lambdas to make
50# sure they are not executed before the (eager- or graph-mode) test environment
51# has been set up.
52#
53
54
55def _test_flat_structure_combinations():
56  cases = [
57      ("Tensor", lambda: constant_op.constant(37.0),
58       lambda: tensor_spec.TensorSpec, lambda: [dtypes.float32], lambda: [[]]),
59      ("TensorArray", lambda: tensor_array_ops.TensorArray(
60          dtype=dtypes.float32, element_shape=(3,), size=0),
61       lambda: tensor_array_ops.TensorArraySpec, lambda: [dtypes.variant],
62       lambda: [[]]),
63      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
64          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
65       lambda: sparse_tensor.SparseTensorSpec, lambda: [dtypes.variant],
66       lambda: [None]),
67      ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]),
68       lambda: ragged_tensor.RaggedTensorSpec, lambda: [dtypes.variant],
69       lambda: [None]),
70      ("Nested_0", lambda:
71       (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
72       lambda: tuple, lambda: [dtypes.float32, dtypes.int32],
73       lambda: [[], [3]]),
74      ("Nested_1", lambda: {
75          "a": constant_op.constant(37.0),
76          "b": constant_op.constant([1, 2, 3])
77      }, lambda: dict, lambda: [dtypes.float32, dtypes.int32],
78       lambda: [[], [3]]),
79      ("Nested_2", lambda: {
80          "a":
81              constant_op.constant(37.0),
82          "b": (sparse_tensor.SparseTensor(
83              indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
84                sparse_tensor.SparseTensor(
85                    indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
86      }, lambda: dict, lambda: [dtypes.float32, dtypes.variant, dtypes.variant],
87       lambda: [[], None, None]),
88  ]
89
90  def reduce_fn(x, y):
91    # workaround for long line
92    name, value_fn = y[:2]
93    expected_structure_fn, expected_types_fn, expected_shapes_fn = y[2:]
94    return x + combinations.combine(
95        value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn),
96        expected_structure_fn=combinations.NamedObject(
97            "expected_structure_fn.{}".format(name), expected_structure_fn),
98        expected_types_fn=combinations.NamedObject(
99            "expected_types_fn.{}".format(name), expected_types_fn),
100        expected_shapes_fn=combinations.NamedObject(
101            "expected_shapes_fn.{}".format(name), expected_shapes_fn))
102
103  return functools.reduce(reduce_fn, cases, [])
104
105
106def _test_is_compatible_with_structure_combinations():
107  cases = [
108      ("Tensor", lambda: constant_op.constant(37.0), lambda: [
109          constant_op.constant(38.0),
110          array_ops.placeholder(dtypes.float32),
111          variables.Variable(100.0), 42.0,
112          np.array(42.0, dtype=np.float32)
113      ], lambda: [constant_op.constant([1.0, 2.0]),
114                  constant_op.constant(37)]),
115      ("TensorArray", lambda: tensor_array_ops.TensorArray(
116          dtype=dtypes.float32, element_shape=(3,), size=0), lambda: [
117              tensor_array_ops.TensorArray(
118                  dtype=dtypes.float32, element_shape=(3,), size=0),
119              tensor_array_ops.TensorArray(
120                  dtype=dtypes.float32, element_shape=(3,), size=10)
121          ], lambda: [
122              tensor_array_ops.TensorArray(
123                  dtype=dtypes.int32, element_shape=(3,), size=0),
124              tensor_array_ops.TensorArray(
125                  dtype=dtypes.float32, element_shape=(), size=0)
126          ]),
127      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
128          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
129       lambda: [
130           sparse_tensor.SparseTensor(
131               indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
132           sparse_tensor.SparseTensorValue(
133               indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
134           array_ops.sparse_placeholder(dtype=dtypes.int32),
135           array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None])
136       ], lambda: [
137           constant_op.constant(37, shape=[4, 5]),
138           sparse_tensor.SparseTensor(
139               indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
140           array_ops.sparse_placeholder(
141               dtype=dtypes.int32, shape=[None, None, None]),
142           sparse_tensor.SparseTensor(
143               indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
144       ]),
145      ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
146       lambda: [
147           ragged_factory_ops.constant([[1, 2], [3, 4], []]),
148           ragged_factory_ops.constant([[1], [2, 3, 4], [5]]),
149       ], lambda: [
150           ragged_factory_ops.constant(1),
151           ragged_factory_ops.constant([1, 2]),
152           ragged_factory_ops.constant([[1], [2]]),
153           ragged_factory_ops.constant([["a", "b"]]),
154       ]),
155      ("Nested", lambda: {
156          "a": constant_op.constant(37.0),
157          "b": constant_op.constant([1, 2, 3])
158      }, lambda: [{
159          "a": constant_op.constant(15.0),
160          "b": constant_op.constant([4, 5, 6])
161      }], lambda: [{
162          "a": constant_op.constant(15.0),
163          "b": constant_op.constant([4, 5, 6, 7])
164      }, {
165          "a": constant_op.constant(15),
166          "b": constant_op.constant([4, 5, 6])
167      }, {
168          "a":
169              constant_op.constant(15),
170          "b":
171              sparse_tensor.SparseTensor(
172                  indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
173      }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
174  ]
175
176  def reduce_fn(x, y):
177    name, original_value_fn, compatible_values_fn, incompatible_values_fn = y
178    return x + combinations.combine(
179        original_value_fn=combinations.NamedObject(
180            "original_value_fn.{}".format(name), original_value_fn),
181        compatible_values_fn=combinations.NamedObject(
182            "compatible_values_fn.{}".format(name), compatible_values_fn),
183        incompatible_values_fn=combinations.NamedObject(
184            "incompatible_values_fn.{}".format(name), incompatible_values_fn))
185
186  return functools.reduce(reduce_fn, cases, [])
187
188
189def _test_structure_from_value_equality_combinations():
190  cases = [
191      ("Tensor", lambda: constant_op.constant(37.0),
192       lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])),
193      ("TensorArray", lambda: tensor_array_ops.TensorArray(
194          dtype=dtypes.float32, element_shape=(3,), size=0),
195       lambda: tensor_array_ops.TensorArray(
196           dtype=dtypes.float32, element_shape=(3,), size=0),
197       lambda: tensor_array_ops.TensorArray(
198           dtype=dtypes.int32, element_shape=(), size=0)),
199      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
200          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
201       lambda: sparse_tensor.SparseTensor(
202           indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda:
203       sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5]),
204       lambda: sparse_tensor.SparseTensor(
205           indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])),
206      ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]),
207       lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]),
208       lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1),
209       lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]),
210       lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])),
211      ("Nested", lambda: {
212          "a": constant_op.constant(37.0),
213          "b": constant_op.constant([1, 2, 3])
214      }, lambda: {
215          "a": constant_op.constant(42.0),
216          "b": constant_op.constant([4, 5, 6])
217      }, lambda: {
218          "a": constant_op.constant([1, 2, 3]),
219          "b": constant_op.constant(37.0)
220      }),
221  ]
222
223  def reduce_fn(x, y):
224    name, value1_fn, value2_fn, *not_equal_value_fns = y
225    return x + combinations.combine(
226        value1_fn=combinations.NamedObject("value1_fn.{}".format(name),
227                                           value1_fn),
228        value2_fn=combinations.NamedObject("value2_fn.{}".format(name),
229                                           value2_fn),
230        not_equal_value_fns=combinations.NamedObject(
231            "not_equal_value_fns.{}".format(name), not_equal_value_fns))
232
233  return functools.reduce(reduce_fn, cases, [])
234
235
236def _test_ragged_structure_inequality_combinations():
237  cases = [
238      (ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
239       ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 2)),
240      (ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32, 1),
241       ragged_tensor.RaggedTensorSpec([5, None], dtypes.int32, 1)),
242      (ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
243       ragged_tensor.RaggedTensorSpec(None, dtypes.float32, 1)),
244  ]
245
246  def reduce_fn(x, y):
247    spec1, spec2 = y
248    return x + combinations.combine(spec1=spec1, spec2=spec2)
249
250  return functools.reduce(reduce_fn, cases, [])
251
252
253def _test_hash_combinations():
254  cases = [
255      ("Tensor", lambda: constant_op.constant(37.0),
256       lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])),
257      ("TensorArray", lambda: tensor_array_ops.TensorArray(
258          dtype=dtypes.float32, element_shape=(3,), size=0),
259       lambda: tensor_array_ops.TensorArray(
260           dtype=dtypes.float32, element_shape=(3,), size=0),
261       lambda: tensor_array_ops.TensorArray(
262           dtype=dtypes.int32, element_shape=(), size=0)),
263      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
264          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
265       lambda: sparse_tensor.SparseTensor(
266           indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda:
267       sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5])),
268      ("Nested", lambda: {
269          "a": constant_op.constant(37.0),
270          "b": constant_op.constant([1, 2, 3])
271      }, lambda: {
272          "a": constant_op.constant(42.0),
273          "b": constant_op.constant([4, 5, 6])
274      }, lambda: {
275          "a": constant_op.constant([1, 2, 3]),
276          "b": constant_op.constant(37.0)
277      }),
278  ]
279
280  def reduce_fn(x, y):
281    name, value1_fn, value2_fn, value3_fn = y
282    return x + combinations.combine(
283        value1_fn=combinations.NamedObject("value1_fn.{}".format(name),
284                                           value1_fn),
285        value2_fn=combinations.NamedObject("value2_fn.{}".format(name),
286                                           value2_fn),
287        value3_fn=combinations.NamedObject("value3_fn.{}".format(name),
288                                           value3_fn))
289
290  return functools.reduce(reduce_fn, cases, [])
291
292
293def _test_round_trip_conversion_combinations():
294  cases = [
295      (
296          "Tensor",
297          lambda: constant_op.constant(37.0),
298      ),
299      (
300          "SparseTensor",
301          lambda: sparse_tensor.SparseTensor(
302              indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
303      ),
304      ("TensorArray", lambda: tensor_array_ops.TensorArray(
305          dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
306      (
307          "RaggedTensor",
308          lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
309      ),
310      (
311          "Nested_0",
312          lambda: {
313              "a": constant_op.constant(37.0),
314              "b": constant_op.constant([1, 2, 3])
315          },
316      ),
317      (
318          "Nested_1",
319          lambda: {
320              "a":
321                  constant_op.constant(37.0),
322              "b": (sparse_tensor.SparseTensor(
323                  indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
324                    sparse_tensor.SparseTensor(
325                        indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
326          },
327      ),
328  ]
329
330  def reduce_fn(x, y):
331    name, value_fn = y
332    return x + combinations.combine(
333        value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn))
334
335  return functools.reduce(reduce_fn, cases, [])
336
337
338def _test_convert_legacy_structure_combinations():
339  cases = [
340      (dtypes.float32, tensor_shape.TensorShape([]), ops.Tensor,
341       tensor_spec.TensorSpec([], dtypes.float32)),
342      (dtypes.int32, tensor_shape.TensorShape([2,
343                                               2]), sparse_tensor.SparseTensor,
344       sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)),
345      (dtypes.int32, tensor_shape.TensorShape([None, True, 2, 2]),
346       tensor_array_ops.TensorArray,
347       tensor_array_ops.TensorArraySpec([2, 2],
348                                        dtypes.int32,
349                                        dynamic_size=None,
350                                        infer_shape=True)),
351      (dtypes.int32, tensor_shape.TensorShape([True, None, 2, 2]),
352       tensor_array_ops.TensorArray,
353       tensor_array_ops.TensorArraySpec([2, 2],
354                                        dtypes.int32,
355                                        dynamic_size=True,
356                                        infer_shape=None)),
357      (dtypes.int32, tensor_shape.TensorShape([True, False, 2, 2]),
358       tensor_array_ops.TensorArray,
359       tensor_array_ops.TensorArraySpec([2, 2],
360                                        dtypes.int32,
361                                        dynamic_size=True,
362                                        infer_shape=False)),
363      (dtypes.int32, tensor_shape.TensorShape([2, None]),
364       ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1),
365       ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)),
366      ({
367          "a": dtypes.float32,
368          "b": (dtypes.int32, dtypes.string)
369      }, {
370          "a": tensor_shape.TensorShape([]),
371          "b": (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([]))
372      }, {
373          "a": ops.Tensor,
374          "b": (sparse_tensor.SparseTensor, ops.Tensor)
375      }, {
376          "a":
377              tensor_spec.TensorSpec([], dtypes.float32),
378          "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
379                tensor_spec.TensorSpec([], dtypes.string))
380      })
381  ]
382
383  def reduce_fn(x, y):
384    output_types, output_shapes, output_classes, expected_structure = y
385    return x + combinations.combine(
386        output_types=output_types,
387        output_shapes=output_shapes,
388        output_classes=output_classes,
389        expected_structure=expected_structure)
390
391  return functools.reduce(reduce_fn, cases, [])
392
393
394def _test_batch_combinations():
395  cases = [
396      (tensor_spec.TensorSpec([], dtypes.float32), 32,
397       tensor_spec.TensorSpec([32], dtypes.float32)),
398      (tensor_spec.TensorSpec([], dtypes.float32), None,
399       tensor_spec.TensorSpec([None], dtypes.float32)),
400      (sparse_tensor.SparseTensorSpec([None], dtypes.float32), 32,
401       sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)),
402      (sparse_tensor.SparseTensorSpec([4], dtypes.float32), None,
403       sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32)),
404      (ragged_tensor.RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
405       ragged_tensor.RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
406      (ragged_tensor.RaggedTensorSpec([4, None], dtypes.float32, 1), None,
407       ragged_tensor.RaggedTensorSpec([None, 4, None], dtypes.float32, 2)),
408      ({
409          "a":
410              tensor_spec.TensorSpec([], dtypes.float32),
411          "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
412                tensor_spec.TensorSpec([], dtypes.string))
413      }, 128, {
414          "a":
415              tensor_spec.TensorSpec([128], dtypes.float32),
416          "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
417                tensor_spec.TensorSpec([128], dtypes.string))
418      }),
419  ]
420
421  def reduce_fn(x, y):
422    element_structure, batch_size, expected_batched_structure = y
423    return x + combinations.combine(
424        element_structure=element_structure,
425        batch_size=batch_size,
426        expected_batched_structure=expected_batched_structure)
427
428  return functools.reduce(reduce_fn, cases, [])
429
430
431def _test_unbatch_combinations():
432  cases = [
433      (tensor_spec.TensorSpec([32], dtypes.float32),
434       tensor_spec.TensorSpec([], dtypes.float32)),
435      (tensor_spec.TensorSpec([None], dtypes.float32),
436       tensor_spec.TensorSpec([], dtypes.float32)),
437      (sparse_tensor.SparseTensorSpec([32, None], dtypes.float32),
438       sparse_tensor.SparseTensorSpec([None], dtypes.float32)),
439      (sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32),
440       sparse_tensor.SparseTensorSpec([4], dtypes.float32)),
441      (ragged_tensor.RaggedTensorSpec([32, None, None], dtypes.float32, 2),
442       ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
443      (ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.float32, 2),
444       ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
445      ({
446          "a":
447              tensor_spec.TensorSpec([128], dtypes.float32),
448          "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
449                tensor_spec.TensorSpec([None], dtypes.string))
450      }, {
451          "a":
452              tensor_spec.TensorSpec([], dtypes.float32),
453          "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
454                tensor_spec.TensorSpec([], dtypes.string))
455      }),
456  ]
457
458  def reduce_fn(x, y):
459    element_structure, expected_unbatched_structure = y
460    return x + combinations.combine(
461        element_structure=element_structure,
462        expected_unbatched_structure=expected_unbatched_structure)
463
464  return functools.reduce(reduce_fn, cases, [])
465
466
467def _test_to_batched_tensor_list_combinations():
468  cases = [
469      ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
470       lambda: constant_op.constant([1.0, 2.0])),
471      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
472          indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]),
473       lambda: sparse_tensor.SparseTensor(
474           indices=[[0]], values=[13], dense_shape=[2])),
475      ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]),
476       lambda: ragged_factory_ops.constant([[1]])),
477      ("Nest", lambda:
478       (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
479        sparse_tensor.SparseTensor(
480            indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])),
481       lambda:
482       (constant_op.constant([1.0, 2.0]),
483        sparse_tensor.SparseTensor(indices=[[0]], values=[13], dense_shape=[2]))
484      ),
485  ]
486
487  def reduce_fn(x, y):
488    name, value_fn, element_0_fn = y
489    return x + combinations.combine(
490        value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn),
491        element_0_fn=combinations.NamedObject("element_0_fn.{}".format(name),
492                                              element_0_fn))
493
494  return functools.reduce(reduce_fn, cases, [])
495
496# TODO(jsimsa): Add tests for OptionalStructure and DatasetStructure.
497class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
498
499  # pylint: disable=g-long-lambda,protected-access
500  @combinations.generate(
501      combinations.times(test_base.default_test_combinations(),
502                         _test_flat_structure_combinations()))
503  def testFlatStructure(self, value_fn, expected_structure_fn,
504                        expected_types_fn, expected_shapes_fn):
505    value = value_fn()
506    expected_structure = expected_structure_fn()
507    expected_types = expected_types_fn()
508    expected_shapes = expected_shapes_fn()
509    s = structure.type_spec_from_value(value)
510    self.assertIsInstance(s, expected_structure)
511    flat_types = structure.get_flat_tensor_types(s)
512    self.assertEqual(expected_types, flat_types)
513    flat_shapes = structure.get_flat_tensor_shapes(s)
514    self.assertLen(flat_shapes, len(expected_shapes))
515    for expected, actual in zip(expected_shapes, flat_shapes):
516      if expected is None:
517        self.assertEqual(actual.ndims, None)
518      else:
519        self.assertEqual(actual.as_list(), expected)
520
521  @combinations.generate(
522      combinations.times(test_base.graph_only_combinations(),
523                         _test_is_compatible_with_structure_combinations()))
524  def testIsCompatibleWithStructure(self, original_value_fn,
525                                    compatible_values_fn,
526                                    incompatible_values_fn):
527    original_value = original_value_fn()
528    compatible_values = compatible_values_fn()
529    incompatible_values = incompatible_values_fn()
530
531    s = structure.type_spec_from_value(original_value)
532    for compatible_value in compatible_values:
533      self.assertTrue(
534          structure.are_compatible(
535              s, structure.type_spec_from_value(compatible_value)))
536    for incompatible_value in incompatible_values:
537      self.assertFalse(
538          structure.are_compatible(
539              s, structure.type_spec_from_value(incompatible_value)))
540
541  @combinations.generate(
542      combinations.times(test_base.default_test_combinations(),
543                         _test_structure_from_value_equality_combinations()))
544  def testStructureFromValueEquality(self, value1_fn, value2_fn,
545                                     not_equal_value_fns):
546    # pylint: disable=g-generic-assert
547    not_equal_value_fns = not_equal_value_fns._obj
548    s1 = structure.type_spec_from_value(value1_fn())
549    s2 = structure.type_spec_from_value(value2_fn())
550    self.assertEqual(s1, s1)  # check __eq__ operator.
551    self.assertEqual(s1, s2)  # check __eq__ operator.
552    self.assertFalse(s1 != s1)  # check __ne__ operator.
553    self.assertFalse(s1 != s2)  # check __ne__ operator.
554    for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)):
555      self.assertEqual(hash(c1), hash(c1))
556      self.assertEqual(hash(c1), hash(c2))
557    for value_fn in not_equal_value_fns:
558      s3 = structure.type_spec_from_value(value_fn())
559      self.assertNotEqual(s1, s3)  # check __ne__ operator.
560      self.assertNotEqual(s2, s3)  # check __ne__ operator.
561      self.assertFalse(s1 == s3)  # check __eq_ operator.
562      self.assertFalse(s2 == s3)  # check __eq_ operator.
563
564  @combinations.generate(
565      combinations.times(test_base.default_test_combinations(),
566                         _test_ragged_structure_inequality_combinations()))
567  def testRaggedStructureInequality(self, spec1, spec2):
568    # pylint: disable=g-generic-assert
569    self.assertNotEqual(spec1, spec2)  # check __ne__ operator.
570    self.assertFalse(spec1 == spec2)  # check __eq__ operator.
571
572  @combinations.generate(
573      combinations.times(test_base.default_test_combinations(),
574                         _test_hash_combinations()))
575  def testHash(self, value1_fn, value2_fn, value3_fn):
576    s1 = structure.type_spec_from_value(value1_fn())
577    s2 = structure.type_spec_from_value(value2_fn())
578    s3 = structure.type_spec_from_value(value3_fn())
579    for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2), nest.flatten(s3)):
580      self.assertEqual(hash(c1), hash(c1))
581      self.assertEqual(hash(c1), hash(c2))
582      self.assertNotEqual(hash(c1), hash(c3))
583      self.assertNotEqual(hash(c2), hash(c3))
584
585  @combinations.generate(
586      combinations.times(test_base.default_test_combinations(),
587                         _test_round_trip_conversion_combinations()))
588  def testRoundTripConversion(self, value_fn):
589    value = value_fn()
590    s = structure.type_spec_from_value(value)
591
592    def maybe_stack_ta(v):
593      if isinstance(v, tensor_array_ops.TensorArray):
594        return v.stack()
595      return v
596
597    before = self.evaluate(maybe_stack_ta(value))
598    after = self.evaluate(
599        maybe_stack_ta(
600            structure.from_tensor_list(s, structure.to_tensor_list(s, value))))
601
602    flat_before = nest.flatten(before)
603    flat_after = nest.flatten(after)
604    for b, a in zip(flat_before, flat_after):
605      if isinstance(b, sparse_tensor.SparseTensorValue):
606        self.assertAllEqual(b.indices, a.indices)
607        self.assertAllEqual(b.values, a.values)
608        self.assertAllEqual(b.dense_shape, a.dense_shape)
609      elif isinstance(
610          b,
611          (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)):
612        self.assertAllEqual(b, a)
613      else:
614        self.assertAllEqual(b, a)
615
616  # pylint: enable=g-long-lambda
617
618  def preserveStaticShape(self):
619    rt = ragged_factory_ops.constant([[1, 2], [], [3]])
620    rt_s = structure.type_spec_from_value(rt)
621    rt_after = structure.from_tensor_list(rt_s,
622                                          structure.to_tensor_list(rt_s, rt))
623    self.assertEqual(rt_after.row_splits.shape.as_list(),
624                     rt.row_splits.shape.as_list())
625    self.assertEqual(rt_after.values.shape.as_list(), [None])
626
627    st = sparse_tensor.SparseTensor(
628        indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
629    st_s = structure.type_spec_from_value(st)
630    st_after = structure.from_tensor_list(st_s,
631                                          structure.to_tensor_list(st_s, st))
632    self.assertEqual(st_after.indices.shape.as_list(), [None, 2])
633    self.assertEqual(st_after.values.shape.as_list(), [None])
634    self.assertEqual(st_after.dense_shape.shape.as_list(),
635                     st.dense_shape.shape.as_list())
636
637  @combinations.generate(test_base.default_test_combinations())
638  def testPreserveTensorArrayShape(self):
639    ta = tensor_array_ops.TensorArray(
640        dtype=dtypes.int32, size=1, element_shape=(3,))
641    ta_s = structure.type_spec_from_value(ta)
642    ta_after = structure.from_tensor_list(ta_s,
643                                          structure.to_tensor_list(ta_s, ta))
644    self.assertEqual(ta_after.element_shape.as_list(), [3])
645
646  @combinations.generate(test_base.default_test_combinations())
647  def testPreserveInferredTensorArrayShape(self):
648    ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1)
649    # Shape is inferred from the write.
650    ta = ta.write(0, [1, 2, 3])
651    ta_s = structure.type_spec_from_value(ta)
652    ta_after = structure.from_tensor_list(ta_s,
653                                          structure.to_tensor_list(ta_s, ta))
654    self.assertEqual(ta_after.element_shape.as_list(), [3])
655
656  @combinations.generate(test_base.default_test_combinations())
657  def testIncompatibleStructure(self):
658    # Define three mutually incompatible values/structures, and assert that:
659    # 1. Using one structure to flatten a value with an incompatible structure
660    #    fails.
661    # 2. Using one structure to restructure a flattened value with an
662    #    incompatible structure fails.
663    value_tensor = constant_op.constant(42.0)
664    s_tensor = structure.type_spec_from_value(value_tensor)
665    flat_tensor = structure.to_tensor_list(s_tensor, value_tensor)
666
667    value_sparse_tensor = sparse_tensor.SparseTensor(
668        indices=[[0, 0]], values=[1], dense_shape=[1, 1])
669    s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor)
670    flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor,
671                                                  value_sparse_tensor)
672
673    value_nest = {
674        "a": constant_op.constant(37.0),
675        "b": constant_op.constant([1, 2, 3])
676    }
677    s_nest = structure.type_spec_from_value(value_nest)
678    flat_nest = structure.to_tensor_list(s_nest, value_nest)
679
680    with self.assertRaisesRegex(
681        ValueError, r"SparseTensor.* is not convertible to a tensor with "
682        r"dtype.*float32.* and shape \(\)"):
683      structure.to_tensor_list(s_tensor, value_sparse_tensor)
684    with self.assertRaisesRegex(
685        ValueError, "The two structures don't have the same nested structure."):
686      structure.to_tensor_list(s_tensor, value_nest)
687
688    with self.assertRaisesRegex(TypeError,
689                                "Neither a SparseTensor nor SparseTensorValue"):
690      structure.to_tensor_list(s_sparse_tensor, value_tensor)
691
692    with self.assertRaisesRegex(
693        ValueError, "The two structures don't have the same nested structure."):
694      structure.to_tensor_list(s_sparse_tensor, value_nest)
695
696    with self.assertRaisesRegex(
697        ValueError, "The two structures don't have the same nested structure."):
698      structure.to_tensor_list(s_nest, value_tensor)
699
700    with self.assertRaisesRegex(
701        ValueError, "The two structures don't have the same nested structure."):
702      structure.to_tensor_list(s_nest, value_sparse_tensor)
703
704    with self.assertRaisesRegex(ValueError, r"Incompatible input:"):
705      structure.from_tensor_list(s_tensor, flat_sparse_tensor)
706
707    with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."):
708      structure.from_tensor_list(s_tensor, flat_nest)
709
710    with self.assertRaisesRegex(ValueError, "Incompatible input: "):
711      structure.from_tensor_list(s_sparse_tensor, flat_tensor)
712
713    with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."):
714      structure.from_tensor_list(s_sparse_tensor, flat_nest)
715
716    with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."):
717      structure.from_tensor_list(s_nest, flat_tensor)
718
719    with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."):
720      structure.from_tensor_list(s_nest, flat_sparse_tensor)
721
722  @combinations.generate(test_base.default_test_combinations())
723  def testIncompatibleNestedStructure(self):
724    # Define three mutually incompatible nested values/structures, and assert
725    # that:
726    # 1. Using one structure to flatten a value with an incompatible structure
727    #    fails.
728    # 2. Using one structure to restructure a flattened value with an
729    #    incompatible structure fails.
730
731    value_0 = {
732        "a": constant_op.constant(37.0),
733        "b": constant_op.constant([1, 2, 3])
734    }
735    s_0 = structure.type_spec_from_value(value_0)
736    flat_s_0 = structure.to_tensor_list(s_0, value_0)
737
738    # `value_1` has compatible nested structure with `value_0`, but different
739    # classes.
740    value_1 = {
741        "a":
742            constant_op.constant(37.0),
743        "b":
744            sparse_tensor.SparseTensor(
745                indices=[[0, 0]], values=[1], dense_shape=[1, 1])
746    }
747    s_1 = structure.type_spec_from_value(value_1)
748    flat_s_1 = structure.to_tensor_list(s_1, value_1)
749
750    # `value_2` has incompatible nested structure with `value_0` and `value_1`.
751    value_2 = {
752        "a":
753            constant_op.constant(37.0),
754        "b": (sparse_tensor.SparseTensor(
755            indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
756              sparse_tensor.SparseTensor(
757                  indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
758    }
759    s_2 = structure.type_spec_from_value(value_2)
760    flat_s_2 = structure.to_tensor_list(s_2, value_2)
761
762    with self.assertRaisesRegex(
763        ValueError, r"SparseTensor.* is not convertible to a tensor with "
764        r"dtype.*int32.* and shape \(3,\)"):
765      structure.to_tensor_list(s_0, value_1)
766
767    with self.assertRaisesRegex(
768        ValueError, "The two structures don't have the same nested structure."):
769      structure.to_tensor_list(s_0, value_2)
770
771    with self.assertRaisesRegex(TypeError,
772                                "Neither a SparseTensor nor SparseTensorValue"):
773      structure.to_tensor_list(s_1, value_0)
774
775    with self.assertRaisesRegex(
776        ValueError, "The two structures don't have the same nested structure."):
777      structure.to_tensor_list(s_1, value_2)
778
779    # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
780    # needs to account for "a" coming before or after "b". It might be worth
781    # adding a deterministic repr for these error messages (among other
782    # improvements).
783    with self.assertRaisesRegex(
784        ValueError, "The two structures don't have the same nested structure."):
785      structure.to_tensor_list(s_2, value_0)
786
787    with self.assertRaisesRegex(
788        ValueError, "The two structures don't have the same nested structure."):
789      structure.to_tensor_list(s_2, value_1)
790
791    with self.assertRaisesRegex(ValueError, r"Incompatible input:"):
792      structure.from_tensor_list(s_0, flat_s_1)
793
794    with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3."):
795      structure.from_tensor_list(s_0, flat_s_2)
796
797    with self.assertRaisesRegex(ValueError, "Incompatible input: "):
798      structure.from_tensor_list(s_1, flat_s_0)
799
800    with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3."):
801      structure.from_tensor_list(s_1, flat_s_2)
802
803    with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2."):
804      structure.from_tensor_list(s_2, flat_s_0)
805
806    with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2."):
807      structure.from_tensor_list(s_2, flat_s_1)
808
809  @combinations.generate(
810      combinations.times(test_base.default_test_combinations(),
811                         _test_convert_legacy_structure_combinations()))
812  def testConvertLegacyStructure(self, output_types, output_shapes,
813                                 output_classes, expected_structure):
814    actual_structure = structure.convert_legacy_structure(
815        output_types, output_shapes, output_classes)
816    self.assertEqual(actual_structure, expected_structure)
817
818  @combinations.generate(test_base.default_test_combinations())
819  def testNestedNestedStructure(self):
820    s = (tensor_spec.TensorSpec([], dtypes.int64),
821         (tensor_spec.TensorSpec([], dtypes.float32),
822          tensor_spec.TensorSpec([], dtypes.string)))
823
824    int64_t = constant_op.constant(37, dtype=dtypes.int64)
825    float32_t = constant_op.constant(42.0)
826    string_t = constant_op.constant("Foo")
827
828    nested_tensors = (int64_t, (float32_t, string_t))
829
830    tensor_list = structure.to_tensor_list(s, nested_tensors)
831    for expected, actual in zip([int64_t, float32_t, string_t], tensor_list):
832      self.assertIs(expected, actual)
833
834    (actual_int64_t,
835     (actual_float32_t,
836      actual_string_t)) = structure.from_tensor_list(s, tensor_list)
837    self.assertIs(int64_t, actual_int64_t)
838    self.assertIs(float32_t, actual_float32_t)
839    self.assertIs(string_t, actual_string_t)
840
841    (actual_int64_t, (actual_float32_t, actual_string_t)) = (
842        structure.from_compatible_tensor_list(s, tensor_list))
843    self.assertIs(int64_t, actual_int64_t)
844    self.assertIs(float32_t, actual_float32_t)
845    self.assertIs(string_t, actual_string_t)
846
847  @combinations.generate(
848      combinations.times(test_base.default_test_combinations(),
849                         _test_batch_combinations()))
850  def testBatch(self, element_structure, batch_size,
851                expected_batched_structure):
852    batched_structure = nest.map_structure(
853        lambda component_spec: component_spec._batch(batch_size),
854        element_structure)
855    self.assertEqual(batched_structure, expected_batched_structure)
856
857  @combinations.generate(
858      combinations.times(test_base.default_test_combinations(),
859                         _test_unbatch_combinations()))
860  def testUnbatch(self, element_structure, expected_unbatched_structure):
861    unbatched_structure = nest.map_structure(
862        lambda component_spec: component_spec._unbatch(), element_structure)
863    self.assertEqual(unbatched_structure, expected_unbatched_structure)
864
865  # pylint: disable=g-long-lambda
866  @combinations.generate(
867      combinations.times(test_base.default_test_combinations(),
868                         _test_to_batched_tensor_list_combinations()))
869  def testToBatchedTensorList(self, value_fn, element_0_fn):
870    batched_value = value_fn()
871    s = structure.type_spec_from_value(batched_value)
872    batched_tensor_list = structure.to_batched_tensor_list(s, batched_value)
873
874    # The batch dimension is 2 for all of the test cases.
875    # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
876    # tensors in which we store sparse tensors.
877    for t in batched_tensor_list:
878      if t.dtype != dtypes.variant:
879        self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))
880
881    # Test that the 0th element from the unbatched tensor is equal to the
882    # expected value.
883    expected_element_0 = self.evaluate(element_0_fn())
884    unbatched_s = nest.map_structure(
885        lambda component_spec: component_spec._unbatch(), s)
886    actual_element_0 = structure.from_tensor_list(
887        unbatched_s, [t[0] for t in batched_tensor_list])
888
889    for expected, actual in zip(
890        nest.flatten(expected_element_0), nest.flatten(actual_element_0)):
891      self.assertValuesEqual(expected, actual)
892
893  # pylint: enable=g-long-lambda
894
895  @combinations.generate(test_base.default_test_combinations())
896  def testDatasetSpecConstructor(self):
897    rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
898    st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
899    t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
900    element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
901    ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
902    self.assertEqual(ds_struct._element_spec, element_spec)
903    # Note: shape was automatically converted from a list to a TensorShape.
904    self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
905
906  @combinations.generate(test_base.default_test_combinations())
907  def testCustomMapping(self):
908    elem = CustomMap(foo=constant_op.constant(37.))
909    spec = structure.type_spec_from_value(elem)
910    self.assertIsInstance(spec, CustomMap)
911    self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32))
912
913  @combinations.generate(test_base.default_test_combinations())
914  def testObjectProxy(self):
915    nt_type = collections.namedtuple("A", ["x", "y"])
916    proxied = wrapt.ObjectProxy(nt_type(1, 2))
917    proxied_spec = structure.type_spec_from_value(proxied)
918    self.assertEqual(structure.type_spec_from_value(nt_type(1, 2)),
919                     proxied_spec)
920
921
922class CustomMap(collections_abc.Mapping):
923  """Custom, immutable map."""
924
925  def __init__(self, *args, **kwargs):
926    self.__dict__.update(dict(*args, **kwargs))
927
928  def __getitem__(self, x):
929    return self.__dict__[x]
930
931  def __iter__(self):
932    return iter(self.__dict__)
933
934  def __len__(self):
935    return len(self.__dict__)
936
937
938if __name__ == "__main__":
939  test.main()
940