1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf.framework.extension_type.""" 16 17import contextlib 18import tempfile 19import typing 20 21from absl.testing import parameterized 22 23from tensorflow.python.eager import context 24from tensorflow.python.eager import def_function 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import extension_type 28from tensorflow.python.framework import extension_type_field 29from tensorflow.python.framework import immutable_dict 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.framework import test_util 33from tensorflow.python.framework import type_spec 34from tensorflow.python.module import module 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops.ragged import ragged_factory_ops 39from tensorflow.python.ops.ragged import ragged_tensor 40from tensorflow.python.platform import googletest 41from tensorflow.python.platform import test 42from tensorflow.python.saved_model import load 43from tensorflow.python.saved_model import save 44from tensorflow.python.util import dispatch 45from tensorflow.python.util import nest 46from tensorflow.python.util import tf_inspect 47 48 49class MaskedTensorV1(extension_type.ExtensionType): 50 """Example subclass of ExtensionType, used for testing.""" 51 values: ops.Tensor 52 mask: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool) 53 54 55class MaskedTensorV2(extension_type.ExtensionType): 56 """Example subclass of ExtensionType, used for testing. 57 58 This version adds methods, classmethod, staticmethod, and properties, and 59 customizes `__repr__` and `__validate__`. It also adds a `__name__` field, 60 which enables serialization. 61 """ 62 __name__ = 'tf.test.MaskedTensorV2' 63 64 values: ops.Tensor 65 mask: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool) 66 67 def __repr__(self): 68 if hasattr(self.values, 'numpy') and hasattr(self.mask, 'numpy'): 69 return '<MaskedTensorV2 %s>' % _masked_array_repr(self.values.numpy(), 70 self.mask.numpy()) 71 else: 72 return super(MaskedTensorV2, self).__repr__() 73 74 @property 75 def shape(self): 76 return self.values.shape 77 78 @property 79 def dtype(self): 80 return self.values.dtype 81 82 @classmethod 83 def from_full_tensor(cls, values): 84 return cls(values, array_ops.ones_like(values, dtype=dtypes.bool)) 85 86 # A dummy example to test support of staticmethod 87 @staticmethod 88 def doc_link(): 89 return 'http://example.com/masked_tensor' 90 91 def __validate__(self): 92 self.values.shape.assert_is_compatible_with(self.mask.shape) 93 94 def with_default(self, default): 95 return array_ops.where_v2(self.mask, self.values, default) 96 97 __add__ = math_ops.add 98 __sub__ = math_ops.subtract 99 100 101def _masked_array_repr(values, mask): 102 """Returns a string representation for a masked numpy array.""" 103 assert len(values) == len(mask) 104 if len(values.shape) == 1: 105 items = [repr(v) if m else '_' for (v, m) in zip(values, mask)] 106 else: 107 items = [_masked_array_repr(v, m) for (v, m) in zip(values, mask)] 108 return '[%s]' % ', '.join(items) 109 110 111class ForwardRefA(extension_type.ExtensionType): 112 x: typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'], ...] 113 y: 'ForwardRefB' 114 115 116class ForwardRefB(extension_type.ExtensionType): 117 z: 'ForwardRefB' 118 n: ops.Tensor 119 120 121@test_util.run_all_in_graph_and_eager_modes 122class ExtensionTypeTest(test_util.TensorFlowTestCase, parameterized.TestCase): 123 124 def testAttributeAccessors(self): 125 mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True]) 126 mt2 = extension_type.pack(mt1) 127 128 for mt in [mt1, mt2]: 129 self.assertIsInstance(mt.values, ops.Tensor) 130 self.assertAllEqual(mt.values, [1, 2, 3, 4]) 131 self.assertIsInstance(mt.mask, ops.Tensor) 132 self.assertAllEqual(mt.mask, [True, True, False, True]) 133 134 def testAttributesAreImmutable(self): 135 mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True]) 136 mt2 = extension_type.pack(mt1) 137 138 for mt in [mt1, mt2]: 139 with self.assertRaisesRegex(AttributeError, 140 "cannot assign to field 'score'"): 141 mt.score = 12 142 with self.assertRaisesRegex(AttributeError, 143 "cannot assign to field 'values'"): 144 mt.values = constant_op.constant([4, 3, 2, 1]) 145 with self.assertRaisesRegex(AttributeError, 146 "cannot delete field 'values'"): 147 del mt.values 148 149 def testClassAndStaticMethod(self): 150 mt = MaskedTensorV2.from_full_tensor([1, 2, 3, 4]) 151 self.assertAllEqual(mt.mask, [True, True, True, True]) 152 self.assertEqual(mt.doc_link(), 'http://example.com/masked_tensor') 153 154 def testRepr(self): 155 values = constant_op.constant([1, 2, 3, 4]) 156 mask = constant_op.constant([True, True, False, True]) 157 mt = MaskedTensorV1(values, mask) 158 expected = f'MaskedTensorV1(values={values!r}, mask={mask!r})' 159 self.assertEqual(expected, repr(mt)) 160 161 def testEagerRepr(self): 162 values = constant_op.constant([1, 2, 3, 4]) 163 mask = constant_op.constant([True, True, False, True]) 164 mt = MaskedTensorV2(values, mask) 165 if context.executing_eagerly(): 166 expected = '<MaskedTensorV2 [1, 2, _, 4]>' 167 else: 168 expected = f'MaskedTensorV2(values={values!r}, mask={mask!r})' 169 170 self.assertEqual(expected, repr(mt)) 171 self.assertEqual(expected, repr(mt)) 172 173 def testConstructorSignature(self): 174 175 class MyType(extension_type.ExtensionType): 176 x: ops.Tensor 177 y: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool) 178 z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3] 179 180 expected_parameters = [ 181 tf_inspect.Parameter('self', 182 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), 183 tf_inspect.Parameter( 184 'x', 185 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, 186 annotation=ops.Tensor), 187 tf_inspect.Parameter( 188 'y', 189 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, 190 annotation=tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)), 191 tf_inspect.Parameter( 192 'z', 193 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, 194 annotation=typing.Tuple[typing.Union[int, str], ...], 195 default=(1, 'two', 3)), 196 ] 197 expected_sig = tf_inspect.Signature( 198 expected_parameters, return_annotation=MyType) 199 self.assertEqual(expected_sig, tf_inspect.signature(MyType.__init__)) 200 201 def testEmptyType(self): 202 203 class EmptyType(extension_type.ExtensionType): 204 pass 205 206 self.assertEmpty(EmptyType._tf_extension_type_fields()) 207 x = EmptyType() 208 self.assertEqual(repr(x), 'EmptyType()') 209 210 def testCustomConstrutor(self): 211 212 class SummarizedTensor(extension_type.ExtensionType): 213 values: ops.Tensor 214 mean: ops.Tensor 215 max: ops.Tensor 216 217 def __init__(self, values): 218 self.values = ops.convert_to_tensor(values) 219 self.mean = math_ops.reduce_mean(values) 220 self.max = math_ops.reduce_max(values) 221 222 x = SummarizedTensor([[1.0, 2, 3], [4, 5, 6]]) 223 self.assertAllEqual(x.values, [[1.0, 2, 3], [4, 5, 6]]) 224 self.assertAllEqual(x.mean, 3.5) 225 self.assertAllEqual(x.max, 6) 226 227 class Node(extension_type.ExtensionType): 228 x: ops.Tensor 229 y: typing.Optional[str] = None 230 children: typing.Tuple['ExtensionTypeTest.Node', ...] = () 231 232 def testCustomConstructorWithDefaultValues(self): 233 a = ExtensionTypeTest.Node(5) 234 self.assertAllEqual(a.x, 5) 235 self.assertIsNone(a.y) 236 self.assertEqual(a.children, ()) 237 238 b = ExtensionTypeTest.Node(6, 'blue') 239 self.assertAllEqual(b.x, 6) 240 self.assertEqual(b.y, 'blue') 241 self.assertEqual(b.children, ()) 242 243 c = ExtensionTypeTest.Node(7, children=(a, b)) 244 self.assertAllEqual(c.x, 7) 245 self.assertIsNone(c.y) 246 self.assertEqual(c.children, (a, b)) 247 248 def testCustomConstructorNondefaultCanotFollowDefault(self): 249 with self.assertRaisesRegex( 250 ValueError, "Field without default 'd' follows field with default 'c'"): 251 252 class MyType(extension_type.ExtensionType): 253 a: int 254 b: str = 'Hello world' 255 c: typing.Optional[ops.Tensor] = None 256 d: ops.Tensor 257 258 del MyType 259 260 def testCustomConstrutorCantMutateNestedValues(self): 261 262 class Foo(extension_type.ExtensionType): 263 x: int 264 265 class Bar(extension_type.ExtensionType): 266 foo: Foo 267 268 def __init__(self, foo): 269 foo.x = 33 # This raises an exception 270 271 with self.assertRaisesRegex(AttributeError, "cannot assign to field 'x'"): 272 Bar(Foo(12)) 273 274 def testCustomValidate(self): 275 276 class AlignedTensors(extension_type.ExtensionType): 277 x: ops.Tensor 278 y: ops.Tensor 279 280 def __validate__(self): 281 self.x.shape.assert_is_compatible_with(self.y.shape) 282 283 aligned = AlignedTensors([1, 2, 3], ['a', 'b', 'c']) 284 self.assertAllEqual(aligned.x, [1, 2, 3]) 285 self.assertAllEqual(aligned.y, [b'a', b'b', b'c']) 286 287 with self.assertRaises(ValueError): 288 AlignedTensors([1, 2, 3], ['a', 'b', 'c', 'd']) 289 290 def testEquals(self): 291 292 class MyType(extension_type.ExtensionType): 293 values: ops.Tensor 294 score: ops.Tensor 295 flavor: str 296 297 x1 = MyType([1, 2], 8, 'blue') 298 x2 = MyType([1, 2], 8, 'blue') 299 y = MyType([1, 2], 8, 'red') 300 z = MyType([1, 2], 7, 'blue') 301 self.assertAllEqual(x1 == x2, True) 302 self.assertAllEqual(x1 != x2, False) 303 self.assertAllEqual(x1 == y, False) 304 self.assertAllEqual(x1 != y, True) 305 self.assertAllEqual(x1 == z, False) 306 self.assertAllEqual(y == z, False) 307 308 # These are not equal, even though their values are broadcast-compatible 309 # and elements are all equal when we broadcast. Shapes must match. 310 a = MyType([1, 1, 1, 1], 0, 'x') 311 b = MyType([[1, 1, 1, 1]], 0, 'x') 312 c = MyType([[1, 1], [1, 1]], 0, 'x') 313 self.assertAllEqual(a == b, False) 314 self.assertAllEqual(a == c, False) 315 self.assertAllEqual(b == c, False) 316 317 # Test with unknown shapes (executes a different codepath). 318 a_ph = replace_tensors_with_placeholders(a) 319 b_ph = replace_tensors_with_placeholders(b) 320 c_ph = replace_tensors_with_placeholders(c) 321 self.assertAllEqual(a_ph == b_ph, False) 322 self.assertAllEqual(a_ph == c_ph, False) 323 self.assertAllEqual(b_ph == c_ph, False) 324 325 def testPassIntoTfFunction(self): 326 327 @def_function.function 328 def fn(x): 329 return x.with_default(99) 330 331 mt = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True]) 332 self.assertAllEqual([1, 2, 99, 4], fn(mt)) 333 self.assertAllEqual([1, 2, 99, 4], fn(extension_type.pack(mt))) 334 335 def testReturnFromTfFunction(self): 336 337 @def_function.function 338 def mask_neg_values(x): 339 return MaskedTensorV2(x, x > 0) 340 341 @def_function.function 342 def mask_neg_values_packed(x): 343 return extension_type.pack(MaskedTensorV2(x, x > 0)) 344 345 expected = MaskedTensorV2([5, 8, -3, 9], [True, True, False, True]) 346 347 actual1 = mask_neg_values(constant_op.constant([5, 8, -3, 9])) 348 self.assertIsInstance(actual1, MaskedTensorV2) 349 self.assertAllEqual(expected.values, actual1.values) 350 self.assertAllEqual(expected.mask, actual1.mask) 351 352 actual2 = mask_neg_values_packed(constant_op.constant([5, 8, -3, 9])) 353 self.assertIsInstance(actual2, MaskedTensorV2) 354 self.assertTrue(extension_type.is_packed(actual2)) 355 self.assertAllEqual(expected.values, actual2.values) 356 self.assertAllEqual(expected.mask, actual2.mask) 357 358 def testCaptureByTfFunction(self): 359 x = MaskedTensorV2( 360 values=[[1, 2, 3], [4, 5, 6]], 361 mask=[[True, True, True], [True, False, True]]) 362 363 @def_function.function 364 def add_to_x(y): 365 return MaskedTensorV2(x.values + y.values, x.mask & y.mask) 366 367 actual = add_to_x(MaskedTensorV2([10, 20, 30], [False, True, True])) 368 expected = MaskedTensorV2( 369 values=[[11, 22, 33], [14, 25, 36]], 370 mask=[[False, True, True], [False, False, True]]) 371 self.assertIsInstance(actual, MaskedTensorV2) 372 self.assertAllEqual(expected.values, actual.values) 373 self.assertAllEqual(expected.mask, actual.mask) 374 375 def testTfFunctionArgMutationError(self): 376 377 @def_function.function 378 def fn_with_side_effect(mts): 379 mts.append(MaskedTensorV1(mts[0].values * 2, mts[0].mask)) 380 381 with self.assertRaisesRegex(ValueError, 'should not modify'): 382 fn_with_side_effect([MaskedTensorV1([10, 20, 30], [False, True, True])]) 383 384 def testNestPackUnpack(self): 385 386 class CandyStore(extension_type.ExtensionType): 387 name: ops.Tensor 388 prices: typing.Mapping[str, ops.Tensor] 389 390 store = CandyStore('Yum', {'gum': [0.42, 0.48], 'chocolate': [0.83, 1.02]}) 391 components = nest.flatten(store, expand_composites=True) 392 repacked_1 = nest.pack_sequence_as( 393 store, components, expand_composites=True) 394 repacked_2 = nest.pack_sequence_as( 395 store._type_spec, components, expand_composites=True) 396 397 # Note: dicts get sorted by key. 398 self.assertLen(components, 3) 399 self.assertAllEqual(components[0], b'Yum') 400 self.assertAllClose(components[1], [0.83, 1.02]) 401 self.assertAllClose(components[2], [0.42, 0.48]) 402 403 for repacked in [repacked_1, repacked_2]: 404 self.assertAllEqual(repacked.name, b'Yum') 405 self.assertAllClose(repacked.prices['gum'], [0.42, 0.48]) 406 self.assertAllClose(repacked.prices['chocolate'], [0.83, 1.02]) 407 408 def testSimpleCond(self): 409 x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]) 410 y = MaskedTensorV1([5, 6, 7, 8], [False, True, True, False]) 411 412 x_2 = control_flow_ops.cond( 413 constant_op.constant(True), lambda: x, lambda: y) 414 y_2 = control_flow_ops.cond( 415 constant_op.constant(False), lambda: x, lambda: y) 416 417 self.assertAllEqual(x.values, x_2.values) 418 self.assertAllEqual(x.mask, x_2.mask) 419 self.assertAllEqual(y.values, y_2.values) 420 self.assertAllEqual(y.mask, y_2.mask) 421 422 def testComplexCond(self): 423 mt = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]) 424 425 def true_fn(): 426 return MaskedTensorV1( 427 array_ops.where_v2(mt.mask, mt.values, -1), mt.values > 3) 428 429 def false_fn(): 430 return MaskedTensorV1( 431 array_ops.where_v2(mt.mask, 100, mt.values * 2), 432 math_ops.logical_not(mt.mask)) 433 434 x = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn) 435 y = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn) 436 437 self.assertAllEqual(x.values, [1, -1, 3, -1]) 438 self.assertAllEqual(x.mask, [False, False, False, True]) 439 self.assertAllEqual(y.values, [100, 4, 100, 8]) 440 self.assertAllEqual(y.mask, [False, True, False, True]) 441 442 def testCondAutograph(self): 443 444 @def_function.function 445 def fn(mt): 446 if mt.values[3] > 3: 447 return MaskedTensorV1( 448 array_ops.where_v2(mt.mask, mt.values, -1), mt.values > 3) 449 else: 450 return MaskedTensorV1( 451 array_ops.where_v2(mt.mask, 100, mt.values * 2), not mt.mask) 452 453 x = fn(MaskedTensorV1([1, 2, 3, 4], [True, False, True, False])) 454 self.assertAllEqual(x.values, [1, -1, 3, -1]) 455 self.assertAllEqual(x.mask, [False, False, False, True]) 456 457 def testCondTypeMismatch(self): 458 if context.executing_eagerly: 459 # In eager mode, tf.cond eagerly runs either true_fn or false_fn, and 460 # ignores the other one; so it doesn't detect any type mismatches 461 # between the two outcomes. (See _eager_cond_implementation in 462 # control_flow_ops.py.) 463 return 464 465 a = lambda: MaskedTensorV1([1, 2, 3], [True, True, False]) 466 b = lambda: MaskedTensorV1(['a', 'b', 'c'], [False, True, True]) 467 c = lambda: MaskedTensorV2([4, 5, 6], [True, True, False]) 468 d = lambda: constant_op.constant([7, 8, 9]) 469 470 with self.assertRaisesRegex( 471 ValueError, 472 'Incompatible return values of true_fn and false_fn: The two ' 473 "structures don't have the same nested structure"): 474 control_flow_ops.cond(constant_op.constant(True), a, b) 475 with self.assertRaisesRegex( 476 TypeError, 'Incompatible return types of true_fn and false_fn: The two ' 477 "structures don't have the same nested structure"): 478 control_flow_ops.cond(constant_op.constant(True), a, c) 479 with self.assertRaisesRegex( 480 ValueError, 481 'Incompatible return values of true_fn and false_fn: The two ' 482 "structures don't have the same nested structure"): 483 control_flow_ops.cond(constant_op.constant(True), a, d) 484 485 def testCondPacked(self): 486 x = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False]) 487 y = MaskedTensorV2([5, 6, 7, 8], [False, True, True, False]) 488 x = extension_type.pack(x) 489 y = extension_type.pack(y) 490 491 x_2 = control_flow_ops.cond( 492 constant_op.constant(True), lambda: x, lambda: y) 493 y_2 = control_flow_ops.cond( 494 constant_op.constant(False), lambda: x, lambda: y) 495 496 self.assertAllEqual(x.values, x_2.values) 497 self.assertAllEqual(x.mask, x_2.mask) 498 self.assertAllEqual(y.values, y_2.values) 499 self.assertAllEqual(y.mask, y_2.mask) 500 501 a = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False]) 502 b = extension_type.pack(a) 503 b = control_flow_ops.cond( 504 constant_op.constant(True), lambda: array_ops.size(a.mask), 505 lambda: array_ops.size(a.values)) 506 self.assertAllEqual(b, 4) 507 508 # Note: the following example would fail (with `Retval[0] does not have a 509 # value`) if `ExtensionType.__getattr__` cached the results of unpacking 510 # the value. See the comment in `ExtensionType.__getattr__` for details. 511 c = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False]) 512 c = extension_type.pack(c) 513 d = control_flow_ops.cond( 514 constant_op.constant(False), lambda: array_ops.size(c.mask), 515 lambda: array_ops.size(c.values)) 516 self.assertAllEqual(d, 4) 517 518 def testWhileLoop(self): 519 x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]) 520 521 cond = lambda i, x: i < 10 522 body = lambda i, x: (i + 1, MaskedTensorV1(x.values * 2, x.mask)) 523 _, y = control_flow_ops.while_loop_v2(cond, body, [0, x]) 524 525 self.assertIsInstance(y, MaskedTensorV1) 526 self.assertAllEqual(y.values, [1024, 2048, 3072, 4096]) 527 self.assertAllEqual(y.mask, [True, False, True, False]) 528 529 def testWhileLoopAutograph(self): 530 531 @def_function.function 532 def fn(x, n): 533 for _ in math_ops.range(n): 534 x = MaskedTensorV1(x.values * 2, x.mask) 535 return x 536 537 y = fn(MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]), 10) 538 self.assertIsInstance(y, MaskedTensorV1) 539 self.assertAllEqual(y.values, [1024, 2048, 3072, 4096]) 540 self.assertAllEqual(y.mask, [True, False, True, False]) 541 542 def testWhileLoopTypeMismatch(self): 543 x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]) 544 545 cond = lambda i, x: i < 10 546 547 def body(i, x): 548 if isinstance(x, MaskedTensorV1): 549 return x.values * 2 550 else: 551 return MaskedTensorV1(x, x > i) 552 553 with self.assertRaisesRegex( 554 ValueError, "The two structures don't have the same nested structure"): 555 control_flow_ops.while_loop_v2(cond, body, [0, x]) 556 557 def testWhileLoopPacked(self): 558 x = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False]) 559 x = extension_type.pack(x) 560 cond = lambda i, x: i < 10 561 562 def body(i, x): 563 return i + 1, extension_type.pack(MaskedTensorV2(x.values * 2, x.mask)) 564 565 _, y = control_flow_ops.while_loop_v2(cond, body, [0, x]) 566 self.assertIsInstance(y, MaskedTensorV2) 567 self.assertAllEqual(y.values, [1024, 2048, 3072, 4096]) 568 self.assertAllEqual(y.mask, [True, False, True, False]) 569 570 def testNestedFields(self): 571 PossiblyRaggedTensor = typing.Union[ops.Tensor, ragged_tensor.RaggedTensor] 572 ToyFeatures = typing.Mapping[str, PossiblyRaggedTensor] 573 574 class ToyInfo(extension_type.ExtensionType): 575 version: str 576 toys: typing.Tuple[typing.Tuple[str, ops.Tensor, ToyFeatures], ...] 577 boxes: typing.Mapping[str, ops.Tensor] 578 579 authors = [[b'A', b'Aardvark'], [b'Z', b'Zhook']] 580 toys = [('car', 1.0, { 581 'size': [8, 3, 2], 582 'color': [0.3, 0.2, 0.8] 583 }), ('book', 3.7, { 584 'authors': ragged_factory_ops.constant(authors) 585 })] 586 boxes = {'green': ['car'], 'blue': ['car', 'book', 'book']} 587 toy_info = ToyInfo(version='1.0 alpha', toys=toys, boxes=boxes) 588 589 self.assertEqual(toy_info.version, '1.0 alpha') 590 self.assertEqual(toy_info.toys[0][0], 'car') 591 self.assertIsInstance(toy_info.toys[0][1], ops.Tensor) 592 self.assertAllEqual(toy_info.toys[0][1], 1.0) 593 self.assertEqual(set(toy_info.toys[0][2].keys()), {'size', 'color'}) 594 self.assertIsInstance(toy_info.toys[0][2]['size'], ops.Tensor) 595 self.assertAllEqual(toy_info.toys[0][2]['size'], [8, 3, 2]) 596 self.assertIsInstance(toy_info.toys[1][2]['authors'], 597 ragged_tensor.RaggedTensor) 598 self.assertAllEqual(toy_info.toys[1][2]['authors'], authors) 599 self.assertAllEqual(toy_info.boxes['green'], [b'car']) 600 self.assertAllEqual(toy_info.boxes['blue'], ['car', 'book', 'book']) 601 602 expected_repr = ( 603 r"ToyInfo\(version='1.0 alpha', toys=\(" 604 r"\('car', <tf.Tensor[^>]*>, ImmutableDict\(" 605 r"{'size': <tf.Tensor[^>]*>, 'color': <tf.Tensor[^>]*>}\)\), " 606 r"\('book', <tf.Tensor[^>]*>, ImmutableDict\(" 607 r"{'authors': (<tf.RaggedTensor[^>]*>|tf.RaggedTensor\(.*\))}\)\)\), " 608 r'boxes=ImmutableDict\(' 609 r"{'green': <tf.Tensor[^>]*>, 'blue': <tf.Tensor[^>]*>}\)\)") 610 611 self.assertRegex(repr(toy_info), expected_repr) 612 613 def testNestedExtensionTypes(self): 614 PossiblyMaskedTensor = typing.Union[ops.Tensor, MaskedTensorV1] 615 616 class Toy(extension_type.ExtensionType): 617 name: str 618 price: ops.Tensor 619 features: typing.Mapping[str, PossiblyMaskedTensor] 620 621 class Box(extension_type.ExtensionType): 622 contents: ops.Tensor 623 624 class ToyInfo(extension_type.ExtensionType): 625 version: str 626 toys: typing.Tuple[Toy, ...] 627 boxes: typing.Mapping[str, Box] 628 629 authors = MaskedTensorV1( 630 values=[[b'A', b'Quincy', b'Aardvark'], [b'Z', b'Zhook', b'']], 631 mask=[[True, True, True], [True, True, False]]) 632 toys = [ 633 Toy('car', 1.0, { 634 'size': [8, 3, 2], 635 'color': [0.3, 0.2, 0.8] 636 }), 637 Toy(name='book', price=3.7, features={'authors': authors}) 638 ] 639 boxes = { 640 'green': Box(['car']), 641 'blue': Box(contents=['car', 'book', 'book']) 642 } 643 toy_info = ToyInfo(version='1.0 alpha', toys=toys, boxes=boxes) 644 645 @def_function.function 646 def fn(info): 647 prices = [toy.price for toy in info.toys] 648 return math_ops.reduce_sum(array_ops.stack(prices)) 649 650 self.assertAllClose(fn(toy_info), 4.7) 651 652 def testNestedCustomConstructor(self): 653 654 class Toy(extension_type.ExtensionType): 655 name: str 656 price: ops.Tensor 657 658 def __init__(self, name, price, discount=0): 659 if discount: 660 name += ' (discounted)' 661 price *= (1 - discount) 662 self.name = name 663 self.price = price 664 665 class ToyBox(extension_type.ExtensionType): 666 toys: typing.Tuple[Toy, ...] 667 668 def __init__(self, name_to_price, name_to_discount): 669 self.toys = [ 670 Toy(name, price, name_to_discount.get(name, 0)) 671 for (name, price) in name_to_price.items() 672 ] 673 674 toy_box = ToyBox({ 675 'car': 8.3, 676 'truck': 5.9, 677 'puzzle': 5.3, 678 'jacks': 2.8 679 }, { 680 'puzzle': .2, 681 'truck': .3 682 }) 683 self.assertLen(toy_box.toys, 4) 684 self.assertEqual( 685 set(toy.name for toy in toy_box.toys), 686 {'car', 'truck (discounted)', 'puzzle (discounted)', 'jacks'}) 687 688 def testExtensionTypeWithMathOperators(self): 689 690 def masked_add(x, y, name=None): 691 del name 692 if not isinstance(x, MaskedTensorV2) and isinstance(y, MaskedTensorV2): 693 return dispatch.OpDispatcher.NOT_SUPPORTED 694 return MaskedTensorV2(x.values + y.values, x.mask & y.mask) 695 696 with temporarily_add_dispatch(math_ops.add, MaskedTensorV2, masked_add): 697 x = MaskedTensorV2([[1, 2], [3, 4]], [[True, False], [True, True]]) 698 y = MaskedTensorV2([[3, 4], [5, 6]], [[True, True], [False, True]]) 699 z = x + y 700 self.assertAllEqual(z.values, [[4, 6], [8, 10]]) 701 self.assertAllEqual(z.mask, [[True, False], [False, True]]) 702 703 def testGetExtensionTypeFields(self): 704 705 # Can be called on a type or an instance: 706 fields_1 = MaskedTensorV1._tf_extension_type_fields() 707 fields_2 = MaskedTensorV1([0], [True])._tf_extension_type_fields() 708 709 for fields in [fields_1, fields_2]: 710 self.assertLen(fields, 2) 711 self.assertEqual(fields[0].name, 'values') 712 self.assertEqual(fields[0].value_type, ops.Tensor) 713 self.assertEqual(fields[0].default, fields[0].NO_DEFAULT) 714 self.assertEqual(fields[1].name, 'mask') 715 self.assertEqual(fields[1].value_type, 716 tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)) 717 self.assertEqual(fields[1].default, fields[0].NO_DEFAULT) 718 719 def testHasExtensionTypeField(self): 720 721 self.assertTrue(MaskedTensorV1._tf_extension_type_has_field('values')) 722 self.assertTrue(MaskedTensorV1._tf_extension_type_has_field('mask')) 723 self.assertFalse(MaskedTensorV1._tf_extension_type_has_field('labels')) 724 725 mt = MaskedTensorV1([0], [True]) 726 self.assertTrue(mt._tf_extension_type_has_field('values')) 727 self.assertTrue(mt._tf_extension_type_has_field('mask')) 728 self.assertFalse(mt._tf_extension_type_has_field('labels')) 729 730 def testForwardReferences(self): 731 A, B = ForwardRefA, ForwardRefB 732 733 self.assertEqual(A._tf_extension_type_fields(), 734 (extension_type_field.ExtensionTypeField( 735 'x', typing.Tuple[typing.Union[A, B], ...]), 736 extension_type_field.ExtensionTypeField('y', B))) 737 self.assertEqual(B._tf_extension_type_fields(), 738 (extension_type_field.ExtensionTypeField('z', B), 739 extension_type_field.ExtensionTypeField('n', ops.Tensor))) 740 741 # Check the signature. 742 expected_parameters = [ 743 tf_inspect.Parameter('self', 744 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), 745 tf_inspect.Parameter( 746 'x', 747 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, 748 annotation=typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'], 749 ...]), 750 tf_inspect.Parameter( 751 'y', 752 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD, 753 annotation='ForwardRefB'), 754 ] 755 expected_sig = tf_inspect.Signature( 756 expected_parameters, return_annotation=A) 757 self.assertEqual(tf_inspect.signature(A.__init__), expected_sig) 758 759 def testUnresolvedForwardReference(self): 760 761 class Broken(extension_type.ExtensionType): 762 x: 'Cra' # note: intentional typo for Car. 763 764 class Car(extension_type.ExtensionType): 765 speed: float 766 767 with self.assertRaises(TypeError): 768 Broken(x=Car(3.8)) 769 770 def testUnsupportedAnnotations(self): 771 with self.assertRaisesRegex( 772 TypeError, "In field 'values': Unsupported type annotation"): 773 774 class MyType1(extension_type.ExtensionType): # pylint: disable=unused-variable 775 values: typing.List[ops.Tensor] 776 777 with self.assertRaisesRegex(TypeError, 778 "In field 'xyz': Unsupported type annotation"): 779 780 class MyType2(extension_type.ExtensionType): # pylint: disable=unused-variable 781 xyz: typing.Union[typing.Tuple[complex, ...], int] 782 783 def testExtensionTypeBaseClassHasNoSpec(self): 784 self.assertFalse(hasattr(extension_type.ExtensionType, 'Spec')) 785 786 def testExtensionTypeBaseConstructorRaisesException(self): 787 with self.assertRaisesRegex(AssertionError, 788 'ExtensionType is an abstract base class.'): 789 extension_type.ExtensionType() 790 791 class ExtensionTypeWithName(extension_type.ExtensionType): 792 __name__ = 'tf.__test__.ExtensionTypeWithName' # For SavedModel 793 x: typing.Tuple[ops.Tensor, int] 794 y: ops.Tensor 795 796 def testSavedModelSupport(self): 797 798 class TestModule(module.Module): 799 800 @def_function.function 801 def f(self, s): 802 return s.x[0] + s.x[1] + s.y 803 804 s1 = self.ExtensionTypeWithName((1, 2), 3) 805 s2 = self.ExtensionTypeWithName((1.0, 2), [3.0, 4.0]) 806 807 m = TestModule() 808 m.f.get_concrete_function(s1) 809 m.f.get_concrete_function(s2) 810 811 path = tempfile.mkdtemp(prefix=test.get_temp_dir()) 812 save.save(m, path) 813 loaded = load.load(path) 814 815 self.assertAllEqual(loaded.f(s1), 6) 816 self.assertAllEqual(loaded.f(s2), [6.0, 7.0]) 817 818 def testPackedEncoding(self): 819 mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True]) 820 self.assertLen(nest.flatten(mt1, expand_composites=True), 2) 821 822 mt2 = extension_type.pack(mt1) 823 self.assertLen(nest.flatten(mt2, expand_composites=True), 1) 824 self.assertIsInstance(mt2.values, ops.Tensor) 825 self.assertAllEqual(mt2.values, [1, 2, 3, 4]) 826 self.assertIsInstance(mt2.mask, ops.Tensor) 827 self.assertAllEqual(mt2.mask, [True, True, False, True]) 828 829 mt3 = extension_type.unpack(mt2) 830 self.assertLen(nest.flatten(mt3, expand_composites=True), 2) 831 self.assertIsInstance(mt3.values, ops.Tensor) 832 self.assertAllEqual(mt3.values, [1, 2, 3, 4]) 833 self.assertIsInstance(mt3.mask, ops.Tensor) 834 self.assertAllEqual(mt3.mask, [True, True, False, True]) 835 836 nest.assert_same_structure(mt1, mt3, expand_composites=True) 837 with self.assertRaisesRegex(ValueError, "don't have the same"): # pylint: disable=g-error-prone-assert-raises 838 nest.assert_same_structure(mt1, mt2, expand_composites=True) 839 840 mt4 = MaskedTensorV1([1, 2, 3, 4], [True, True, False, True]) 841 with self.assertRaisesRegex( 842 ValueError, 843 'ExtensionTypes must have a __name__ field in order to be packed.'): 844 extension_type.pack(mt4) 845 846 847@test_util.run_all_in_graph_and_eager_modes 848class ExtensionTypeSpecTest(test_util.TensorFlowTestCase, 849 parameterized.TestCase): 850 851 def testSpecConstructor(self): 852 values_spec = tensor_spec.TensorSpec([4], dtypes.float32) 853 mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) 854 mt_spec = MaskedTensorV1.Spec(values_spec, mask_spec) 855 self.assertEqual(mt_spec.values, values_spec) 856 self.assertEqual(mt_spec.mask, mask_spec) 857 858 mt = MaskedTensorV1([1.0, 2.0, 3.0, 4.0], [True, True, False, True]) 859 self.assertEqual(mt._type_spec, mt_spec) 860 861 def testSpecConstructorSignature(self): 862 863 class MyType(extension_type.ExtensionType): 864 x: ops.Tensor 865 y: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool) 866 z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3] 867 868 expected_parameters = [ 869 tf_inspect.Parameter('self', 870 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), 871 tf_inspect.Parameter('x', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), 872 tf_inspect.Parameter('y', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), 873 tf_inspect.Parameter('z', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD), 874 ] 875 expected_sig = tf_inspect.Signature( 876 expected_parameters, return_annotation=MyType.Spec) 877 self.assertEqual(expected_sig, tf_inspect.signature(MyType.Spec.__init__)) 878 879 def testSpecAttributesAreImmutable(self): 880 mt = MaskedTensorV1([1, 2, 3, 4], [True, True, False, True]) 881 mt_spec = MaskedTensorV1.Spec.from_value(mt) 882 with self.assertRaisesRegex(AttributeError, 883 "cannot assign to field 'score'"): 884 mt_spec.score = 12 885 with self.assertRaisesRegex(AttributeError, 886 "cannot assign to field 'values'"): 887 mt_spec.values = constant_op.constant([4, 3, 2, 1]) 888 with self.assertRaisesRegex(AttributeError, "cannot delete field 'values'"): 889 del mt_spec.values 890 891 def testSpecFromValue(self): 892 mt = MaskedTensorV1([1.0, 2.0, 3.0, 4.0], [True, True, False, True]) 893 mt_spec = MaskedTensorV1.Spec.from_value(mt) 894 895 expected_values_spec = tensor_spec.TensorSpec([4], dtypes.float32) 896 expected_mask_spec = tensor_spec.TensorSpec([4], dtypes.bool) 897 self.assertEqual(mt_spec.values, expected_values_spec) 898 self.assertEqual(mt_spec.mask, expected_mask_spec) 899 900 def testSpecSerialize(self): 901 902 class Zoo(extension_type.ExtensionType): 903 zookeepers: typing.Tuple[str, ...] 904 animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]] 905 906 featurespec = { 907 'size': tensor_spec.TensorSpec([3]), 908 'weight': tensor_spec.TensorSpec([]) 909 } 910 zoo_spec = Zoo.Spec( 911 zookeepers=['Zoey', 'Zack'], 912 animals={ 913 'tiger': featurespec, 914 'elephant': featurespec 915 }) 916 917 serialized = zoo_spec._serialize() 918 self.assertEqual(serialized, 919 (('zookeepers', ('Zoey', 'Zack')), ('animals', { 920 'tiger': featurespec, 921 'elephant': featurespec 922 }))) 923 restored = Zoo.Spec._deserialize(serialized) 924 self.assertEqual(zoo_spec, restored) 925 926 # ImmutableDict is used for the field, but dict for the serialization: 927 self.assertIsInstance(zoo_spec.animals, immutable_dict.ImmutableDict) 928 serialized_field_name, serialized_field_value = serialized[1] 929 self.assertEqual(serialized_field_name, 'animals') 930 self.assertIsInstance(serialized_field_value, dict) 931 932 def testSpecComponents(self): 933 934 class Zoo(extension_type.ExtensionType): 935 zookeepers: typing.Tuple[str, ...] 936 animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]] 937 938 zoo = Zoo( 939 ['Zoey', 'Zack'], { 940 'elephant': { 941 'size': [25, 30, 20], 942 'weight': 2000.0 943 }, 944 'tiger': { 945 'hunger': 3.2, 946 'size': [3, 8, 2], 947 'weight': 87.3 948 } 949 }) 950 zoo_spec = Zoo.Spec.from_value(zoo) 951 952 components = zoo_spec._to_components(zoo) 953 self.assertLen(components, 5) 954 self.assertAllClose(components[0], [25, 30, 20]) 955 self.assertAllClose(components[1], 2000.0) 956 self.assertAllClose(components[2], 3.2) 957 self.assertAllClose(components[3], [3, 8, 2]) 958 self.assertAllClose(components[4], 87.3) 959 960 restored = zoo_spec._from_components(components) 961 self.assertAllEqual(zoo == restored, True) 962 963 self.assertEqual(zoo_spec._component_specs, 964 (tensor_spec.TensorSpec([3], dtypes.int32), 965 tensor_spec.TensorSpec([], dtypes.float32), 966 tensor_spec.TensorSpec([], dtypes.float32), 967 tensor_spec.TensorSpec([3], dtypes.int32), 968 tensor_spec.TensorSpec([], dtypes.float32))) 969 970 971@test_util.run_all_in_graph_and_eager_modes 972class AnonymousExtensionTypeTest(test_util.TensorFlowTestCase, 973 parameterized.TestCase): 974 975 @parameterized.parameters([ 976 [dict(i=5, f=3.2, b=True, n=None)], 977 [dict(x=(1, 2), y={ 978 3: 4, 979 5: 6 980 })], 981 [lambda: dict(t=constant_op.constant(123))], 982 [lambda: dict(r=ragged_factory_ops.constant([[1, 2], [3]]))], 983 ]) 984 def testConstruction(self, fields): 985 if callable(fields): 986 fields = fields() 987 extension_type.AnonymousExtensionType(**fields) 988 989 @parameterized.parameters([ 990 [dict(x=[1, 2, 3]), 'Unsupported field value'], 991 [dict(x=set([1, 2])), 'Unsupported field value'], 992 [dict(x=(1, dict([(2, [])]))), 'Unsupported field value'], 993 [ 994 dict(_tf_extension_type_xyz=5), 995 "The field name '_tf_extension_type_xyz' is reserved" 996 ], 997 ]) 998 def testConstructionErrors(self, fields, error): 999 with self.assertRaisesRegex(ValueError, error): 1000 extension_type.AnonymousExtensionType(**fields) 1001 1002 @parameterized.parameters([ 1003 [dict(i=5, f=3.2, b=True, n=None)], 1004 [dict(x=(1, 2), y={ 1005 3: 4, 1006 5: 6 1007 })], 1008 [lambda: dict(t=constant_op.constant(123))], 1009 [lambda: dict(r=ragged_factory_ops.constant([[1, 2], [3]]))], 1010 ]) 1011 def testAttributeAccessors(self, fields): 1012 if callable(fields): 1013 fields = fields() 1014 s = extension_type.AnonymousExtensionType(**fields) 1015 for (name, value) in fields.items(): 1016 actual = getattr(s, name) 1017 if isinstance(actual, (ops.Tensor, ragged_tensor.RaggedTensor)): 1018 self.assertAllEqual(actual, value) 1019 else: 1020 self.assertEqual(actual, value) 1021 1022 def testAttributeAccessorsAreImmutable(self): 1023 s = extension_type.AnonymousExtensionType(x=12, y={'x': 55}) 1024 with self.assertRaisesRegex(AttributeError, "cannot assign to field 'x'"): 1025 s.x = 22 1026 with self.assertRaisesRegex(AttributeError, "cannot delete field 'y'"): 1027 del s.y 1028 with self.assertRaisesRegex(TypeError, 'does not support item assignment'): 1029 s.y['x'] = 66 1030 1031 def testReinterpret(self): 1032 x = MaskedTensorV2([4, 5], [True, False]) 1033 anon_x = extension_type.reinterpret(x, 1034 extension_type.AnonymousExtensionType) 1035 self.assertAllEqual(anon_x.values, [4, 5]) 1036 self.assertAllEqual(anon_x.mask, [True, False]) 1037 1038 round_trip_x = extension_type.reinterpret(anon_x, MaskedTensorV2) 1039 self.assertAllEqual(round_trip_x.values, [4, 5]) 1040 self.assertAllEqual(round_trip_x.mask, [True, False]) 1041 1042 converted_x = extension_type.reinterpret(anon_x, MaskedTensorV1) 1043 self.assertAllEqual(converted_x.values, [4, 5]) 1044 self.assertAllEqual(converted_x.mask, [True, False]) 1045 1046 # pylint: disable=g-long-lambda 1047 @parameterized.parameters([ 1048 [ 1049 lambda: extension_type.AnonymousExtensionType( 1050 values=constant_op.constant([1, 2, 3])), MaskedTensorV2, 1051 "Missing required fields: {'mask'}" 1052 ], 1053 [ 1054 lambda: extension_type.AnonymousExtensionType( 1055 values=(1, 2, 3), mask=None), MaskedTensorV2, 1056 'mask: expected a tf.bool Tensor, got None' 1057 ], 1058 [ 1059 lambda: extension_type.AnonymousExtensionType( 1060 values=constant_op.constant([[1, 2], [3, 4]]), 1061 mask=ragged_factory_ops.constant([[1, 2], [3]])), MaskedTensorV2, 1062 'mask: expected a tf.bool Tensor' 1063 ], 1064 [ 1065 lambda: extension_type.AnonymousExtensionType( 1066 values=constant_op.constant([1, 2, 3]), 1067 mask=constant_op.constant([True, False])), MaskedTensorV2, 1068 'Shapes .* are incompatible' 1069 ], 1070 [ 1071 lambda: extension_type.AnonymousExtensionType( 1072 values=constant_op.constant([1, 2, 3])), ops.Tensor, 1073 'Expected `new_type` to be a subclass of tf.ExtensionType' 1074 ], 1075 [ 1076 lambda: constant_op.constant([1, 2, 3]), 1077 extension_type.AnonymousExtensionType, 1078 'Expected `value` to be a tf.ExtensionType' 1079 ], 1080 ]) 1081 def testReinterpretErrors(self, value, new_type, error): 1082 if callable(value): 1083 value = value() 1084 with self.assertRaisesRegex((TypeError, ValueError), error): 1085 extension_type.reinterpret(value, new_type) 1086 1087 def testLoadSavedModelWithUnregisteredExtensionType(self): 1088 1089 def f(x, y): 1090 x_values = x.values if isinstance(x, MaskedTensorV1) else x 1091 y_values = y.values if isinstance(y, MaskedTensorV1) else y 1092 x_mask = x.mask if isinstance(x, MaskedTensorV1) else True 1093 y_mask = y.mask if isinstance(y, MaskedTensorV1) else True 1094 return MaskedTensorV1(x_values + y_values, x_mask & y_mask) 1095 1096 t_spec = tensor_spec.TensorSpec(None, dtypes.int32) 1097 b_spec = tensor_spec.TensorSpec(None, dtypes.bool) 1098 mt_spec = MaskedTensorV1.Spec(values=t_spec, mask=b_spec) 1099 model = module.Module() 1100 model.f = def_function.function(f) 1101 model.f.get_concrete_function(t_spec, t_spec) 1102 model.f.get_concrete_function(t_spec, mt_spec) 1103 model.f.get_concrete_function(mt_spec, t_spec) 1104 model.f.get_concrete_function(mt_spec, mt_spec) 1105 1106 path = tempfile.mkdtemp(prefix=test.get_temp_dir()) 1107 with temporarily_register_type_spec('tf.test.MaskedTensorV1.Spec', 1108 MaskedTensorV1.Spec): 1109 save.save(model, path) 1110 loaded_model = load.load(path) 1111 1112 with self.assertRaises(ValueError): 1113 type_spec.lookup('tf.test.MaskedTensorV1') 1114 1115 t = constant_op.constant([10, 20, 30]) 1116 v1 = loaded_model.f(t, t) 1117 self.assertIsInstance(v1, extension_type.AnonymousExtensionType) 1118 self.assertAllEqual(v1.values, [20, 40, 60]) 1119 self.assertAllEqual(v1.mask, True) 1120 1121 v2 = loaded_model.f(v1, v1) 1122 self.assertIsInstance(v2, extension_type.AnonymousExtensionType) 1123 self.assertAllEqual(v2.values, [40, 80, 120]) 1124 self.assertAllEqual(v2.mask, True) 1125 1126 mt = MaskedTensorV1([1, 2, 3], [True, True, False]) 1127 v3 = loaded_model.f( 1128 t, extension_type.reinterpret(mt, 1129 extension_type.AnonymousExtensionType)) 1130 self.assertIsInstance(v3, extension_type.AnonymousExtensionType) 1131 self.assertAllEqual(v3.values, [11, 22, 33]) 1132 self.assertAllEqual(v3.mask, [True, True, False]) 1133 1134 v4 = extension_type.reinterpret(v3, MaskedTensorV1) 1135 self.assertIsInstance(v4, MaskedTensorV1) 1136 self.assertAllEqual(v4.values, [11, 22, 33]) 1137 self.assertAllEqual(v4.mask, [True, True, False]) 1138 1139 1140def replace_tensors_with_placeholders(value): 1141 1142 def repl(x): 1143 if isinstance(x, ops.Tensor): 1144 return array_ops.placeholder_with_default(x, shape=None) 1145 else: 1146 return x 1147 1148 return nest.map_structure(repl, value, expand_composites=True) 1149 1150 1151@contextlib.contextmanager 1152def temporarily_add_dispatch(op, typ, fn): 1153 n = len(op._tf_dispatchers) 1154 dispatch.dispatch_for_types(op, typ)(fn) 1155 yield 1156 assert len(op._tf_dispatchers) == n + 1 1157 del op._tf_dispatchers[-1] 1158 1159 1160@contextlib.contextmanager 1161def temporarily_register_type_spec(name, cls): 1162 """Context manager for making temporary changes to the TypeSpec registry.""" 1163 type_spec.register(name)(cls) 1164 yield 1165 assert type_spec._TYPE_SPEC_TO_NAME.pop(cls) == name 1166 assert type_spec._NAME_TO_TYPE_SPEC.pop(name) is cls 1167 1168 1169if __name__ == '__main__': 1170 googletest.main() 1171