• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for utilities working with arbitrarily nested structures."""
16
17import collections
18import functools
19
20import numpy as np
21import wrapt
22from absl.testing import parameterized
23
24from tensorflow.python.data.kernel_tests import test_base
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.data.util import nest
27from tensorflow.python.data.util import structure
28from tensorflow.python.framework import combinations
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import tensor_array_ops
37from tensorflow.python.ops import variables
38from tensorflow.python.ops.ragged import ragged_factory_ops
39from tensorflow.python.ops.ragged import ragged_tensor
40from tensorflow.python.ops.ragged import ragged_tensor_value
41from tensorflow.python.platform import test
42from tensorflow.python.util.compat import collections_abc
43
44# NOTE(mrry): Arguments of parameterized tests are lifted into lambdas to make
45# sure they are not executed before the (eager- or graph-mode) test environment
46# has been set up.
47#
48
49
50def _test_flat_structure_combinations():
51  cases = [
52      ("Tensor", lambda: constant_op.constant(37.0),
53       lambda: tensor_spec.TensorSpec, lambda: [dtypes.float32], lambda: [[]]),
54      ("TensorArray", lambda: tensor_array_ops.TensorArray(
55          dtype=dtypes.float32, element_shape=(3,), size=0),
56       lambda: tensor_array_ops.TensorArraySpec, lambda: [dtypes.variant],
57       lambda: [[]]),
58      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
59          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
60       lambda: sparse_tensor.SparseTensorSpec, lambda: [dtypes.variant],
61       lambda: [None]),
62      ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]),
63       lambda: ragged_tensor.RaggedTensorSpec, lambda: [dtypes.variant],
64       lambda: [None]),
65      ("Nested_0", lambda:
66       (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
67       lambda: tuple, lambda: [dtypes.float32, dtypes.int32],
68       lambda: [[], [3]]),
69      ("Nested_1", lambda: {
70          "a": constant_op.constant(37.0),
71          "b": constant_op.constant([1, 2, 3])
72      }, lambda: dict, lambda: [dtypes.float32, dtypes.int32],
73       lambda: [[], [3]]),
74      ("Nested_2", lambda: {
75          "a":
76              constant_op.constant(37.0),
77          "b": (sparse_tensor.SparseTensor(
78              indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
79                sparse_tensor.SparseTensor(
80                    indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
81      }, lambda: dict, lambda: [dtypes.float32, dtypes.variant, dtypes.variant],
82       lambda: [[], None, None]),
83  ]
84
85  def reduce_fn(x, y):
86    # workaround for long line
87    name, value_fn = y[:2]
88    expected_structure_fn, expected_types_fn, expected_shapes_fn = y[2:]
89    return x + combinations.combine(
90        value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn),
91        expected_structure_fn=combinations.NamedObject(
92            "expected_structure_fn.{}".format(name), expected_structure_fn),
93        expected_types_fn=combinations.NamedObject(
94            "expected_types_fn.{}".format(name), expected_types_fn),
95        expected_shapes_fn=combinations.NamedObject(
96            "expected_shapes_fn.{}".format(name), expected_shapes_fn))
97
98  return functools.reduce(reduce_fn, cases, [])
99
100
101def _test_is_compatible_with_structure_combinations():
102  cases = [
103      ("Tensor", lambda: constant_op.constant(37.0), lambda: [
104          constant_op.constant(38.0),
105          array_ops.placeholder(dtypes.float32), 42.0,
106          np.array(42.0, dtype=np.float32)
107      ], lambda: [constant_op.constant([1.0, 2.0]),
108                  constant_op.constant(37)]),
109      # TODO(b/209081027): add Python constant and TF constant to the
110      # incompatible branch after ResourceVariable becoming a CompositeTensor.
111      ("Variable", lambda: variables.Variable(100.0),
112       lambda: [variables.Variable(99.0)],
113       lambda: [1]),
114      ("TensorArray", lambda: tensor_array_ops.TensorArray(
115          dtype=dtypes.float32, element_shape=(3,), size=0), lambda: [
116              tensor_array_ops.TensorArray(
117                  dtype=dtypes.float32, element_shape=(3,), size=0),
118              tensor_array_ops.TensorArray(
119                  dtype=dtypes.float32, element_shape=(3,), size=10)
120          ], lambda: [
121              tensor_array_ops.TensorArray(
122                  dtype=dtypes.int32, element_shape=(3,), size=0),
123              tensor_array_ops.TensorArray(
124                  dtype=dtypes.float32, element_shape=(), size=0)
125          ]),
126      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
127          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
128       lambda: [
129           sparse_tensor.SparseTensor(
130               indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
131           sparse_tensor.SparseTensorValue(
132               indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]),
133           array_ops.sparse_placeholder(dtype=dtypes.int32),
134           array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None])
135       ], lambda: [
136           constant_op.constant(37, shape=[4, 5]),
137           sparse_tensor.SparseTensor(
138               indices=[[3, 4]], values=[-1], dense_shape=[5, 6]),
139           array_ops.sparse_placeholder(
140               dtype=dtypes.int32, shape=[None, None, None]),
141           sparse_tensor.SparseTensor(
142               indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5])
143       ]),
144      ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
145       lambda: [
146           ragged_factory_ops.constant([[1, 2], [3, 4], []]),
147           ragged_factory_ops.constant([[1], [2, 3, 4], [5]]),
148       ], lambda: [
149           ragged_factory_ops.constant(1),
150           ragged_factory_ops.constant([1, 2]),
151           ragged_factory_ops.constant([[1], [2]]),
152           ragged_factory_ops.constant([["a", "b"]]),
153       ]),
154      ("Nested", lambda: {
155          "a": constant_op.constant(37.0),
156          "b": constant_op.constant([1, 2, 3])
157      }, lambda: [{
158          "a": constant_op.constant(15.0),
159          "b": constant_op.constant([4, 5, 6])
160      }], lambda: [{
161          "a": constant_op.constant(15.0),
162          "b": constant_op.constant([4, 5, 6, 7])
163      }, {
164          "a": constant_op.constant(15),
165          "b": constant_op.constant([4, 5, 6])
166      }, {
167          "a":
168              constant_op.constant(15),
169          "b":
170              sparse_tensor.SparseTensor(
171                  indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3])
172      }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]),
173  ]
174
175  def reduce_fn(x, y):
176    name, original_value_fn, compatible_values_fn, incompatible_values_fn = y
177    return x + combinations.combine(
178        original_value_fn=combinations.NamedObject(
179            "original_value_fn.{}".format(name), original_value_fn),
180        compatible_values_fn=combinations.NamedObject(
181            "compatible_values_fn.{}".format(name), compatible_values_fn),
182        incompatible_values_fn=combinations.NamedObject(
183            "incompatible_values_fn.{}".format(name), incompatible_values_fn))
184
185  return functools.reduce(reduce_fn, cases, [])
186
187
188def _test_structure_from_value_equality_combinations():
189  cases = [
190      ("Tensor", lambda: constant_op.constant(37.0),
191       lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])),
192      ("TensorArray", lambda: tensor_array_ops.TensorArray(
193          dtype=dtypes.float32, element_shape=(3,), size=0),
194       lambda: tensor_array_ops.TensorArray(
195           dtype=dtypes.float32, element_shape=(3,), size=0),
196       lambda: tensor_array_ops.TensorArray(
197           dtype=dtypes.int32, element_shape=(), size=0)),
198      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
199          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
200       lambda: sparse_tensor.SparseTensor(
201           indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda:
202       sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5]),
203       lambda: sparse_tensor.SparseTensor(
204           indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])),
205      ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]),
206       lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]),
207       lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1),
208       lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]),
209       lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])),
210      ("Nested", lambda: {
211          "a": constant_op.constant(37.0),
212          "b": constant_op.constant([1, 2, 3])
213      }, lambda: {
214          "a": constant_op.constant(42.0),
215          "b": constant_op.constant([4, 5, 6])
216      }, lambda: {
217          "a": constant_op.constant([1, 2, 3]),
218          "b": constant_op.constant(37.0)
219      }),
220  ]
221
222  def reduce_fn(x, y):
223    name, value1_fn, value2_fn, *not_equal_value_fns = y
224    return x + combinations.combine(
225        value1_fn=combinations.NamedObject("value1_fn.{}".format(name),
226                                           value1_fn),
227        value2_fn=combinations.NamedObject("value2_fn.{}".format(name),
228                                           value2_fn),
229        not_equal_value_fns=combinations.NamedObject(
230            "not_equal_value_fns.{}".format(name), not_equal_value_fns))
231
232  return functools.reduce(reduce_fn, cases, [])
233
234
235def _test_ragged_structure_inequality_combinations():
236  cases = [
237      (ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
238       ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 2)),
239      (ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32, 1),
240       ragged_tensor.RaggedTensorSpec([5, None], dtypes.int32, 1)),
241      (ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1),
242       ragged_tensor.RaggedTensorSpec(None, dtypes.float32, 1)),
243  ]
244
245  def reduce_fn(x, y):
246    spec1, spec2 = y
247    return x + combinations.combine(spec1=spec1, spec2=spec2)
248
249  return functools.reduce(reduce_fn, cases, [])
250
251
252def _test_hash_combinations():
253  cases = [
254      ("Tensor", lambda: constant_op.constant(37.0),
255       lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])),
256      ("TensorArray", lambda: tensor_array_ops.TensorArray(
257          dtype=dtypes.float32, element_shape=(3,), size=0),
258       lambda: tensor_array_ops.TensorArray(
259           dtype=dtypes.float32, element_shape=(3,), size=0),
260       lambda: tensor_array_ops.TensorArray(
261           dtype=dtypes.int32, element_shape=(), size=0)),
262      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
263          indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
264       lambda: sparse_tensor.SparseTensor(
265           indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda:
266       sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5])),
267      ("Nested", lambda: {
268          "a": constant_op.constant(37.0),
269          "b": constant_op.constant([1, 2, 3])
270      }, lambda: {
271          "a": constant_op.constant(42.0),
272          "b": constant_op.constant([4, 5, 6])
273      }, lambda: {
274          "a": constant_op.constant([1, 2, 3]),
275          "b": constant_op.constant(37.0)
276      }),
277  ]
278
279  def reduce_fn(x, y):
280    name, value1_fn, value2_fn, value3_fn = y
281    return x + combinations.combine(
282        value1_fn=combinations.NamedObject("value1_fn.{}".format(name),
283                                           value1_fn),
284        value2_fn=combinations.NamedObject("value2_fn.{}".format(name),
285                                           value2_fn),
286        value3_fn=combinations.NamedObject("value3_fn.{}".format(name),
287                                           value3_fn))
288
289  return functools.reduce(reduce_fn, cases, [])
290
291
292def _test_round_trip_conversion_combinations():
293  cases = [
294      (
295          "Tensor",
296          lambda: constant_op.constant(37.0),
297      ),
298      (
299          "SparseTensor",
300          lambda: sparse_tensor.SparseTensor(
301              indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
302      ),
303      ("TensorArray", lambda: tensor_array_ops.TensorArray(
304          dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
305      (
306          "RaggedTensor",
307          lambda: ragged_factory_ops.constant([[1, 2], [], [3]]),
308      ),
309      (
310          "Nested_0",
311          lambda: {
312              "a": constant_op.constant(37.0),
313              "b": constant_op.constant([1, 2, 3])
314          },
315      ),
316      (
317          "Nested_1",
318          lambda: {
319              "a":
320                  constant_op.constant(37.0),
321              "b": (sparse_tensor.SparseTensor(
322                  indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
323                    sparse_tensor.SparseTensor(
324                        indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
325          },
326      ),
327  ]
328
329  def reduce_fn(x, y):
330    name, value_fn = y
331    return x + combinations.combine(
332        value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn))
333
334  return functools.reduce(reduce_fn, cases, [])
335
336
337def _test_convert_legacy_structure_combinations():
338  cases = [
339      (dtypes.float32, tensor_shape.TensorShape([]), ops.Tensor,
340       tensor_spec.TensorSpec([], dtypes.float32)),
341      (dtypes.int32, tensor_shape.TensorShape([2,
342                                               2]), sparse_tensor.SparseTensor,
343       sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)),
344      (dtypes.int32, tensor_shape.TensorShape([None, True, 2, 2]),
345       tensor_array_ops.TensorArray,
346       tensor_array_ops.TensorArraySpec([2, 2],
347                                        dtypes.int32,
348                                        dynamic_size=None,
349                                        infer_shape=True)),
350      (dtypes.int32, tensor_shape.TensorShape([True, None, 2, 2]),
351       tensor_array_ops.TensorArray,
352       tensor_array_ops.TensorArraySpec([2, 2],
353                                        dtypes.int32,
354                                        dynamic_size=True,
355                                        infer_shape=None)),
356      (dtypes.int32, tensor_shape.TensorShape([True, False, 2, 2]),
357       tensor_array_ops.TensorArray,
358       tensor_array_ops.TensorArraySpec([2, 2],
359                                        dtypes.int32,
360                                        dynamic_size=True,
361                                        infer_shape=False)),
362      (dtypes.int32, tensor_shape.TensorShape([2, None]),
363       ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1),
364       ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)),
365      ({
366          "a": dtypes.float32,
367          "b": (dtypes.int32, dtypes.string)
368      }, {
369          "a": tensor_shape.TensorShape([]),
370          "b": (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([]))
371      }, {
372          "a": ops.Tensor,
373          "b": (sparse_tensor.SparseTensor, ops.Tensor)
374      }, {
375          "a":
376              tensor_spec.TensorSpec([], dtypes.float32),
377          "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
378                tensor_spec.TensorSpec([], dtypes.string))
379      })
380  ]
381
382  def reduce_fn(x, y):
383    output_types, output_shapes, output_classes, expected_structure = y
384    return x + combinations.combine(
385        output_types=output_types,
386        output_shapes=output_shapes,
387        output_classes=output_classes,
388        expected_structure=expected_structure)
389
390  return functools.reduce(reduce_fn, cases, [])
391
392
393def _test_batch_combinations():
394  cases = [
395      (tensor_spec.TensorSpec([], dtypes.float32), 32,
396       tensor_spec.TensorSpec([32], dtypes.float32)),
397      (tensor_spec.TensorSpec([], dtypes.float32), None,
398       tensor_spec.TensorSpec([None], dtypes.float32)),
399      (sparse_tensor.SparseTensorSpec([None], dtypes.float32), 32,
400       sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)),
401      (sparse_tensor.SparseTensorSpec([4], dtypes.float32), None,
402       sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32)),
403      (ragged_tensor.RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
404       ragged_tensor.RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
405      (ragged_tensor.RaggedTensorSpec([4, None], dtypes.float32, 1), None,
406       ragged_tensor.RaggedTensorSpec([None, 4, None], dtypes.float32, 2)),
407      ({
408          "a":
409              tensor_spec.TensorSpec([], dtypes.float32),
410          "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
411                tensor_spec.TensorSpec([], dtypes.string))
412      }, 128, {
413          "a":
414              tensor_spec.TensorSpec([128], dtypes.float32),
415          "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
416                tensor_spec.TensorSpec([128], dtypes.string))
417      }),
418  ]
419
420  def reduce_fn(x, y):
421    element_structure, batch_size, expected_batched_structure = y
422    return x + combinations.combine(
423        element_structure=element_structure,
424        batch_size=batch_size,
425        expected_batched_structure=expected_batched_structure)
426
427  return functools.reduce(reduce_fn, cases, [])
428
429
430def _test_unbatch_combinations():
431  cases = [
432      (tensor_spec.TensorSpec([32], dtypes.float32),
433       tensor_spec.TensorSpec([], dtypes.float32)),
434      (tensor_spec.TensorSpec([None], dtypes.float32),
435       tensor_spec.TensorSpec([], dtypes.float32)),
436      (sparse_tensor.SparseTensorSpec([32, None], dtypes.float32),
437       sparse_tensor.SparseTensorSpec([None], dtypes.float32)),
438      (sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32),
439       sparse_tensor.SparseTensorSpec([4], dtypes.float32)),
440      (ragged_tensor.RaggedTensorSpec([32, None, None], dtypes.float32, 2),
441       ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
442      (ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.float32, 2),
443       ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)),
444      ({
445          "a":
446              tensor_spec.TensorSpec([128], dtypes.float32),
447          "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32),
448                tensor_spec.TensorSpec([None], dtypes.string))
449      }, {
450          "a":
451              tensor_spec.TensorSpec([], dtypes.float32),
452          "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32),
453                tensor_spec.TensorSpec([], dtypes.string))
454      }),
455  ]
456
457  def reduce_fn(x, y):
458    element_structure, expected_unbatched_structure = y
459    return x + combinations.combine(
460        element_structure=element_structure,
461        expected_unbatched_structure=expected_unbatched_structure)
462
463  return functools.reduce(reduce_fn, cases, [])
464
465
466def _test_to_batched_tensor_list_combinations():
467  cases = [
468      ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
469       lambda: constant_op.constant([1.0, 2.0])),
470      ("SparseTensor", lambda: sparse_tensor.SparseTensor(
471          indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]),
472       lambda: sparse_tensor.SparseTensor(
473           indices=[[0]], values=[13], dense_shape=[2])),
474      ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]),
475       lambda: ragged_factory_ops.constant([[1]])),
476      ("Nest", lambda:
477       (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]),
478        sparse_tensor.SparseTensor(
479            indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])),
480       lambda:
481       (constant_op.constant([1.0, 2.0]),
482        sparse_tensor.SparseTensor(indices=[[0]], values=[13], dense_shape=[2]))
483      ),
484  ]
485
486  def reduce_fn(x, y):
487    name, value_fn, element_0_fn = y
488    return x + combinations.combine(
489        value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn),
490        element_0_fn=combinations.NamedObject("element_0_fn.{}".format(name),
491                                              element_0_fn))
492
493  return functools.reduce(reduce_fn, cases, [])
494
495
496# TODO(jsimsa): Add tests for OptionalStructure and DatasetStructure.
497class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
498
499  # pylint: disable=g-long-lambda,protected-access
500  @combinations.generate(
501      combinations.times(test_base.default_test_combinations(),
502                         _test_flat_structure_combinations()))
503  def testFlatStructure(self, value_fn, expected_structure_fn,
504                        expected_types_fn, expected_shapes_fn):
505    value = value_fn()
506    expected_structure = expected_structure_fn()
507    expected_types = expected_types_fn()
508    expected_shapes = expected_shapes_fn()
509    s = structure.type_spec_from_value(value)
510    self.assertIsInstance(s, expected_structure)
511    flat_types = structure.get_flat_tensor_types(s)
512    self.assertEqual(expected_types, flat_types)
513    flat_shapes = structure.get_flat_tensor_shapes(s)
514    self.assertLen(flat_shapes, len(expected_shapes))
515    for expected, actual in zip(expected_shapes, flat_shapes):
516      if expected is None:
517        self.assertEqual(actual.ndims, None)
518      else:
519        self.assertEqual(actual.as_list(), expected)
520
521  @combinations.generate(
522      combinations.times(test_base.graph_only_combinations(),
523                         _test_is_compatible_with_structure_combinations()))
524  def testIsCompatibleWithStructure(self, original_value_fn,
525                                    compatible_values_fn,
526                                    incompatible_values_fn):
527    original_value = original_value_fn()
528    compatible_values = compatible_values_fn()
529    incompatible_values = incompatible_values_fn()
530
531    s = structure.type_spec_from_value(original_value)
532    for compatible_value in compatible_values:
533      self.assertTrue(
534          structure.are_compatible(
535              s, structure.type_spec_from_value(compatible_value)))
536    for incompatible_value in incompatible_values:
537      self.assertFalse(
538          structure.are_compatible(
539              s, structure.type_spec_from_value(incompatible_value)))
540
541  @combinations.generate(
542      combinations.times(test_base.default_test_combinations(),
543                         _test_structure_from_value_equality_combinations()))
544  def testStructureFromValueEquality(self, value1_fn, value2_fn,
545                                     not_equal_value_fns):
546    # pylint: disable=g-generic-assert
547    not_equal_value_fns = not_equal_value_fns._obj
548    s1 = structure.type_spec_from_value(value1_fn())
549    s2 = structure.type_spec_from_value(value2_fn())
550    self.assertEqual(s1, s1)  # check __eq__ operator.
551    self.assertEqual(s1, s2)  # check __eq__ operator.
552    self.assertFalse(s1 != s1)  # check __ne__ operator.
553    self.assertFalse(s1 != s2)  # check __ne__ operator.
554    for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)):
555      self.assertEqual(hash(c1), hash(c1))
556      self.assertEqual(hash(c1), hash(c2))
557    for value_fn in not_equal_value_fns:
558      s3 = structure.type_spec_from_value(value_fn())
559      self.assertNotEqual(s1, s3)  # check __ne__ operator.
560      self.assertNotEqual(s2, s3)  # check __ne__ operator.
561      self.assertFalse(s1 == s3)  # check __eq_ operator.
562      self.assertFalse(s2 == s3)  # check __eq_ operator.
563
564  @combinations.generate(
565      combinations.times(test_base.default_test_combinations(),
566                         _test_ragged_structure_inequality_combinations()))
567  def testRaggedStructureInequality(self, spec1, spec2):
568    # pylint: disable=g-generic-assert
569    self.assertNotEqual(spec1, spec2)  # check __ne__ operator.
570    self.assertFalse(spec1 == spec2)  # check __eq__ operator.
571
572  @combinations.generate(
573      combinations.times(test_base.default_test_combinations(),
574                         _test_hash_combinations()))
575  def testHash(self, value1_fn, value2_fn, value3_fn):
576    s1 = structure.type_spec_from_value(value1_fn())
577    s2 = structure.type_spec_from_value(value2_fn())
578    s3 = structure.type_spec_from_value(value3_fn())
579    for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2), nest.flatten(s3)):
580      self.assertEqual(hash(c1), hash(c1))
581      self.assertEqual(hash(c1), hash(c2))
582      self.assertNotEqual(hash(c1), hash(c3))
583      self.assertNotEqual(hash(c2), hash(c3))
584
585  @combinations.generate(
586      combinations.times(test_base.default_test_combinations(),
587                         _test_round_trip_conversion_combinations()))
588  def testRoundTripConversion(self, value_fn):
589    value = value_fn()
590    s = structure.type_spec_from_value(value)
591
592    def maybe_stack_ta(v):
593      if isinstance(v, tensor_array_ops.TensorArray):
594        return v.stack()
595      return v
596
597    before = self.evaluate(maybe_stack_ta(value))
598    after = self.evaluate(
599        maybe_stack_ta(
600            structure.from_tensor_list(s, structure.to_tensor_list(s, value))))
601
602    flat_before = nest.flatten(before)
603    flat_after = nest.flatten(after)
604    for b, a in zip(flat_before, flat_after):
605      if isinstance(b, sparse_tensor.SparseTensorValue):
606        self.assertAllEqual(b.indices, a.indices)
607        self.assertAllEqual(b.values, a.values)
608        self.assertAllEqual(b.dense_shape, a.dense_shape)
609      elif isinstance(
610          b,
611          (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)):
612        self.assertAllEqual(b, a)
613      else:
614        self.assertAllEqual(b, a)
615
616  # pylint: enable=g-long-lambda
617
618  def preserveStaticShape(self):
619    rt = ragged_factory_ops.constant([[1, 2], [], [3]])
620    rt_s = structure.type_spec_from_value(rt)
621    rt_after = structure.from_tensor_list(rt_s,
622                                          structure.to_tensor_list(rt_s, rt))
623    self.assertEqual(rt_after.row_splits.shape.as_list(),
624                     rt.row_splits.shape.as_list())
625    self.assertEqual(rt_after.values.shape.as_list(), [None])
626
627    st = sparse_tensor.SparseTensor(
628        indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
629    st_s = structure.type_spec_from_value(st)
630    st_after = structure.from_tensor_list(st_s,
631                                          structure.to_tensor_list(st_s, st))
632    self.assertEqual(st_after.indices.shape.as_list(), [None, 2])
633    self.assertEqual(st_after.values.shape.as_list(), [None])
634    self.assertEqual(st_after.dense_shape.shape.as_list(),
635                     st.dense_shape.shape.as_list())
636
637  @combinations.generate(test_base.default_test_combinations())
638  def testPreserveTensorArrayShape(self):
639    ta = tensor_array_ops.TensorArray(
640        dtype=dtypes.int32, size=1, element_shape=(3,))
641    ta_s = structure.type_spec_from_value(ta)
642    ta_after = structure.from_tensor_list(ta_s,
643                                          structure.to_tensor_list(ta_s, ta))
644    self.assertEqual(ta_after.element_shape.as_list(), [3])
645
646  @combinations.generate(test_base.default_test_combinations())
647  def testPreserveInferredTensorArrayShape(self):
648    ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1)
649    # Shape is inferred from the write.
650    ta = ta.write(0, [1, 2, 3])
651    ta_s = structure.type_spec_from_value(ta)
652    ta_after = structure.from_tensor_list(ta_s,
653                                          structure.to_tensor_list(ta_s, ta))
654    self.assertEqual(ta_after.element_shape.as_list(), [3])
655
656  @combinations.generate(test_base.default_test_combinations())
657  def testIncompatibleStructure(self):
658    # Define three mutually incompatible values/structures, and assert that:
659    # 1. Using one structure to flatten a value with an incompatible structure
660    #    fails.
661    # 2. Using one structure to restructure a flattened value with an
662    #    incompatible structure fails.
663    value_tensor = constant_op.constant(42.0)
664    s_tensor = structure.type_spec_from_value(value_tensor)
665    flat_tensor = structure.to_tensor_list(s_tensor, value_tensor)
666
667    value_sparse_tensor = sparse_tensor.SparseTensor(
668        indices=[[0, 0]], values=[1], dense_shape=[1, 1])
669    s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor)
670    flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor,
671                                                  value_sparse_tensor)
672
673    value_nest = {
674        "a": constant_op.constant(37.0),
675        "b": constant_op.constant([1, 2, 3])
676    }
677    s_nest = structure.type_spec_from_value(value_nest)
678    flat_nest = structure.to_tensor_list(s_nest, value_nest)
679
680    with self.assertRaisesRegex(
681        ValueError, r"SparseTensor.* is not convertible to a tensor with "
682        r"dtype.*float32.* and shape \(\)"):
683      structure.to_tensor_list(s_tensor, value_sparse_tensor)
684    with self.assertRaisesRegex(
685        ValueError, "The two structures don't have the same nested structure."):
686      structure.to_tensor_list(s_tensor, value_nest)
687
688    with self.assertRaisesRegex(TypeError,
689                                "neither a SparseTensor nor SparseTensorValue"):
690      structure.to_tensor_list(s_sparse_tensor, value_tensor)
691
692    with self.assertRaisesRegex(
693        ValueError, "The two structures don't have the same nested structure."):
694      structure.to_tensor_list(s_sparse_tensor, value_nest)
695
696    with self.assertRaisesRegex(
697        ValueError, "The two structures don't have the same nested structure."):
698      structure.to_tensor_list(s_nest, value_tensor)
699
700    with self.assertRaisesRegex(
701        ValueError, "The two structures don't have the same nested structure."):
702      structure.to_tensor_list(s_nest, value_sparse_tensor)
703
704    with self.assertRaisesRegex(
705        ValueError,
706        "Cannot create a Tensor from the tensor list because item 0 "
707        ".*tf.Tensor.* is incompatible with the expected TypeSpec "
708        ".*TensorSpec.*"):
709      structure.from_tensor_list(s_tensor, flat_sparse_tensor)
710
711    with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."):
712      structure.from_tensor_list(s_tensor, flat_nest)
713
714    with self.assertRaisesRegex(
715        ValueError, "Cannot create a SparseTensor from the tensor list because "
716        "item 0 .*tf.Tensor.* is incompatible with the expected TypeSpec "
717        ".*TensorSpec.*"):
718      structure.from_tensor_list(s_sparse_tensor, flat_tensor)
719
720    with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."):
721      structure.from_tensor_list(s_sparse_tensor, flat_nest)
722
723    with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."):
724      structure.from_tensor_list(s_nest, flat_tensor)
725
726    with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."):
727      structure.from_tensor_list(s_nest, flat_sparse_tensor)
728
729  @combinations.generate(test_base.default_test_combinations())
730  def testIncompatibleNestedStructure(self):
731    # Define three mutually incompatible nested values/structures, and assert
732    # that:
733    # 1. Using one structure to flatten a value with an incompatible structure
734    #    fails.
735    # 2. Using one structure to restructure a flattened value with an
736    #    incompatible structure fails.
737
738    value_0 = {
739        "a": constant_op.constant(37.0),
740        "b": constant_op.constant([1, 2, 3])
741    }
742    s_0 = structure.type_spec_from_value(value_0)
743    flat_s_0 = structure.to_tensor_list(s_0, value_0)
744
745    # `value_1` has compatible nested structure with `value_0`, but different
746    # classes.
747    value_1 = {
748        "a":
749            constant_op.constant(37.0),
750        "b":
751            sparse_tensor.SparseTensor(
752                indices=[[0, 0]], values=[1], dense_shape=[1, 1])
753    }
754    s_1 = structure.type_spec_from_value(value_1)
755    flat_s_1 = structure.to_tensor_list(s_1, value_1)
756
757    # `value_2` has incompatible nested structure with `value_0` and `value_1`.
758    value_2 = {
759        "a":
760            constant_op.constant(37.0),
761        "b": (sparse_tensor.SparseTensor(
762            indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
763              sparse_tensor.SparseTensor(
764                  indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
765    }
766    s_2 = structure.type_spec_from_value(value_2)
767    flat_s_2 = structure.to_tensor_list(s_2, value_2)
768
769    with self.assertRaisesRegex(
770        ValueError, r"SparseTensor.* is not convertible to a tensor with "
771        r"dtype.*int32.* and shape \(3,\)"):
772      structure.to_tensor_list(s_0, value_1)
773
774    with self.assertRaisesRegex(
775        ValueError, "The two structures don't have the same nested structure."):
776      structure.to_tensor_list(s_0, value_2)
777
778    with self.assertRaisesRegex(TypeError,
779                                "neither a SparseTensor nor SparseTensorValue"):
780      structure.to_tensor_list(s_1, value_0)
781
782    with self.assertRaisesRegex(
783        ValueError, "The two structures don't have the same nested structure."):
784      structure.to_tensor_list(s_1, value_2)
785
786    # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
787    # needs to account for "a" coming before or after "b". It might be worth
788    # adding a deterministic repr for these error messages (among other
789    # improvements).
790    with self.assertRaisesRegex(
791        ValueError, "The two structures don't have the same nested structure."):
792      structure.to_tensor_list(s_2, value_0)
793
794    with self.assertRaisesRegex(
795        ValueError, "The two structures don't have the same nested structure."):
796      structure.to_tensor_list(s_2, value_1)
797
798    with self.assertRaisesRegex(ValueError,
799                                r"Cannot create a Tensor from the tensor list"):
800      structure.from_tensor_list(s_0, flat_s_1)
801
802    with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3"):
803      structure.from_tensor_list(s_0, flat_s_2)
804
805    with self.assertRaisesRegex(
806        ValueError, "Cannot create a SparseTensor from the tensor list"):
807      structure.from_tensor_list(s_1, flat_s_0)
808
809    with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3"):
810      structure.from_tensor_list(s_1, flat_s_2)
811
812    with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2"):
813      structure.from_tensor_list(s_2, flat_s_0)
814
815    with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2"):
816      structure.from_tensor_list(s_2, flat_s_1)
817
818  @combinations.generate(
819      combinations.times(test_base.default_test_combinations(),
820                         _test_convert_legacy_structure_combinations()))
821  def testConvertLegacyStructure(self, output_types, output_shapes,
822                                 output_classes, expected_structure):
823    actual_structure = structure.convert_legacy_structure(
824        output_types, output_shapes, output_classes)
825    self.assertEqual(actual_structure, expected_structure)
826
827  @combinations.generate(test_base.default_test_combinations())
828  def testConvertLegacyStructureFail(self):
829    with self.assertRaisesRegex(
830        TypeError, "Could not build a structure for output class "
831        "_EagerTensorArray. Make sure any component class in "
832        "`output_classes` inherits from one of the following classes: "
833        "`tf.TypeSpec`, `tf.sparse.SparseTensor`, `tf.Tensor`, "
834        "`tf.TensorArray`."):
835      structure.convert_legacy_structure(dtypes.int32,
836                                         tensor_shape.TensorShape([2, None]),
837                                         tensor_array_ops._EagerTensorArray)
838
839  @combinations.generate(test_base.default_test_combinations())
840  def testNestedNestedStructure(self):
841    s = (tensor_spec.TensorSpec([], dtypes.int64),
842         (tensor_spec.TensorSpec([], dtypes.float32),
843          tensor_spec.TensorSpec([], dtypes.string)))
844
845    int64_t = constant_op.constant(37, dtype=dtypes.int64)
846    float32_t = constant_op.constant(42.0)
847    string_t = constant_op.constant("Foo")
848
849    nested_tensors = (int64_t, (float32_t, string_t))
850
851    tensor_list = structure.to_tensor_list(s, nested_tensors)
852    for expected, actual in zip([int64_t, float32_t, string_t], tensor_list):
853      self.assertIs(expected, actual)
854
855    (actual_int64_t,
856     (actual_float32_t,
857      actual_string_t)) = structure.from_tensor_list(s, tensor_list)
858    self.assertIs(int64_t, actual_int64_t)
859    self.assertIs(float32_t, actual_float32_t)
860    self.assertIs(string_t, actual_string_t)
861
862    (actual_int64_t, (actual_float32_t, actual_string_t)) = (
863        structure.from_compatible_tensor_list(s, tensor_list))
864    self.assertIs(int64_t, actual_int64_t)
865    self.assertIs(float32_t, actual_float32_t)
866    self.assertIs(string_t, actual_string_t)
867
868  @combinations.generate(
869      combinations.times(test_base.default_test_combinations(),
870                         _test_batch_combinations()))
871  def testBatch(self, element_structure, batch_size,
872                expected_batched_structure):
873    batched_structure = nest.map_structure(
874        lambda component_spec: component_spec._batch(batch_size),
875        element_structure)
876    self.assertEqual(batched_structure, expected_batched_structure)
877
878  @combinations.generate(
879      combinations.times(test_base.default_test_combinations(),
880                         _test_unbatch_combinations()))
881  def testUnbatch(self, element_structure, expected_unbatched_structure):
882    unbatched_structure = nest.map_structure(
883        lambda component_spec: component_spec._unbatch(), element_structure)
884    self.assertEqual(unbatched_structure, expected_unbatched_structure)
885
886  # pylint: disable=g-long-lambda
887  @combinations.generate(
888      combinations.times(test_base.default_test_combinations(),
889                         _test_to_batched_tensor_list_combinations()))
890  def testToBatchedTensorList(self, value_fn, element_0_fn):
891    batched_value = value_fn()
892    s = structure.type_spec_from_value(batched_value)
893    batched_tensor_list = structure.to_batched_tensor_list(s, batched_value)
894
895    # The batch dimension is 2 for all of the test cases.
896    # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
897    # tensors in which we store sparse tensors.
898    for t in batched_tensor_list:
899      if t.dtype != dtypes.variant:
900        self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))
901
902    # Test that the 0th element from the unbatched tensor is equal to the
903    # expected value.
904    expected_element_0 = self.evaluate(element_0_fn())
905    unbatched_s = nest.map_structure(
906        lambda component_spec: component_spec._unbatch(), s)
907    actual_element_0 = structure.from_tensor_list(
908        unbatched_s, [t[0] for t in batched_tensor_list])
909
910    for expected, actual in zip(
911        nest.flatten(expected_element_0), nest.flatten(actual_element_0)):
912      self.assertValuesEqual(expected, actual)
913
914  # pylint: enable=g-long-lambda
915
916  @combinations.generate(test_base.default_test_combinations())
917  def testDatasetSpecConstructor(self):
918    rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32)
919    st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)
920    t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string)
921    element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec}
922    ds_struct = dataset_ops.DatasetSpec(element_spec, [5])
923    self.assertEqual(ds_struct._element_spec, element_spec)
924    # Note: shape was automatically converted from a list to a TensorShape.
925    self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5]))
926
927  @combinations.generate(test_base.default_test_combinations())
928  def testCustomMapping(self):
929    elem = CustomMap(foo=constant_op.constant(37.))
930    spec = structure.type_spec_from_value(elem)
931    self.assertIsInstance(spec, CustomMap)
932    self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32))
933
934  @combinations.generate(test_base.default_test_combinations())
935  def testObjectProxy(self):
936    nt_type = collections.namedtuple("A", ["x", "y"])
937    proxied = wrapt.ObjectProxy(nt_type(1, 2))
938    proxied_spec = structure.type_spec_from_value(proxied)
939    self.assertEqual(
940        structure.type_spec_from_value(nt_type(1, 2)), proxied_spec)
941
942  @combinations.generate(test_base.default_test_combinations())
943  def testTypeSpecNotBuild(self):
944    with self.assertRaisesRegex(
945        TypeError, "Could not build a `TypeSpec` for 100 with type int"):
946      structure.type_spec_from_value(100, use_fallback=False)
947
948  @combinations.generate(test_base.default_test_combinations())
949  def testTypeSpecNotCompatible(self):
950    test_obj = structure.NoneTensorSpec()
951    with self.assertRaisesRegex(
952        ValueError, r"No `TypeSpec` is compatible with both NoneTensorSpec\(\) "
953        "and 100"):
954      test_obj.most_specific_compatible_shape(100)
955    self.assertEqual(test_obj,
956                     test_obj.most_specific_compatible_shape(test_obj))
957
958
959class CustomMap(collections_abc.Mapping):
960  """Custom, immutable map."""
961
962  def __init__(self, *args, **kwargs):
963    self.__dict__.update(dict(*args, **kwargs))
964
965  def __getitem__(self, x):
966    return self.__dict__[x]
967
968  def __iter__(self):
969    return iter(self.__dict__)
970
971  def __len__(self):
972    return len(self.__dict__)
973
974
975if __name__ == "__main__":
976  test.main()
977