1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for utilities working with arbitrarily nested structures.""" 16 17import collections 18import functools 19 20import numpy as np 21import wrapt 22from absl.testing import parameterized 23 24from tensorflow.python.data.kernel_tests import test_base 25from tensorflow.python.data.ops import dataset_ops 26from tensorflow.python.data.util import nest 27from tensorflow.python.data.util import structure 28from tensorflow.python.framework import combinations 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import sparse_tensor 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.framework import tensor_spec 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import tensor_array_ops 37from tensorflow.python.ops import variables 38from tensorflow.python.ops.ragged import ragged_factory_ops 39from tensorflow.python.ops.ragged import ragged_tensor 40from tensorflow.python.ops.ragged import ragged_tensor_value 41from tensorflow.python.platform import test 42from tensorflow.python.util.compat import collections_abc 43 44# NOTE(mrry): Arguments of parameterized tests are lifted into lambdas to make 45# sure they are not executed before the (eager- or graph-mode) test environment 46# has been set up. 47# 48 49 50def _test_flat_structure_combinations(): 51 cases = [ 52 ("Tensor", lambda: constant_op.constant(37.0), 53 lambda: tensor_spec.TensorSpec, lambda: [dtypes.float32], lambda: [[]]), 54 ("TensorArray", lambda: tensor_array_ops.TensorArray( 55 dtype=dtypes.float32, element_shape=(3,), size=0), 56 lambda: tensor_array_ops.TensorArraySpec, lambda: [dtypes.variant], 57 lambda: [[]]), 58 ("SparseTensor", lambda: sparse_tensor.SparseTensor( 59 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 60 lambda: sparse_tensor.SparseTensorSpec, lambda: [dtypes.variant], 61 lambda: [None]), 62 ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]), 63 lambda: ragged_tensor.RaggedTensorSpec, lambda: [dtypes.variant], 64 lambda: [None]), 65 ("Nested_0", lambda: 66 (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), 67 lambda: tuple, lambda: [dtypes.float32, dtypes.int32], 68 lambda: [[], [3]]), 69 ("Nested_1", lambda: { 70 "a": constant_op.constant(37.0), 71 "b": constant_op.constant([1, 2, 3]) 72 }, lambda: dict, lambda: [dtypes.float32, dtypes.int32], 73 lambda: [[], [3]]), 74 ("Nested_2", lambda: { 75 "a": 76 constant_op.constant(37.0), 77 "b": (sparse_tensor.SparseTensor( 78 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 79 sparse_tensor.SparseTensor( 80 indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) 81 }, lambda: dict, lambda: [dtypes.float32, dtypes.variant, dtypes.variant], 82 lambda: [[], None, None]), 83 ] 84 85 def reduce_fn(x, y): 86 # workaround for long line 87 name, value_fn = y[:2] 88 expected_structure_fn, expected_types_fn, expected_shapes_fn = y[2:] 89 return x + combinations.combine( 90 value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn), 91 expected_structure_fn=combinations.NamedObject( 92 "expected_structure_fn.{}".format(name), expected_structure_fn), 93 expected_types_fn=combinations.NamedObject( 94 "expected_types_fn.{}".format(name), expected_types_fn), 95 expected_shapes_fn=combinations.NamedObject( 96 "expected_shapes_fn.{}".format(name), expected_shapes_fn)) 97 98 return functools.reduce(reduce_fn, cases, []) 99 100 101def _test_is_compatible_with_structure_combinations(): 102 cases = [ 103 ("Tensor", lambda: constant_op.constant(37.0), lambda: [ 104 constant_op.constant(38.0), 105 array_ops.placeholder(dtypes.float32), 42.0, 106 np.array(42.0, dtype=np.float32) 107 ], lambda: [constant_op.constant([1.0, 2.0]), 108 constant_op.constant(37)]), 109 # TODO(b/209081027): add Python constant and TF constant to the 110 # incompatible branch after ResourceVariable becoming a CompositeTensor. 111 ("Variable", lambda: variables.Variable(100.0), 112 lambda: [variables.Variable(99.0)], 113 lambda: [1]), 114 ("TensorArray", lambda: tensor_array_ops.TensorArray( 115 dtype=dtypes.float32, element_shape=(3,), size=0), lambda: [ 116 tensor_array_ops.TensorArray( 117 dtype=dtypes.float32, element_shape=(3,), size=0), 118 tensor_array_ops.TensorArray( 119 dtype=dtypes.float32, element_shape=(3,), size=10) 120 ], lambda: [ 121 tensor_array_ops.TensorArray( 122 dtype=dtypes.int32, element_shape=(3,), size=0), 123 tensor_array_ops.TensorArray( 124 dtype=dtypes.float32, element_shape=(), size=0) 125 ]), 126 ("SparseTensor", lambda: sparse_tensor.SparseTensor( 127 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 128 lambda: [ 129 sparse_tensor.SparseTensor( 130 indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), 131 sparse_tensor.SparseTensorValue( 132 indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), 133 array_ops.sparse_placeholder(dtype=dtypes.int32), 134 array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) 135 ], lambda: [ 136 constant_op.constant(37, shape=[4, 5]), 137 sparse_tensor.SparseTensor( 138 indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), 139 array_ops.sparse_placeholder( 140 dtype=dtypes.int32, shape=[None, None, None]), 141 sparse_tensor.SparseTensor( 142 indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) 143 ]), 144 ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), 145 lambda: [ 146 ragged_factory_ops.constant([[1, 2], [3, 4], []]), 147 ragged_factory_ops.constant([[1], [2, 3, 4], [5]]), 148 ], lambda: [ 149 ragged_factory_ops.constant(1), 150 ragged_factory_ops.constant([1, 2]), 151 ragged_factory_ops.constant([[1], [2]]), 152 ragged_factory_ops.constant([["a", "b"]]), 153 ]), 154 ("Nested", lambda: { 155 "a": constant_op.constant(37.0), 156 "b": constant_op.constant([1, 2, 3]) 157 }, lambda: [{ 158 "a": constant_op.constant(15.0), 159 "b": constant_op.constant([4, 5, 6]) 160 }], lambda: [{ 161 "a": constant_op.constant(15.0), 162 "b": constant_op.constant([4, 5, 6, 7]) 163 }, { 164 "a": constant_op.constant(15), 165 "b": constant_op.constant([4, 5, 6]) 166 }, { 167 "a": 168 constant_op.constant(15), 169 "b": 170 sparse_tensor.SparseTensor( 171 indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) 172 }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), 173 ] 174 175 def reduce_fn(x, y): 176 name, original_value_fn, compatible_values_fn, incompatible_values_fn = y 177 return x + combinations.combine( 178 original_value_fn=combinations.NamedObject( 179 "original_value_fn.{}".format(name), original_value_fn), 180 compatible_values_fn=combinations.NamedObject( 181 "compatible_values_fn.{}".format(name), compatible_values_fn), 182 incompatible_values_fn=combinations.NamedObject( 183 "incompatible_values_fn.{}".format(name), incompatible_values_fn)) 184 185 return functools.reduce(reduce_fn, cases, []) 186 187 188def _test_structure_from_value_equality_combinations(): 189 cases = [ 190 ("Tensor", lambda: constant_op.constant(37.0), 191 lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])), 192 ("TensorArray", lambda: tensor_array_ops.TensorArray( 193 dtype=dtypes.float32, element_shape=(3,), size=0), 194 lambda: tensor_array_ops.TensorArray( 195 dtype=dtypes.float32, element_shape=(3,), size=0), 196 lambda: tensor_array_ops.TensorArray( 197 dtype=dtypes.int32, element_shape=(), size=0)), 198 ("SparseTensor", lambda: sparse_tensor.SparseTensor( 199 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 200 lambda: sparse_tensor.SparseTensor( 201 indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda: 202 sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5]), 203 lambda: sparse_tensor.SparseTensor( 204 indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])), 205 ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]), 206 lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]), 207 lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1), 208 lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]), 209 lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])), 210 ("Nested", lambda: { 211 "a": constant_op.constant(37.0), 212 "b": constant_op.constant([1, 2, 3]) 213 }, lambda: { 214 "a": constant_op.constant(42.0), 215 "b": constant_op.constant([4, 5, 6]) 216 }, lambda: { 217 "a": constant_op.constant([1, 2, 3]), 218 "b": constant_op.constant(37.0) 219 }), 220 ] 221 222 def reduce_fn(x, y): 223 name, value1_fn, value2_fn, *not_equal_value_fns = y 224 return x + combinations.combine( 225 value1_fn=combinations.NamedObject("value1_fn.{}".format(name), 226 value1_fn), 227 value2_fn=combinations.NamedObject("value2_fn.{}".format(name), 228 value2_fn), 229 not_equal_value_fns=combinations.NamedObject( 230 "not_equal_value_fns.{}".format(name), not_equal_value_fns)) 231 232 return functools.reduce(reduce_fn, cases, []) 233 234 235def _test_ragged_structure_inequality_combinations(): 236 cases = [ 237 (ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1), 238 ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 2)), 239 (ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32, 1), 240 ragged_tensor.RaggedTensorSpec([5, None], dtypes.int32, 1)), 241 (ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1), 242 ragged_tensor.RaggedTensorSpec(None, dtypes.float32, 1)), 243 ] 244 245 def reduce_fn(x, y): 246 spec1, spec2 = y 247 return x + combinations.combine(spec1=spec1, spec2=spec2) 248 249 return functools.reduce(reduce_fn, cases, []) 250 251 252def _test_hash_combinations(): 253 cases = [ 254 ("Tensor", lambda: constant_op.constant(37.0), 255 lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])), 256 ("TensorArray", lambda: tensor_array_ops.TensorArray( 257 dtype=dtypes.float32, element_shape=(3,), size=0), 258 lambda: tensor_array_ops.TensorArray( 259 dtype=dtypes.float32, element_shape=(3,), size=0), 260 lambda: tensor_array_ops.TensorArray( 261 dtype=dtypes.int32, element_shape=(), size=0)), 262 ("SparseTensor", lambda: sparse_tensor.SparseTensor( 263 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 264 lambda: sparse_tensor.SparseTensor( 265 indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda: 266 sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5])), 267 ("Nested", lambda: { 268 "a": constant_op.constant(37.0), 269 "b": constant_op.constant([1, 2, 3]) 270 }, lambda: { 271 "a": constant_op.constant(42.0), 272 "b": constant_op.constant([4, 5, 6]) 273 }, lambda: { 274 "a": constant_op.constant([1, 2, 3]), 275 "b": constant_op.constant(37.0) 276 }), 277 ] 278 279 def reduce_fn(x, y): 280 name, value1_fn, value2_fn, value3_fn = y 281 return x + combinations.combine( 282 value1_fn=combinations.NamedObject("value1_fn.{}".format(name), 283 value1_fn), 284 value2_fn=combinations.NamedObject("value2_fn.{}".format(name), 285 value2_fn), 286 value3_fn=combinations.NamedObject("value3_fn.{}".format(name), 287 value3_fn)) 288 289 return functools.reduce(reduce_fn, cases, []) 290 291 292def _test_round_trip_conversion_combinations(): 293 cases = [ 294 ( 295 "Tensor", 296 lambda: constant_op.constant(37.0), 297 ), 298 ( 299 "SparseTensor", 300 lambda: sparse_tensor.SparseTensor( 301 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 302 ), 303 ("TensorArray", lambda: tensor_array_ops.TensorArray( 304 dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)), 305 ( 306 "RaggedTensor", 307 lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), 308 ), 309 ( 310 "Nested_0", 311 lambda: { 312 "a": constant_op.constant(37.0), 313 "b": constant_op.constant([1, 2, 3]) 314 }, 315 ), 316 ( 317 "Nested_1", 318 lambda: { 319 "a": 320 constant_op.constant(37.0), 321 "b": (sparse_tensor.SparseTensor( 322 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 323 sparse_tensor.SparseTensor( 324 indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) 325 }, 326 ), 327 ] 328 329 def reduce_fn(x, y): 330 name, value_fn = y 331 return x + combinations.combine( 332 value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn)) 333 334 return functools.reduce(reduce_fn, cases, []) 335 336 337def _test_convert_legacy_structure_combinations(): 338 cases = [ 339 (dtypes.float32, tensor_shape.TensorShape([]), ops.Tensor, 340 tensor_spec.TensorSpec([], dtypes.float32)), 341 (dtypes.int32, tensor_shape.TensorShape([2, 342 2]), sparse_tensor.SparseTensor, 343 sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)), 344 (dtypes.int32, tensor_shape.TensorShape([None, True, 2, 2]), 345 tensor_array_ops.TensorArray, 346 tensor_array_ops.TensorArraySpec([2, 2], 347 dtypes.int32, 348 dynamic_size=None, 349 infer_shape=True)), 350 (dtypes.int32, tensor_shape.TensorShape([True, None, 2, 2]), 351 tensor_array_ops.TensorArray, 352 tensor_array_ops.TensorArraySpec([2, 2], 353 dtypes.int32, 354 dynamic_size=True, 355 infer_shape=None)), 356 (dtypes.int32, tensor_shape.TensorShape([True, False, 2, 2]), 357 tensor_array_ops.TensorArray, 358 tensor_array_ops.TensorArraySpec([2, 2], 359 dtypes.int32, 360 dynamic_size=True, 361 infer_shape=False)), 362 (dtypes.int32, tensor_shape.TensorShape([2, None]), 363 ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1), 364 ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)), 365 ({ 366 "a": dtypes.float32, 367 "b": (dtypes.int32, dtypes.string) 368 }, { 369 "a": tensor_shape.TensorShape([]), 370 "b": (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([])) 371 }, { 372 "a": ops.Tensor, 373 "b": (sparse_tensor.SparseTensor, ops.Tensor) 374 }, { 375 "a": 376 tensor_spec.TensorSpec([], dtypes.float32), 377 "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), 378 tensor_spec.TensorSpec([], dtypes.string)) 379 }) 380 ] 381 382 def reduce_fn(x, y): 383 output_types, output_shapes, output_classes, expected_structure = y 384 return x + combinations.combine( 385 output_types=output_types, 386 output_shapes=output_shapes, 387 output_classes=output_classes, 388 expected_structure=expected_structure) 389 390 return functools.reduce(reduce_fn, cases, []) 391 392 393def _test_batch_combinations(): 394 cases = [ 395 (tensor_spec.TensorSpec([], dtypes.float32), 32, 396 tensor_spec.TensorSpec([32], dtypes.float32)), 397 (tensor_spec.TensorSpec([], dtypes.float32), None, 398 tensor_spec.TensorSpec([None], dtypes.float32)), 399 (sparse_tensor.SparseTensorSpec([None], dtypes.float32), 32, 400 sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)), 401 (sparse_tensor.SparseTensorSpec([4], dtypes.float32), None, 402 sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32)), 403 (ragged_tensor.RaggedTensorSpec([2, None], dtypes.float32, 1), 32, 404 ragged_tensor.RaggedTensorSpec([32, 2, None], dtypes.float32, 2)), 405 (ragged_tensor.RaggedTensorSpec([4, None], dtypes.float32, 1), None, 406 ragged_tensor.RaggedTensorSpec([None, 4, None], dtypes.float32, 2)), 407 ({ 408 "a": 409 tensor_spec.TensorSpec([], dtypes.float32), 410 "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), 411 tensor_spec.TensorSpec([], dtypes.string)) 412 }, 128, { 413 "a": 414 tensor_spec.TensorSpec([128], dtypes.float32), 415 "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32), 416 tensor_spec.TensorSpec([128], dtypes.string)) 417 }), 418 ] 419 420 def reduce_fn(x, y): 421 element_structure, batch_size, expected_batched_structure = y 422 return x + combinations.combine( 423 element_structure=element_structure, 424 batch_size=batch_size, 425 expected_batched_structure=expected_batched_structure) 426 427 return functools.reduce(reduce_fn, cases, []) 428 429 430def _test_unbatch_combinations(): 431 cases = [ 432 (tensor_spec.TensorSpec([32], dtypes.float32), 433 tensor_spec.TensorSpec([], dtypes.float32)), 434 (tensor_spec.TensorSpec([None], dtypes.float32), 435 tensor_spec.TensorSpec([], dtypes.float32)), 436 (sparse_tensor.SparseTensorSpec([32, None], dtypes.float32), 437 sparse_tensor.SparseTensorSpec([None], dtypes.float32)), 438 (sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32), 439 sparse_tensor.SparseTensorSpec([4], dtypes.float32)), 440 (ragged_tensor.RaggedTensorSpec([32, None, None], dtypes.float32, 2), 441 ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)), 442 (ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.float32, 2), 443 ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)), 444 ({ 445 "a": 446 tensor_spec.TensorSpec([128], dtypes.float32), 447 "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32), 448 tensor_spec.TensorSpec([None], dtypes.string)) 449 }, { 450 "a": 451 tensor_spec.TensorSpec([], dtypes.float32), 452 "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), 453 tensor_spec.TensorSpec([], dtypes.string)) 454 }), 455 ] 456 457 def reduce_fn(x, y): 458 element_structure, expected_unbatched_structure = y 459 return x + combinations.combine( 460 element_structure=element_structure, 461 expected_unbatched_structure=expected_unbatched_structure) 462 463 return functools.reduce(reduce_fn, cases, []) 464 465 466def _test_to_batched_tensor_list_combinations(): 467 cases = [ 468 ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), 469 lambda: constant_op.constant([1.0, 2.0])), 470 ("SparseTensor", lambda: sparse_tensor.SparseTensor( 471 indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]), 472 lambda: sparse_tensor.SparseTensor( 473 indices=[[0]], values=[13], dense_shape=[2])), 474 ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]), 475 lambda: ragged_factory_ops.constant([[1]])), 476 ("Nest", lambda: 477 (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), 478 sparse_tensor.SparseTensor( 479 indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])), 480 lambda: 481 (constant_op.constant([1.0, 2.0]), 482 sparse_tensor.SparseTensor(indices=[[0]], values=[13], dense_shape=[2])) 483 ), 484 ] 485 486 def reduce_fn(x, y): 487 name, value_fn, element_0_fn = y 488 return x + combinations.combine( 489 value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn), 490 element_0_fn=combinations.NamedObject("element_0_fn.{}".format(name), 491 element_0_fn)) 492 493 return functools.reduce(reduce_fn, cases, []) 494 495 496# TODO(jsimsa): Add tests for OptionalStructure and DatasetStructure. 497class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): 498 499 # pylint: disable=g-long-lambda,protected-access 500 @combinations.generate( 501 combinations.times(test_base.default_test_combinations(), 502 _test_flat_structure_combinations())) 503 def testFlatStructure(self, value_fn, expected_structure_fn, 504 expected_types_fn, expected_shapes_fn): 505 value = value_fn() 506 expected_structure = expected_structure_fn() 507 expected_types = expected_types_fn() 508 expected_shapes = expected_shapes_fn() 509 s = structure.type_spec_from_value(value) 510 self.assertIsInstance(s, expected_structure) 511 flat_types = structure.get_flat_tensor_types(s) 512 self.assertEqual(expected_types, flat_types) 513 flat_shapes = structure.get_flat_tensor_shapes(s) 514 self.assertLen(flat_shapes, len(expected_shapes)) 515 for expected, actual in zip(expected_shapes, flat_shapes): 516 if expected is None: 517 self.assertEqual(actual.ndims, None) 518 else: 519 self.assertEqual(actual.as_list(), expected) 520 521 @combinations.generate( 522 combinations.times(test_base.graph_only_combinations(), 523 _test_is_compatible_with_structure_combinations())) 524 def testIsCompatibleWithStructure(self, original_value_fn, 525 compatible_values_fn, 526 incompatible_values_fn): 527 original_value = original_value_fn() 528 compatible_values = compatible_values_fn() 529 incompatible_values = incompatible_values_fn() 530 531 s = structure.type_spec_from_value(original_value) 532 for compatible_value in compatible_values: 533 self.assertTrue( 534 structure.are_compatible( 535 s, structure.type_spec_from_value(compatible_value))) 536 for incompatible_value in incompatible_values: 537 self.assertFalse( 538 structure.are_compatible( 539 s, structure.type_spec_from_value(incompatible_value))) 540 541 @combinations.generate( 542 combinations.times(test_base.default_test_combinations(), 543 _test_structure_from_value_equality_combinations())) 544 def testStructureFromValueEquality(self, value1_fn, value2_fn, 545 not_equal_value_fns): 546 # pylint: disable=g-generic-assert 547 not_equal_value_fns = not_equal_value_fns._obj 548 s1 = structure.type_spec_from_value(value1_fn()) 549 s2 = structure.type_spec_from_value(value2_fn()) 550 self.assertEqual(s1, s1) # check __eq__ operator. 551 self.assertEqual(s1, s2) # check __eq__ operator. 552 self.assertFalse(s1 != s1) # check __ne__ operator. 553 self.assertFalse(s1 != s2) # check __ne__ operator. 554 for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)): 555 self.assertEqual(hash(c1), hash(c1)) 556 self.assertEqual(hash(c1), hash(c2)) 557 for value_fn in not_equal_value_fns: 558 s3 = structure.type_spec_from_value(value_fn()) 559 self.assertNotEqual(s1, s3) # check __ne__ operator. 560 self.assertNotEqual(s2, s3) # check __ne__ operator. 561 self.assertFalse(s1 == s3) # check __eq_ operator. 562 self.assertFalse(s2 == s3) # check __eq_ operator. 563 564 @combinations.generate( 565 combinations.times(test_base.default_test_combinations(), 566 _test_ragged_structure_inequality_combinations())) 567 def testRaggedStructureInequality(self, spec1, spec2): 568 # pylint: disable=g-generic-assert 569 self.assertNotEqual(spec1, spec2) # check __ne__ operator. 570 self.assertFalse(spec1 == spec2) # check __eq__ operator. 571 572 @combinations.generate( 573 combinations.times(test_base.default_test_combinations(), 574 _test_hash_combinations())) 575 def testHash(self, value1_fn, value2_fn, value3_fn): 576 s1 = structure.type_spec_from_value(value1_fn()) 577 s2 = structure.type_spec_from_value(value2_fn()) 578 s3 = structure.type_spec_from_value(value3_fn()) 579 for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2), nest.flatten(s3)): 580 self.assertEqual(hash(c1), hash(c1)) 581 self.assertEqual(hash(c1), hash(c2)) 582 self.assertNotEqual(hash(c1), hash(c3)) 583 self.assertNotEqual(hash(c2), hash(c3)) 584 585 @combinations.generate( 586 combinations.times(test_base.default_test_combinations(), 587 _test_round_trip_conversion_combinations())) 588 def testRoundTripConversion(self, value_fn): 589 value = value_fn() 590 s = structure.type_spec_from_value(value) 591 592 def maybe_stack_ta(v): 593 if isinstance(v, tensor_array_ops.TensorArray): 594 return v.stack() 595 return v 596 597 before = self.evaluate(maybe_stack_ta(value)) 598 after = self.evaluate( 599 maybe_stack_ta( 600 structure.from_tensor_list(s, structure.to_tensor_list(s, value)))) 601 602 flat_before = nest.flatten(before) 603 flat_after = nest.flatten(after) 604 for b, a in zip(flat_before, flat_after): 605 if isinstance(b, sparse_tensor.SparseTensorValue): 606 self.assertAllEqual(b.indices, a.indices) 607 self.assertAllEqual(b.values, a.values) 608 self.assertAllEqual(b.dense_shape, a.dense_shape) 609 elif isinstance( 610 b, 611 (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)): 612 self.assertAllEqual(b, a) 613 else: 614 self.assertAllEqual(b, a) 615 616 # pylint: enable=g-long-lambda 617 618 def preserveStaticShape(self): 619 rt = ragged_factory_ops.constant([[1, 2], [], [3]]) 620 rt_s = structure.type_spec_from_value(rt) 621 rt_after = structure.from_tensor_list(rt_s, 622 structure.to_tensor_list(rt_s, rt)) 623 self.assertEqual(rt_after.row_splits.shape.as_list(), 624 rt.row_splits.shape.as_list()) 625 self.assertEqual(rt_after.values.shape.as_list(), [None]) 626 627 st = sparse_tensor.SparseTensor( 628 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]) 629 st_s = structure.type_spec_from_value(st) 630 st_after = structure.from_tensor_list(st_s, 631 structure.to_tensor_list(st_s, st)) 632 self.assertEqual(st_after.indices.shape.as_list(), [None, 2]) 633 self.assertEqual(st_after.values.shape.as_list(), [None]) 634 self.assertEqual(st_after.dense_shape.shape.as_list(), 635 st.dense_shape.shape.as_list()) 636 637 @combinations.generate(test_base.default_test_combinations()) 638 def testPreserveTensorArrayShape(self): 639 ta = tensor_array_ops.TensorArray( 640 dtype=dtypes.int32, size=1, element_shape=(3,)) 641 ta_s = structure.type_spec_from_value(ta) 642 ta_after = structure.from_tensor_list(ta_s, 643 structure.to_tensor_list(ta_s, ta)) 644 self.assertEqual(ta_after.element_shape.as_list(), [3]) 645 646 @combinations.generate(test_base.default_test_combinations()) 647 def testPreserveInferredTensorArrayShape(self): 648 ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1) 649 # Shape is inferred from the write. 650 ta = ta.write(0, [1, 2, 3]) 651 ta_s = structure.type_spec_from_value(ta) 652 ta_after = structure.from_tensor_list(ta_s, 653 structure.to_tensor_list(ta_s, ta)) 654 self.assertEqual(ta_after.element_shape.as_list(), [3]) 655 656 @combinations.generate(test_base.default_test_combinations()) 657 def testIncompatibleStructure(self): 658 # Define three mutually incompatible values/structures, and assert that: 659 # 1. Using one structure to flatten a value with an incompatible structure 660 # fails. 661 # 2. Using one structure to restructure a flattened value with an 662 # incompatible structure fails. 663 value_tensor = constant_op.constant(42.0) 664 s_tensor = structure.type_spec_from_value(value_tensor) 665 flat_tensor = structure.to_tensor_list(s_tensor, value_tensor) 666 667 value_sparse_tensor = sparse_tensor.SparseTensor( 668 indices=[[0, 0]], values=[1], dense_shape=[1, 1]) 669 s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor) 670 flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor, 671 value_sparse_tensor) 672 673 value_nest = { 674 "a": constant_op.constant(37.0), 675 "b": constant_op.constant([1, 2, 3]) 676 } 677 s_nest = structure.type_spec_from_value(value_nest) 678 flat_nest = structure.to_tensor_list(s_nest, value_nest) 679 680 with self.assertRaisesRegex( 681 ValueError, r"SparseTensor.* is not convertible to a tensor with " 682 r"dtype.*float32.* and shape \(\)"): 683 structure.to_tensor_list(s_tensor, value_sparse_tensor) 684 with self.assertRaisesRegex( 685 ValueError, "The two structures don't have the same nested structure."): 686 structure.to_tensor_list(s_tensor, value_nest) 687 688 with self.assertRaisesRegex(TypeError, 689 "neither a SparseTensor nor SparseTensorValue"): 690 structure.to_tensor_list(s_sparse_tensor, value_tensor) 691 692 with self.assertRaisesRegex( 693 ValueError, "The two structures don't have the same nested structure."): 694 structure.to_tensor_list(s_sparse_tensor, value_nest) 695 696 with self.assertRaisesRegex( 697 ValueError, "The two structures don't have the same nested structure."): 698 structure.to_tensor_list(s_nest, value_tensor) 699 700 with self.assertRaisesRegex( 701 ValueError, "The two structures don't have the same nested structure."): 702 structure.to_tensor_list(s_nest, value_sparse_tensor) 703 704 with self.assertRaisesRegex( 705 ValueError, 706 "Cannot create a Tensor from the tensor list because item 0 " 707 ".*tf.Tensor.* is incompatible with the expected TypeSpec " 708 ".*TensorSpec.*"): 709 structure.from_tensor_list(s_tensor, flat_sparse_tensor) 710 711 with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."): 712 structure.from_tensor_list(s_tensor, flat_nest) 713 714 with self.assertRaisesRegex( 715 ValueError, "Cannot create a SparseTensor from the tensor list because " 716 "item 0 .*tf.Tensor.* is incompatible with the expected TypeSpec " 717 ".*TensorSpec.*"): 718 structure.from_tensor_list(s_sparse_tensor, flat_tensor) 719 720 with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."): 721 structure.from_tensor_list(s_sparse_tensor, flat_nest) 722 723 with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."): 724 structure.from_tensor_list(s_nest, flat_tensor) 725 726 with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."): 727 structure.from_tensor_list(s_nest, flat_sparse_tensor) 728 729 @combinations.generate(test_base.default_test_combinations()) 730 def testIncompatibleNestedStructure(self): 731 # Define three mutually incompatible nested values/structures, and assert 732 # that: 733 # 1. Using one structure to flatten a value with an incompatible structure 734 # fails. 735 # 2. Using one structure to restructure a flattened value with an 736 # incompatible structure fails. 737 738 value_0 = { 739 "a": constant_op.constant(37.0), 740 "b": constant_op.constant([1, 2, 3]) 741 } 742 s_0 = structure.type_spec_from_value(value_0) 743 flat_s_0 = structure.to_tensor_list(s_0, value_0) 744 745 # `value_1` has compatible nested structure with `value_0`, but different 746 # classes. 747 value_1 = { 748 "a": 749 constant_op.constant(37.0), 750 "b": 751 sparse_tensor.SparseTensor( 752 indices=[[0, 0]], values=[1], dense_shape=[1, 1]) 753 } 754 s_1 = structure.type_spec_from_value(value_1) 755 flat_s_1 = structure.to_tensor_list(s_1, value_1) 756 757 # `value_2` has incompatible nested structure with `value_0` and `value_1`. 758 value_2 = { 759 "a": 760 constant_op.constant(37.0), 761 "b": (sparse_tensor.SparseTensor( 762 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 763 sparse_tensor.SparseTensor( 764 indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) 765 } 766 s_2 = structure.type_spec_from_value(value_2) 767 flat_s_2 = structure.to_tensor_list(s_2, value_2) 768 769 with self.assertRaisesRegex( 770 ValueError, r"SparseTensor.* is not convertible to a tensor with " 771 r"dtype.*int32.* and shape \(3,\)"): 772 structure.to_tensor_list(s_0, value_1) 773 774 with self.assertRaisesRegex( 775 ValueError, "The two structures don't have the same nested structure."): 776 structure.to_tensor_list(s_0, value_2) 777 778 with self.assertRaisesRegex(TypeError, 779 "neither a SparseTensor nor SparseTensorValue"): 780 structure.to_tensor_list(s_1, value_0) 781 782 with self.assertRaisesRegex( 783 ValueError, "The two structures don't have the same nested structure."): 784 structure.to_tensor_list(s_1, value_2) 785 786 # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp 787 # needs to account for "a" coming before or after "b". It might be worth 788 # adding a deterministic repr for these error messages (among other 789 # improvements). 790 with self.assertRaisesRegex( 791 ValueError, "The two structures don't have the same nested structure."): 792 structure.to_tensor_list(s_2, value_0) 793 794 with self.assertRaisesRegex( 795 ValueError, "The two structures don't have the same nested structure."): 796 structure.to_tensor_list(s_2, value_1) 797 798 with self.assertRaisesRegex(ValueError, 799 r"Cannot create a Tensor from the tensor list"): 800 structure.from_tensor_list(s_0, flat_s_1) 801 802 with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3"): 803 structure.from_tensor_list(s_0, flat_s_2) 804 805 with self.assertRaisesRegex( 806 ValueError, "Cannot create a SparseTensor from the tensor list"): 807 structure.from_tensor_list(s_1, flat_s_0) 808 809 with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3"): 810 structure.from_tensor_list(s_1, flat_s_2) 811 812 with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2"): 813 structure.from_tensor_list(s_2, flat_s_0) 814 815 with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2"): 816 structure.from_tensor_list(s_2, flat_s_1) 817 818 @combinations.generate( 819 combinations.times(test_base.default_test_combinations(), 820 _test_convert_legacy_structure_combinations())) 821 def testConvertLegacyStructure(self, output_types, output_shapes, 822 output_classes, expected_structure): 823 actual_structure = structure.convert_legacy_structure( 824 output_types, output_shapes, output_classes) 825 self.assertEqual(actual_structure, expected_structure) 826 827 @combinations.generate(test_base.default_test_combinations()) 828 def testConvertLegacyStructureFail(self): 829 with self.assertRaisesRegex( 830 TypeError, "Could not build a structure for output class " 831 "_EagerTensorArray. Make sure any component class in " 832 "`output_classes` inherits from one of the following classes: " 833 "`tf.TypeSpec`, `tf.sparse.SparseTensor`, `tf.Tensor`, " 834 "`tf.TensorArray`."): 835 structure.convert_legacy_structure(dtypes.int32, 836 tensor_shape.TensorShape([2, None]), 837 tensor_array_ops._EagerTensorArray) 838 839 @combinations.generate(test_base.default_test_combinations()) 840 def testNestedNestedStructure(self): 841 s = (tensor_spec.TensorSpec([], dtypes.int64), 842 (tensor_spec.TensorSpec([], dtypes.float32), 843 tensor_spec.TensorSpec([], dtypes.string))) 844 845 int64_t = constant_op.constant(37, dtype=dtypes.int64) 846 float32_t = constant_op.constant(42.0) 847 string_t = constant_op.constant("Foo") 848 849 nested_tensors = (int64_t, (float32_t, string_t)) 850 851 tensor_list = structure.to_tensor_list(s, nested_tensors) 852 for expected, actual in zip([int64_t, float32_t, string_t], tensor_list): 853 self.assertIs(expected, actual) 854 855 (actual_int64_t, 856 (actual_float32_t, 857 actual_string_t)) = structure.from_tensor_list(s, tensor_list) 858 self.assertIs(int64_t, actual_int64_t) 859 self.assertIs(float32_t, actual_float32_t) 860 self.assertIs(string_t, actual_string_t) 861 862 (actual_int64_t, (actual_float32_t, actual_string_t)) = ( 863 structure.from_compatible_tensor_list(s, tensor_list)) 864 self.assertIs(int64_t, actual_int64_t) 865 self.assertIs(float32_t, actual_float32_t) 866 self.assertIs(string_t, actual_string_t) 867 868 @combinations.generate( 869 combinations.times(test_base.default_test_combinations(), 870 _test_batch_combinations())) 871 def testBatch(self, element_structure, batch_size, 872 expected_batched_structure): 873 batched_structure = nest.map_structure( 874 lambda component_spec: component_spec._batch(batch_size), 875 element_structure) 876 self.assertEqual(batched_structure, expected_batched_structure) 877 878 @combinations.generate( 879 combinations.times(test_base.default_test_combinations(), 880 _test_unbatch_combinations())) 881 def testUnbatch(self, element_structure, expected_unbatched_structure): 882 unbatched_structure = nest.map_structure( 883 lambda component_spec: component_spec._unbatch(), element_structure) 884 self.assertEqual(unbatched_structure, expected_unbatched_structure) 885 886 # pylint: disable=g-long-lambda 887 @combinations.generate( 888 combinations.times(test_base.default_test_combinations(), 889 _test_to_batched_tensor_list_combinations())) 890 def testToBatchedTensorList(self, value_fn, element_0_fn): 891 batched_value = value_fn() 892 s = structure.type_spec_from_value(batched_value) 893 batched_tensor_list = structure.to_batched_tensor_list(s, batched_value) 894 895 # The batch dimension is 2 for all of the test cases. 896 # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT 897 # tensors in which we store sparse tensors. 898 for t in batched_tensor_list: 899 if t.dtype != dtypes.variant: 900 self.assertEqual(2, self.evaluate(array_ops.shape(t)[0])) 901 902 # Test that the 0th element from the unbatched tensor is equal to the 903 # expected value. 904 expected_element_0 = self.evaluate(element_0_fn()) 905 unbatched_s = nest.map_structure( 906 lambda component_spec: component_spec._unbatch(), s) 907 actual_element_0 = structure.from_tensor_list( 908 unbatched_s, [t[0] for t in batched_tensor_list]) 909 910 for expected, actual in zip( 911 nest.flatten(expected_element_0), nest.flatten(actual_element_0)): 912 self.assertValuesEqual(expected, actual) 913 914 # pylint: enable=g-long-lambda 915 916 @combinations.generate(test_base.default_test_combinations()) 917 def testDatasetSpecConstructor(self): 918 rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32) 919 st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32) 920 t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string) 921 element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec} 922 ds_struct = dataset_ops.DatasetSpec(element_spec, [5]) 923 self.assertEqual(ds_struct._element_spec, element_spec) 924 # Note: shape was automatically converted from a list to a TensorShape. 925 self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5])) 926 927 @combinations.generate(test_base.default_test_combinations()) 928 def testCustomMapping(self): 929 elem = CustomMap(foo=constant_op.constant(37.)) 930 spec = structure.type_spec_from_value(elem) 931 self.assertIsInstance(spec, CustomMap) 932 self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32)) 933 934 @combinations.generate(test_base.default_test_combinations()) 935 def testObjectProxy(self): 936 nt_type = collections.namedtuple("A", ["x", "y"]) 937 proxied = wrapt.ObjectProxy(nt_type(1, 2)) 938 proxied_spec = structure.type_spec_from_value(proxied) 939 self.assertEqual( 940 structure.type_spec_from_value(nt_type(1, 2)), proxied_spec) 941 942 @combinations.generate(test_base.default_test_combinations()) 943 def testTypeSpecNotBuild(self): 944 with self.assertRaisesRegex( 945 TypeError, "Could not build a `TypeSpec` for 100 with type int"): 946 structure.type_spec_from_value(100, use_fallback=False) 947 948 @combinations.generate(test_base.default_test_combinations()) 949 def testTypeSpecNotCompatible(self): 950 test_obj = structure.NoneTensorSpec() 951 with self.assertRaisesRegex( 952 ValueError, r"No `TypeSpec` is compatible with both NoneTensorSpec\(\) " 953 "and 100"): 954 test_obj.most_specific_compatible_shape(100) 955 self.assertEqual(test_obj, 956 test_obj.most_specific_compatible_shape(test_obj)) 957 958 959class CustomMap(collections_abc.Mapping): 960 """Custom, immutable map.""" 961 962 def __init__(self, *args, **kwargs): 963 self.__dict__.update(dict(*args, **kwargs)) 964 965 def __getitem__(self, x): 966 return self.__dict__[x] 967 968 def __iter__(self): 969 return iter(self.__dict__) 970 971 def __len__(self): 972 return len(self.__dict__) 973 974 975if __name__ == "__main__": 976 test.main() 977