1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for utilities working with arbitrarily nested structures.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23 24import numpy as np 25import wrapt 26from absl.testing import parameterized 27 28from tensorflow.python.data.kernel_tests import test_base 29from tensorflow.python.data.ops import dataset_ops 30from tensorflow.python.data.util import nest 31from tensorflow.python.data.util import structure 32from tensorflow.python.framework import combinations 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import sparse_tensor 37from tensorflow.python.framework import tensor_shape 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import tensor_array_ops 41from tensorflow.python.ops import variables 42from tensorflow.python.ops.ragged import ragged_factory_ops 43from tensorflow.python.ops.ragged import ragged_tensor 44from tensorflow.python.ops.ragged import ragged_tensor_value 45from tensorflow.python.platform import test 46from tensorflow.python.util.compat import collections_abc 47 48 49# NOTE(mrry): Arguments of parameterized tests are lifted into lambdas to make 50# sure they are not executed before the (eager- or graph-mode) test environment 51# has been set up. 52# 53 54 55def _test_flat_structure_combinations(): 56 cases = [ 57 ("Tensor", lambda: constant_op.constant(37.0), 58 lambda: tensor_spec.TensorSpec, lambda: [dtypes.float32], lambda: [[]]), 59 ("TensorArray", lambda: tensor_array_ops.TensorArray( 60 dtype=dtypes.float32, element_shape=(3,), size=0), 61 lambda: tensor_array_ops.TensorArraySpec, lambda: [dtypes.variant], 62 lambda: [[]]), 63 ("SparseTensor", lambda: sparse_tensor.SparseTensor( 64 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 65 lambda: sparse_tensor.SparseTensorSpec, lambda: [dtypes.variant], 66 lambda: [None]), 67 ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]), 68 lambda: ragged_tensor.RaggedTensorSpec, lambda: [dtypes.variant], 69 lambda: [None]), 70 ("Nested_0", lambda: 71 (constant_op.constant(37.0), constant_op.constant([1, 2, 3])), 72 lambda: tuple, lambda: [dtypes.float32, dtypes.int32], 73 lambda: [[], [3]]), 74 ("Nested_1", lambda: { 75 "a": constant_op.constant(37.0), 76 "b": constant_op.constant([1, 2, 3]) 77 }, lambda: dict, lambda: [dtypes.float32, dtypes.int32], 78 lambda: [[], [3]]), 79 ("Nested_2", lambda: { 80 "a": 81 constant_op.constant(37.0), 82 "b": (sparse_tensor.SparseTensor( 83 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 84 sparse_tensor.SparseTensor( 85 indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) 86 }, lambda: dict, lambda: [dtypes.float32, dtypes.variant, dtypes.variant], 87 lambda: [[], None, None]), 88 ] 89 90 def reduce_fn(x, y): 91 # workaround for long line 92 name, value_fn = y[:2] 93 expected_structure_fn, expected_types_fn, expected_shapes_fn = y[2:] 94 return x + combinations.combine( 95 value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn), 96 expected_structure_fn=combinations.NamedObject( 97 "expected_structure_fn.{}".format(name), expected_structure_fn), 98 expected_types_fn=combinations.NamedObject( 99 "expected_types_fn.{}".format(name), expected_types_fn), 100 expected_shapes_fn=combinations.NamedObject( 101 "expected_shapes_fn.{}".format(name), expected_shapes_fn)) 102 103 return functools.reduce(reduce_fn, cases, []) 104 105 106def _test_is_compatible_with_structure_combinations(): 107 cases = [ 108 ("Tensor", lambda: constant_op.constant(37.0), lambda: [ 109 constant_op.constant(38.0), 110 array_ops.placeholder(dtypes.float32), 111 variables.Variable(100.0), 42.0, 112 np.array(42.0, dtype=np.float32) 113 ], lambda: [constant_op.constant([1.0, 2.0]), 114 constant_op.constant(37)]), 115 ("TensorArray", lambda: tensor_array_ops.TensorArray( 116 dtype=dtypes.float32, element_shape=(3,), size=0), lambda: [ 117 tensor_array_ops.TensorArray( 118 dtype=dtypes.float32, element_shape=(3,), size=0), 119 tensor_array_ops.TensorArray( 120 dtype=dtypes.float32, element_shape=(3,), size=10) 121 ], lambda: [ 122 tensor_array_ops.TensorArray( 123 dtype=dtypes.int32, element_shape=(3,), size=0), 124 tensor_array_ops.TensorArray( 125 dtype=dtypes.float32, element_shape=(), size=0) 126 ]), 127 ("SparseTensor", lambda: sparse_tensor.SparseTensor( 128 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 129 lambda: [ 130 sparse_tensor.SparseTensor( 131 indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), 132 sparse_tensor.SparseTensorValue( 133 indices=[[1, 1], [3, 4]], values=[10, -1], dense_shape=[4, 5]), 134 array_ops.sparse_placeholder(dtype=dtypes.int32), 135 array_ops.sparse_placeholder(dtype=dtypes.int32, shape=[None, None]) 136 ], lambda: [ 137 constant_op.constant(37, shape=[4, 5]), 138 sparse_tensor.SparseTensor( 139 indices=[[3, 4]], values=[-1], dense_shape=[5, 6]), 140 array_ops.sparse_placeholder( 141 dtype=dtypes.int32, shape=[None, None, None]), 142 sparse_tensor.SparseTensor( 143 indices=[[3, 4]], values=[-1.0], dense_shape=[4, 5]) 144 ]), 145 ("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), 146 lambda: [ 147 ragged_factory_ops.constant([[1, 2], [3, 4], []]), 148 ragged_factory_ops.constant([[1], [2, 3, 4], [5]]), 149 ], lambda: [ 150 ragged_factory_ops.constant(1), 151 ragged_factory_ops.constant([1, 2]), 152 ragged_factory_ops.constant([[1], [2]]), 153 ragged_factory_ops.constant([["a", "b"]]), 154 ]), 155 ("Nested", lambda: { 156 "a": constant_op.constant(37.0), 157 "b": constant_op.constant([1, 2, 3]) 158 }, lambda: [{ 159 "a": constant_op.constant(15.0), 160 "b": constant_op.constant([4, 5, 6]) 161 }], lambda: [{ 162 "a": constant_op.constant(15.0), 163 "b": constant_op.constant([4, 5, 6, 7]) 164 }, { 165 "a": constant_op.constant(15), 166 "b": constant_op.constant([4, 5, 6]) 167 }, { 168 "a": 169 constant_op.constant(15), 170 "b": 171 sparse_tensor.SparseTensor( 172 indices=[[0], [1], [2]], values=[4, 5, 6], dense_shape=[3]) 173 }, (constant_op.constant(15.0), constant_op.constant([4, 5, 6]))]), 174 ] 175 176 def reduce_fn(x, y): 177 name, original_value_fn, compatible_values_fn, incompatible_values_fn = y 178 return x + combinations.combine( 179 original_value_fn=combinations.NamedObject( 180 "original_value_fn.{}".format(name), original_value_fn), 181 compatible_values_fn=combinations.NamedObject( 182 "compatible_values_fn.{}".format(name), compatible_values_fn), 183 incompatible_values_fn=combinations.NamedObject( 184 "incompatible_values_fn.{}".format(name), incompatible_values_fn)) 185 186 return functools.reduce(reduce_fn, cases, []) 187 188 189def _test_structure_from_value_equality_combinations(): 190 cases = [ 191 ("Tensor", lambda: constant_op.constant(37.0), 192 lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])), 193 ("TensorArray", lambda: tensor_array_ops.TensorArray( 194 dtype=dtypes.float32, element_shape=(3,), size=0), 195 lambda: tensor_array_ops.TensorArray( 196 dtype=dtypes.float32, element_shape=(3,), size=0), 197 lambda: tensor_array_ops.TensorArray( 198 dtype=dtypes.int32, element_shape=(), size=0)), 199 ("SparseTensor", lambda: sparse_tensor.SparseTensor( 200 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 201 lambda: sparse_tensor.SparseTensor( 202 indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda: 203 sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5]), 204 lambda: sparse_tensor.SparseTensor( 205 indices=[[3, 4]], values=[1.0], dense_shape=[4, 5])), 206 ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1, 2]], [[3]]]), 207 lambda: ragged_factory_ops.constant([[[5]], [[8], [3, 2]]]), 208 lambda: ragged_factory_ops.constant([[[1]], [[2], [3]]], ragged_rank=1), 209 lambda: ragged_factory_ops.constant([[[1.0, 2.0]], [[3.0]]]), 210 lambda: ragged_factory_ops.constant([[[1]], [[2]], [[3]]])), 211 ("Nested", lambda: { 212 "a": constant_op.constant(37.0), 213 "b": constant_op.constant([1, 2, 3]) 214 }, lambda: { 215 "a": constant_op.constant(42.0), 216 "b": constant_op.constant([4, 5, 6]) 217 }, lambda: { 218 "a": constant_op.constant([1, 2, 3]), 219 "b": constant_op.constant(37.0) 220 }), 221 ] 222 223 def reduce_fn(x, y): 224 name, value1_fn, value2_fn, *not_equal_value_fns = y 225 return x + combinations.combine( 226 value1_fn=combinations.NamedObject("value1_fn.{}".format(name), 227 value1_fn), 228 value2_fn=combinations.NamedObject("value2_fn.{}".format(name), 229 value2_fn), 230 not_equal_value_fns=combinations.NamedObject( 231 "not_equal_value_fns.{}".format(name), not_equal_value_fns)) 232 233 return functools.reduce(reduce_fn, cases, []) 234 235 236def _test_ragged_structure_inequality_combinations(): 237 cases = [ 238 (ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1), 239 ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 2)), 240 (ragged_tensor.RaggedTensorSpec([3, None], dtypes.int32, 1), 241 ragged_tensor.RaggedTensorSpec([5, None], dtypes.int32, 1)), 242 (ragged_tensor.RaggedTensorSpec(None, dtypes.int32, 1), 243 ragged_tensor.RaggedTensorSpec(None, dtypes.float32, 1)), 244 ] 245 246 def reduce_fn(x, y): 247 spec1, spec2 = y 248 return x + combinations.combine(spec1=spec1, spec2=spec2) 249 250 return functools.reduce(reduce_fn, cases, []) 251 252 253def _test_hash_combinations(): 254 cases = [ 255 ("Tensor", lambda: constant_op.constant(37.0), 256 lambda: constant_op.constant(42.0), lambda: constant_op.constant([5])), 257 ("TensorArray", lambda: tensor_array_ops.TensorArray( 258 dtype=dtypes.float32, element_shape=(3,), size=0), 259 lambda: tensor_array_ops.TensorArray( 260 dtype=dtypes.float32, element_shape=(3,), size=0), 261 lambda: tensor_array_ops.TensorArray( 262 dtype=dtypes.int32, element_shape=(), size=0)), 263 ("SparseTensor", lambda: sparse_tensor.SparseTensor( 264 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 265 lambda: sparse_tensor.SparseTensor( 266 indices=[[1, 2]], values=[42], dense_shape=[4, 5]), lambda: 267 sparse_tensor.SparseTensor(indices=[[3]], values=[-1], dense_shape=[5])), 268 ("Nested", lambda: { 269 "a": constant_op.constant(37.0), 270 "b": constant_op.constant([1, 2, 3]) 271 }, lambda: { 272 "a": constant_op.constant(42.0), 273 "b": constant_op.constant([4, 5, 6]) 274 }, lambda: { 275 "a": constant_op.constant([1, 2, 3]), 276 "b": constant_op.constant(37.0) 277 }), 278 ] 279 280 def reduce_fn(x, y): 281 name, value1_fn, value2_fn, value3_fn = y 282 return x + combinations.combine( 283 value1_fn=combinations.NamedObject("value1_fn.{}".format(name), 284 value1_fn), 285 value2_fn=combinations.NamedObject("value2_fn.{}".format(name), 286 value2_fn), 287 value3_fn=combinations.NamedObject("value3_fn.{}".format(name), 288 value3_fn)) 289 290 return functools.reduce(reduce_fn, cases, []) 291 292 293def _test_round_trip_conversion_combinations(): 294 cases = [ 295 ( 296 "Tensor", 297 lambda: constant_op.constant(37.0), 298 ), 299 ( 300 "SparseTensor", 301 lambda: sparse_tensor.SparseTensor( 302 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), 303 ), 304 ("TensorArray", lambda: tensor_array_ops.TensorArray( 305 dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)), 306 ( 307 "RaggedTensor", 308 lambda: ragged_factory_ops.constant([[1, 2], [], [3]]), 309 ), 310 ( 311 "Nested_0", 312 lambda: { 313 "a": constant_op.constant(37.0), 314 "b": constant_op.constant([1, 2, 3]) 315 }, 316 ), 317 ( 318 "Nested_1", 319 lambda: { 320 "a": 321 constant_op.constant(37.0), 322 "b": (sparse_tensor.SparseTensor( 323 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 324 sparse_tensor.SparseTensor( 325 indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) 326 }, 327 ), 328 ] 329 330 def reduce_fn(x, y): 331 name, value_fn = y 332 return x + combinations.combine( 333 value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn)) 334 335 return functools.reduce(reduce_fn, cases, []) 336 337 338def _test_convert_legacy_structure_combinations(): 339 cases = [ 340 (dtypes.float32, tensor_shape.TensorShape([]), ops.Tensor, 341 tensor_spec.TensorSpec([], dtypes.float32)), 342 (dtypes.int32, tensor_shape.TensorShape([2, 343 2]), sparse_tensor.SparseTensor, 344 sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32)), 345 (dtypes.int32, tensor_shape.TensorShape([None, True, 2, 2]), 346 tensor_array_ops.TensorArray, 347 tensor_array_ops.TensorArraySpec([2, 2], 348 dtypes.int32, 349 dynamic_size=None, 350 infer_shape=True)), 351 (dtypes.int32, tensor_shape.TensorShape([True, None, 2, 2]), 352 tensor_array_ops.TensorArray, 353 tensor_array_ops.TensorArraySpec([2, 2], 354 dtypes.int32, 355 dynamic_size=True, 356 infer_shape=None)), 357 (dtypes.int32, tensor_shape.TensorShape([True, False, 2, 2]), 358 tensor_array_ops.TensorArray, 359 tensor_array_ops.TensorArraySpec([2, 2], 360 dtypes.int32, 361 dynamic_size=True, 362 infer_shape=False)), 363 (dtypes.int32, tensor_shape.TensorShape([2, None]), 364 ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1), 365 ragged_tensor.RaggedTensorSpec([2, None], dtypes.int32, 1)), 366 ({ 367 "a": dtypes.float32, 368 "b": (dtypes.int32, dtypes.string) 369 }, { 370 "a": tensor_shape.TensorShape([]), 371 "b": (tensor_shape.TensorShape([2, 2]), tensor_shape.TensorShape([])) 372 }, { 373 "a": ops.Tensor, 374 "b": (sparse_tensor.SparseTensor, ops.Tensor) 375 }, { 376 "a": 377 tensor_spec.TensorSpec([], dtypes.float32), 378 "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), 379 tensor_spec.TensorSpec([], dtypes.string)) 380 }) 381 ] 382 383 def reduce_fn(x, y): 384 output_types, output_shapes, output_classes, expected_structure = y 385 return x + combinations.combine( 386 output_types=output_types, 387 output_shapes=output_shapes, 388 output_classes=output_classes, 389 expected_structure=expected_structure) 390 391 return functools.reduce(reduce_fn, cases, []) 392 393 394def _test_batch_combinations(): 395 cases = [ 396 (tensor_spec.TensorSpec([], dtypes.float32), 32, 397 tensor_spec.TensorSpec([32], dtypes.float32)), 398 (tensor_spec.TensorSpec([], dtypes.float32), None, 399 tensor_spec.TensorSpec([None], dtypes.float32)), 400 (sparse_tensor.SparseTensorSpec([None], dtypes.float32), 32, 401 sparse_tensor.SparseTensorSpec([32, None], dtypes.float32)), 402 (sparse_tensor.SparseTensorSpec([4], dtypes.float32), None, 403 sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32)), 404 (ragged_tensor.RaggedTensorSpec([2, None], dtypes.float32, 1), 32, 405 ragged_tensor.RaggedTensorSpec([32, 2, None], dtypes.float32, 2)), 406 (ragged_tensor.RaggedTensorSpec([4, None], dtypes.float32, 1), None, 407 ragged_tensor.RaggedTensorSpec([None, 4, None], dtypes.float32, 2)), 408 ({ 409 "a": 410 tensor_spec.TensorSpec([], dtypes.float32), 411 "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), 412 tensor_spec.TensorSpec([], dtypes.string)) 413 }, 128, { 414 "a": 415 tensor_spec.TensorSpec([128], dtypes.float32), 416 "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32), 417 tensor_spec.TensorSpec([128], dtypes.string)) 418 }), 419 ] 420 421 def reduce_fn(x, y): 422 element_structure, batch_size, expected_batched_structure = y 423 return x + combinations.combine( 424 element_structure=element_structure, 425 batch_size=batch_size, 426 expected_batched_structure=expected_batched_structure) 427 428 return functools.reduce(reduce_fn, cases, []) 429 430 431def _test_unbatch_combinations(): 432 cases = [ 433 (tensor_spec.TensorSpec([32], dtypes.float32), 434 tensor_spec.TensorSpec([], dtypes.float32)), 435 (tensor_spec.TensorSpec([None], dtypes.float32), 436 tensor_spec.TensorSpec([], dtypes.float32)), 437 (sparse_tensor.SparseTensorSpec([32, None], dtypes.float32), 438 sparse_tensor.SparseTensorSpec([None], dtypes.float32)), 439 (sparse_tensor.SparseTensorSpec([None, 4], dtypes.float32), 440 sparse_tensor.SparseTensorSpec([4], dtypes.float32)), 441 (ragged_tensor.RaggedTensorSpec([32, None, None], dtypes.float32, 2), 442 ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)), 443 (ragged_tensor.RaggedTensorSpec([None, None, None], dtypes.float32, 2), 444 ragged_tensor.RaggedTensorSpec([None, None], dtypes.float32, 1)), 445 ({ 446 "a": 447 tensor_spec.TensorSpec([128], dtypes.float32), 448 "b": (sparse_tensor.SparseTensorSpec([128, 2, 2], dtypes.int32), 449 tensor_spec.TensorSpec([None], dtypes.string)) 450 }, { 451 "a": 452 tensor_spec.TensorSpec([], dtypes.float32), 453 "b": (sparse_tensor.SparseTensorSpec([2, 2], dtypes.int32), 454 tensor_spec.TensorSpec([], dtypes.string)) 455 }), 456 ] 457 458 def reduce_fn(x, y): 459 element_structure, expected_unbatched_structure = y 460 return x + combinations.combine( 461 element_structure=element_structure, 462 expected_unbatched_structure=expected_unbatched_structure) 463 464 return functools.reduce(reduce_fn, cases, []) 465 466 467def _test_to_batched_tensor_list_combinations(): 468 cases = [ 469 ("Tensor", lambda: constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), 470 lambda: constant_op.constant([1.0, 2.0])), 471 ("SparseTensor", lambda: sparse_tensor.SparseTensor( 472 indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2]), 473 lambda: sparse_tensor.SparseTensor( 474 indices=[[0]], values=[13], dense_shape=[2])), 475 ("RaggedTensor", lambda: ragged_factory_ops.constant([[[1]], [[2]]]), 476 lambda: ragged_factory_ops.constant([[1]])), 477 ("Nest", lambda: 478 (constant_op.constant([[1.0, 2.0], [3.0, 4.0]]), 479 sparse_tensor.SparseTensor( 480 indices=[[0, 0], [1, 1]], values=[13, 27], dense_shape=[2, 2])), 481 lambda: 482 (constant_op.constant([1.0, 2.0]), 483 sparse_tensor.SparseTensor(indices=[[0]], values=[13], dense_shape=[2])) 484 ), 485 ] 486 487 def reduce_fn(x, y): 488 name, value_fn, element_0_fn = y 489 return x + combinations.combine( 490 value_fn=combinations.NamedObject("value_fn.{}".format(name), value_fn), 491 element_0_fn=combinations.NamedObject("element_0_fn.{}".format(name), 492 element_0_fn)) 493 494 return functools.reduce(reduce_fn, cases, []) 495 496# TODO(jsimsa): Add tests for OptionalStructure and DatasetStructure. 497class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): 498 499 # pylint: disable=g-long-lambda,protected-access 500 @combinations.generate( 501 combinations.times(test_base.default_test_combinations(), 502 _test_flat_structure_combinations())) 503 def testFlatStructure(self, value_fn, expected_structure_fn, 504 expected_types_fn, expected_shapes_fn): 505 value = value_fn() 506 expected_structure = expected_structure_fn() 507 expected_types = expected_types_fn() 508 expected_shapes = expected_shapes_fn() 509 s = structure.type_spec_from_value(value) 510 self.assertIsInstance(s, expected_structure) 511 flat_types = structure.get_flat_tensor_types(s) 512 self.assertEqual(expected_types, flat_types) 513 flat_shapes = structure.get_flat_tensor_shapes(s) 514 self.assertLen(flat_shapes, len(expected_shapes)) 515 for expected, actual in zip(expected_shapes, flat_shapes): 516 if expected is None: 517 self.assertEqual(actual.ndims, None) 518 else: 519 self.assertEqual(actual.as_list(), expected) 520 521 @combinations.generate( 522 combinations.times(test_base.graph_only_combinations(), 523 _test_is_compatible_with_structure_combinations())) 524 def testIsCompatibleWithStructure(self, original_value_fn, 525 compatible_values_fn, 526 incompatible_values_fn): 527 original_value = original_value_fn() 528 compatible_values = compatible_values_fn() 529 incompatible_values = incompatible_values_fn() 530 531 s = structure.type_spec_from_value(original_value) 532 for compatible_value in compatible_values: 533 self.assertTrue( 534 structure.are_compatible( 535 s, structure.type_spec_from_value(compatible_value))) 536 for incompatible_value in incompatible_values: 537 self.assertFalse( 538 structure.are_compatible( 539 s, structure.type_spec_from_value(incompatible_value))) 540 541 @combinations.generate( 542 combinations.times(test_base.default_test_combinations(), 543 _test_structure_from_value_equality_combinations())) 544 def testStructureFromValueEquality(self, value1_fn, value2_fn, 545 not_equal_value_fns): 546 # pylint: disable=g-generic-assert 547 not_equal_value_fns = not_equal_value_fns._obj 548 s1 = structure.type_spec_from_value(value1_fn()) 549 s2 = structure.type_spec_from_value(value2_fn()) 550 self.assertEqual(s1, s1) # check __eq__ operator. 551 self.assertEqual(s1, s2) # check __eq__ operator. 552 self.assertFalse(s1 != s1) # check __ne__ operator. 553 self.assertFalse(s1 != s2) # check __ne__ operator. 554 for c1, c2 in zip(nest.flatten(s1), nest.flatten(s2)): 555 self.assertEqual(hash(c1), hash(c1)) 556 self.assertEqual(hash(c1), hash(c2)) 557 for value_fn in not_equal_value_fns: 558 s3 = structure.type_spec_from_value(value_fn()) 559 self.assertNotEqual(s1, s3) # check __ne__ operator. 560 self.assertNotEqual(s2, s3) # check __ne__ operator. 561 self.assertFalse(s1 == s3) # check __eq_ operator. 562 self.assertFalse(s2 == s3) # check __eq_ operator. 563 564 @combinations.generate( 565 combinations.times(test_base.default_test_combinations(), 566 _test_ragged_structure_inequality_combinations())) 567 def testRaggedStructureInequality(self, spec1, spec2): 568 # pylint: disable=g-generic-assert 569 self.assertNotEqual(spec1, spec2) # check __ne__ operator. 570 self.assertFalse(spec1 == spec2) # check __eq__ operator. 571 572 @combinations.generate( 573 combinations.times(test_base.default_test_combinations(), 574 _test_hash_combinations())) 575 def testHash(self, value1_fn, value2_fn, value3_fn): 576 s1 = structure.type_spec_from_value(value1_fn()) 577 s2 = structure.type_spec_from_value(value2_fn()) 578 s3 = structure.type_spec_from_value(value3_fn()) 579 for c1, c2, c3 in zip(nest.flatten(s1), nest.flatten(s2), nest.flatten(s3)): 580 self.assertEqual(hash(c1), hash(c1)) 581 self.assertEqual(hash(c1), hash(c2)) 582 self.assertNotEqual(hash(c1), hash(c3)) 583 self.assertNotEqual(hash(c2), hash(c3)) 584 585 @combinations.generate( 586 combinations.times(test_base.default_test_combinations(), 587 _test_round_trip_conversion_combinations())) 588 def testRoundTripConversion(self, value_fn): 589 value = value_fn() 590 s = structure.type_spec_from_value(value) 591 592 def maybe_stack_ta(v): 593 if isinstance(v, tensor_array_ops.TensorArray): 594 return v.stack() 595 return v 596 597 before = self.evaluate(maybe_stack_ta(value)) 598 after = self.evaluate( 599 maybe_stack_ta( 600 structure.from_tensor_list(s, structure.to_tensor_list(s, value)))) 601 602 flat_before = nest.flatten(before) 603 flat_after = nest.flatten(after) 604 for b, a in zip(flat_before, flat_after): 605 if isinstance(b, sparse_tensor.SparseTensorValue): 606 self.assertAllEqual(b.indices, a.indices) 607 self.assertAllEqual(b.values, a.values) 608 self.assertAllEqual(b.dense_shape, a.dense_shape) 609 elif isinstance( 610 b, 611 (ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)): 612 self.assertAllEqual(b, a) 613 else: 614 self.assertAllEqual(b, a) 615 616 # pylint: enable=g-long-lambda 617 618 def preserveStaticShape(self): 619 rt = ragged_factory_ops.constant([[1, 2], [], [3]]) 620 rt_s = structure.type_spec_from_value(rt) 621 rt_after = structure.from_tensor_list(rt_s, 622 structure.to_tensor_list(rt_s, rt)) 623 self.assertEqual(rt_after.row_splits.shape.as_list(), 624 rt.row_splits.shape.as_list()) 625 self.assertEqual(rt_after.values.shape.as_list(), [None]) 626 627 st = sparse_tensor.SparseTensor( 628 indices=[[3, 4]], values=[-1], dense_shape=[4, 5]) 629 st_s = structure.type_spec_from_value(st) 630 st_after = structure.from_tensor_list(st_s, 631 structure.to_tensor_list(st_s, st)) 632 self.assertEqual(st_after.indices.shape.as_list(), [None, 2]) 633 self.assertEqual(st_after.values.shape.as_list(), [None]) 634 self.assertEqual(st_after.dense_shape.shape.as_list(), 635 st.dense_shape.shape.as_list()) 636 637 @combinations.generate(test_base.default_test_combinations()) 638 def testPreserveTensorArrayShape(self): 639 ta = tensor_array_ops.TensorArray( 640 dtype=dtypes.int32, size=1, element_shape=(3,)) 641 ta_s = structure.type_spec_from_value(ta) 642 ta_after = structure.from_tensor_list(ta_s, 643 structure.to_tensor_list(ta_s, ta)) 644 self.assertEqual(ta_after.element_shape.as_list(), [3]) 645 646 @combinations.generate(test_base.default_test_combinations()) 647 def testPreserveInferredTensorArrayShape(self): 648 ta = tensor_array_ops.TensorArray(dtype=dtypes.int32, size=1) 649 # Shape is inferred from the write. 650 ta = ta.write(0, [1, 2, 3]) 651 ta_s = structure.type_spec_from_value(ta) 652 ta_after = structure.from_tensor_list(ta_s, 653 structure.to_tensor_list(ta_s, ta)) 654 self.assertEqual(ta_after.element_shape.as_list(), [3]) 655 656 @combinations.generate(test_base.default_test_combinations()) 657 def testIncompatibleStructure(self): 658 # Define three mutually incompatible values/structures, and assert that: 659 # 1. Using one structure to flatten a value with an incompatible structure 660 # fails. 661 # 2. Using one structure to restructure a flattened value with an 662 # incompatible structure fails. 663 value_tensor = constant_op.constant(42.0) 664 s_tensor = structure.type_spec_from_value(value_tensor) 665 flat_tensor = structure.to_tensor_list(s_tensor, value_tensor) 666 667 value_sparse_tensor = sparse_tensor.SparseTensor( 668 indices=[[0, 0]], values=[1], dense_shape=[1, 1]) 669 s_sparse_tensor = structure.type_spec_from_value(value_sparse_tensor) 670 flat_sparse_tensor = structure.to_tensor_list(s_sparse_tensor, 671 value_sparse_tensor) 672 673 value_nest = { 674 "a": constant_op.constant(37.0), 675 "b": constant_op.constant([1, 2, 3]) 676 } 677 s_nest = structure.type_spec_from_value(value_nest) 678 flat_nest = structure.to_tensor_list(s_nest, value_nest) 679 680 with self.assertRaisesRegex( 681 ValueError, r"SparseTensor.* is not convertible to a tensor with " 682 r"dtype.*float32.* and shape \(\)"): 683 structure.to_tensor_list(s_tensor, value_sparse_tensor) 684 with self.assertRaisesRegex( 685 ValueError, "The two structures don't have the same nested structure."): 686 structure.to_tensor_list(s_tensor, value_nest) 687 688 with self.assertRaisesRegex(TypeError, 689 "Neither a SparseTensor nor SparseTensorValue"): 690 structure.to_tensor_list(s_sparse_tensor, value_tensor) 691 692 with self.assertRaisesRegex( 693 ValueError, "The two structures don't have the same nested structure."): 694 structure.to_tensor_list(s_sparse_tensor, value_nest) 695 696 with self.assertRaisesRegex( 697 ValueError, "The two structures don't have the same nested structure."): 698 structure.to_tensor_list(s_nest, value_tensor) 699 700 with self.assertRaisesRegex( 701 ValueError, "The two structures don't have the same nested structure."): 702 structure.to_tensor_list(s_nest, value_sparse_tensor) 703 704 with self.assertRaisesRegex(ValueError, r"Incompatible input:"): 705 structure.from_tensor_list(s_tensor, flat_sparse_tensor) 706 707 with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."): 708 structure.from_tensor_list(s_tensor, flat_nest) 709 710 with self.assertRaisesRegex(ValueError, "Incompatible input: "): 711 structure.from_tensor_list(s_sparse_tensor, flat_tensor) 712 713 with self.assertRaisesRegex(ValueError, "Expected 1 tensors but got 2."): 714 structure.from_tensor_list(s_sparse_tensor, flat_nest) 715 716 with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."): 717 structure.from_tensor_list(s_nest, flat_tensor) 718 719 with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 1."): 720 structure.from_tensor_list(s_nest, flat_sparse_tensor) 721 722 @combinations.generate(test_base.default_test_combinations()) 723 def testIncompatibleNestedStructure(self): 724 # Define three mutually incompatible nested values/structures, and assert 725 # that: 726 # 1. Using one structure to flatten a value with an incompatible structure 727 # fails. 728 # 2. Using one structure to restructure a flattened value with an 729 # incompatible structure fails. 730 731 value_0 = { 732 "a": constant_op.constant(37.0), 733 "b": constant_op.constant([1, 2, 3]) 734 } 735 s_0 = structure.type_spec_from_value(value_0) 736 flat_s_0 = structure.to_tensor_list(s_0, value_0) 737 738 # `value_1` has compatible nested structure with `value_0`, but different 739 # classes. 740 value_1 = { 741 "a": 742 constant_op.constant(37.0), 743 "b": 744 sparse_tensor.SparseTensor( 745 indices=[[0, 0]], values=[1], dense_shape=[1, 1]) 746 } 747 s_1 = structure.type_spec_from_value(value_1) 748 flat_s_1 = structure.to_tensor_list(s_1, value_1) 749 750 # `value_2` has incompatible nested structure with `value_0` and `value_1`. 751 value_2 = { 752 "a": 753 constant_op.constant(37.0), 754 "b": (sparse_tensor.SparseTensor( 755 indices=[[0, 0]], values=[1], dense_shape=[1, 1]), 756 sparse_tensor.SparseTensor( 757 indices=[[3, 4]], values=[-1], dense_shape=[4, 5])) 758 } 759 s_2 = structure.type_spec_from_value(value_2) 760 flat_s_2 = structure.to_tensor_list(s_2, value_2) 761 762 with self.assertRaisesRegex( 763 ValueError, r"SparseTensor.* is not convertible to a tensor with " 764 r"dtype.*int32.* and shape \(3,\)"): 765 structure.to_tensor_list(s_0, value_1) 766 767 with self.assertRaisesRegex( 768 ValueError, "The two structures don't have the same nested structure."): 769 structure.to_tensor_list(s_0, value_2) 770 771 with self.assertRaisesRegex(TypeError, 772 "Neither a SparseTensor nor SparseTensorValue"): 773 structure.to_tensor_list(s_1, value_0) 774 775 with self.assertRaisesRegex( 776 ValueError, "The two structures don't have the same nested structure."): 777 structure.to_tensor_list(s_1, value_2) 778 779 # NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp 780 # needs to account for "a" coming before or after "b". It might be worth 781 # adding a deterministic repr for these error messages (among other 782 # improvements). 783 with self.assertRaisesRegex( 784 ValueError, "The two structures don't have the same nested structure."): 785 structure.to_tensor_list(s_2, value_0) 786 787 with self.assertRaisesRegex( 788 ValueError, "The two structures don't have the same nested structure."): 789 structure.to_tensor_list(s_2, value_1) 790 791 with self.assertRaisesRegex(ValueError, r"Incompatible input:"): 792 structure.from_tensor_list(s_0, flat_s_1) 793 794 with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3."): 795 structure.from_tensor_list(s_0, flat_s_2) 796 797 with self.assertRaisesRegex(ValueError, "Incompatible input: "): 798 structure.from_tensor_list(s_1, flat_s_0) 799 800 with self.assertRaisesRegex(ValueError, "Expected 2 tensors but got 3."): 801 structure.from_tensor_list(s_1, flat_s_2) 802 803 with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2."): 804 structure.from_tensor_list(s_2, flat_s_0) 805 806 with self.assertRaisesRegex(ValueError, "Expected 3 tensors but got 2."): 807 structure.from_tensor_list(s_2, flat_s_1) 808 809 @combinations.generate( 810 combinations.times(test_base.default_test_combinations(), 811 _test_convert_legacy_structure_combinations())) 812 def testConvertLegacyStructure(self, output_types, output_shapes, 813 output_classes, expected_structure): 814 actual_structure = structure.convert_legacy_structure( 815 output_types, output_shapes, output_classes) 816 self.assertEqual(actual_structure, expected_structure) 817 818 @combinations.generate(test_base.default_test_combinations()) 819 def testNestedNestedStructure(self): 820 s = (tensor_spec.TensorSpec([], dtypes.int64), 821 (tensor_spec.TensorSpec([], dtypes.float32), 822 tensor_spec.TensorSpec([], dtypes.string))) 823 824 int64_t = constant_op.constant(37, dtype=dtypes.int64) 825 float32_t = constant_op.constant(42.0) 826 string_t = constant_op.constant("Foo") 827 828 nested_tensors = (int64_t, (float32_t, string_t)) 829 830 tensor_list = structure.to_tensor_list(s, nested_tensors) 831 for expected, actual in zip([int64_t, float32_t, string_t], tensor_list): 832 self.assertIs(expected, actual) 833 834 (actual_int64_t, 835 (actual_float32_t, 836 actual_string_t)) = structure.from_tensor_list(s, tensor_list) 837 self.assertIs(int64_t, actual_int64_t) 838 self.assertIs(float32_t, actual_float32_t) 839 self.assertIs(string_t, actual_string_t) 840 841 (actual_int64_t, (actual_float32_t, actual_string_t)) = ( 842 structure.from_compatible_tensor_list(s, tensor_list)) 843 self.assertIs(int64_t, actual_int64_t) 844 self.assertIs(float32_t, actual_float32_t) 845 self.assertIs(string_t, actual_string_t) 846 847 @combinations.generate( 848 combinations.times(test_base.default_test_combinations(), 849 _test_batch_combinations())) 850 def testBatch(self, element_structure, batch_size, 851 expected_batched_structure): 852 batched_structure = nest.map_structure( 853 lambda component_spec: component_spec._batch(batch_size), 854 element_structure) 855 self.assertEqual(batched_structure, expected_batched_structure) 856 857 @combinations.generate( 858 combinations.times(test_base.default_test_combinations(), 859 _test_unbatch_combinations())) 860 def testUnbatch(self, element_structure, expected_unbatched_structure): 861 unbatched_structure = nest.map_structure( 862 lambda component_spec: component_spec._unbatch(), element_structure) 863 self.assertEqual(unbatched_structure, expected_unbatched_structure) 864 865 # pylint: disable=g-long-lambda 866 @combinations.generate( 867 combinations.times(test_base.default_test_combinations(), 868 _test_to_batched_tensor_list_combinations())) 869 def testToBatchedTensorList(self, value_fn, element_0_fn): 870 batched_value = value_fn() 871 s = structure.type_spec_from_value(batched_value) 872 batched_tensor_list = structure.to_batched_tensor_list(s, batched_value) 873 874 # The batch dimension is 2 for all of the test cases. 875 # NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT 876 # tensors in which we store sparse tensors. 877 for t in batched_tensor_list: 878 if t.dtype != dtypes.variant: 879 self.assertEqual(2, self.evaluate(array_ops.shape(t)[0])) 880 881 # Test that the 0th element from the unbatched tensor is equal to the 882 # expected value. 883 expected_element_0 = self.evaluate(element_0_fn()) 884 unbatched_s = nest.map_structure( 885 lambda component_spec: component_spec._unbatch(), s) 886 actual_element_0 = structure.from_tensor_list( 887 unbatched_s, [t[0] for t in batched_tensor_list]) 888 889 for expected, actual in zip( 890 nest.flatten(expected_element_0), nest.flatten(actual_element_0)): 891 self.assertValuesEqual(expected, actual) 892 893 # pylint: enable=g-long-lambda 894 895 @combinations.generate(test_base.default_test_combinations()) 896 def testDatasetSpecConstructor(self): 897 rt_spec = ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32) 898 st_spec = sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32) 899 t_spec = tensor_spec.TensorSpec([10, 8], dtypes.string) 900 element_spec = {"rt": rt_spec, "st": st_spec, "t": t_spec} 901 ds_struct = dataset_ops.DatasetSpec(element_spec, [5]) 902 self.assertEqual(ds_struct._element_spec, element_spec) 903 # Note: shape was automatically converted from a list to a TensorShape. 904 self.assertEqual(ds_struct._dataset_shape, tensor_shape.TensorShape([5])) 905 906 @combinations.generate(test_base.default_test_combinations()) 907 def testCustomMapping(self): 908 elem = CustomMap(foo=constant_op.constant(37.)) 909 spec = structure.type_spec_from_value(elem) 910 self.assertIsInstance(spec, CustomMap) 911 self.assertEqual(spec["foo"], tensor_spec.TensorSpec([], dtypes.float32)) 912 913 @combinations.generate(test_base.default_test_combinations()) 914 def testObjectProxy(self): 915 nt_type = collections.namedtuple("A", ["x", "y"]) 916 proxied = wrapt.ObjectProxy(nt_type(1, 2)) 917 proxied_spec = structure.type_spec_from_value(proxied) 918 self.assertEqual(structure.type_spec_from_value(nt_type(1, 2)), 919 proxied_spec) 920 921 922class CustomMap(collections_abc.Mapping): 923 """Custom, immutable map.""" 924 925 def __init__(self, *args, **kwargs): 926 self.__dict__.update(dict(*args, **kwargs)) 927 928 def __getitem__(self, x): 929 return self.__dict__[x] 930 931 def __iter__(self): 932 return iter(self.__dict__) 933 934 def __len__(self): 935 return len(self.__dict__) 936 937 938if __name__ == "__main__": 939 test.main() 940