• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.framework.extension_type."""
16
17import contextlib
18import tempfile
19import typing
20
21from absl.testing import parameterized
22
23from tensorflow.python.eager import context
24from tensorflow.python.eager import def_function
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import extension_type
28from tensorflow.python.framework import extension_type_field
29from tensorflow.python.framework import immutable_dict
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import test_util
33from tensorflow.python.framework import type_spec
34from tensorflow.python.module import module
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops.ragged import ragged_factory_ops
39from tensorflow.python.ops.ragged import ragged_tensor
40from tensorflow.python.platform import googletest
41from tensorflow.python.platform import test
42from tensorflow.python.saved_model import load
43from tensorflow.python.saved_model import save
44from tensorflow.python.util import dispatch
45from tensorflow.python.util import nest
46from tensorflow.python.util import tf_inspect
47
48
49class MaskedTensorV1(extension_type.ExtensionType):
50  """Example subclass of ExtensionType, used for testing."""
51  values: ops.Tensor
52  mask: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)
53
54
55class MaskedTensorV2(extension_type.ExtensionType):
56  """Example subclass of ExtensionType, used for testing.
57
58  This version adds methods, classmethod, staticmethod, and properties, and
59  customizes `__repr__` and `__validate__`.  It also adds a `__name__` field,
60  which enables serialization.
61  """
62  __name__ = 'tf.test.MaskedTensorV2'
63
64  values: ops.Tensor
65  mask: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)
66
67  def __repr__(self):
68    if hasattr(self.values, 'numpy') and hasattr(self.mask, 'numpy'):
69      return '<MaskedTensorV2 %s>' % _masked_array_repr(self.values.numpy(),
70                                                        self.mask.numpy())
71    else:
72      return super(MaskedTensorV2, self).__repr__()
73
74  @property
75  def shape(self):
76    return self.values.shape
77
78  @property
79  def dtype(self):
80    return self.values.dtype
81
82  @classmethod
83  def from_full_tensor(cls, values):
84    return cls(values, array_ops.ones_like(values, dtype=dtypes.bool))
85
86  # A dummy example to test support of staticmethod
87  @staticmethod
88  def doc_link():
89    return 'http://example.com/masked_tensor'
90
91  def __validate__(self):
92    self.values.shape.assert_is_compatible_with(self.mask.shape)
93
94  def with_default(self, default):
95    return array_ops.where_v2(self.mask, self.values, default)
96
97  __add__ = math_ops.add
98  __sub__ = math_ops.subtract
99
100
101def _masked_array_repr(values, mask):
102  """Returns a string representation for a masked numpy array."""
103  assert len(values) == len(mask)
104  if len(values.shape) == 1:
105    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]
106  else:
107    items = [_masked_array_repr(v, m) for (v, m) in zip(values, mask)]
108  return '[%s]' % ', '.join(items)
109
110
111class ForwardRefA(extension_type.ExtensionType):
112  x: typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'], ...]
113  y: 'ForwardRefB'
114
115
116class ForwardRefB(extension_type.ExtensionType):
117  z: 'ForwardRefB'
118  n: ops.Tensor
119
120
121@test_util.run_all_in_graph_and_eager_modes
122class ExtensionTypeTest(test_util.TensorFlowTestCase, parameterized.TestCase):
123
124  def testAttributeAccessors(self):
125    mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
126    mt2 = extension_type.pack(mt1)
127
128    for mt in [mt1, mt2]:
129      self.assertIsInstance(mt.values, ops.Tensor)
130      self.assertAllEqual(mt.values, [1, 2, 3, 4])
131      self.assertIsInstance(mt.mask, ops.Tensor)
132      self.assertAllEqual(mt.mask, [True, True, False, True])
133
134  def testAttributesAreImmutable(self):
135    mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
136    mt2 = extension_type.pack(mt1)
137
138    for mt in [mt1, mt2]:
139      with self.assertRaisesRegex(AttributeError,
140                                  "cannot assign to field 'score'"):
141        mt.score = 12
142      with self.assertRaisesRegex(AttributeError,
143                                  "cannot assign to field 'values'"):
144        mt.values = constant_op.constant([4, 3, 2, 1])
145      with self.assertRaisesRegex(AttributeError,
146                                  "cannot delete field 'values'"):
147        del mt.values
148
149  def testClassAndStaticMethod(self):
150    mt = MaskedTensorV2.from_full_tensor([1, 2, 3, 4])
151    self.assertAllEqual(mt.mask, [True, True, True, True])
152    self.assertEqual(mt.doc_link(), 'http://example.com/masked_tensor')
153
154  def testRepr(self):
155    values = constant_op.constant([1, 2, 3, 4])
156    mask = constant_op.constant([True, True, False, True])
157    mt = MaskedTensorV1(values, mask)
158    expected = f'MaskedTensorV1(values={values!r}, mask={mask!r})'
159    self.assertEqual(expected, repr(mt))
160
161  def testEagerRepr(self):
162    values = constant_op.constant([1, 2, 3, 4])
163    mask = constant_op.constant([True, True, False, True])
164    mt = MaskedTensorV2(values, mask)
165    if context.executing_eagerly():
166      expected = '<MaskedTensorV2 [1, 2, _, 4]>'
167    else:
168      expected = f'MaskedTensorV2(values={values!r}, mask={mask!r})'
169
170    self.assertEqual(expected, repr(mt))
171    self.assertEqual(expected, repr(mt))
172
173  def testConstructorSignature(self):
174
175    class MyType(extension_type.ExtensionType):
176      x: ops.Tensor
177      y: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)
178      z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3]
179
180    expected_parameters = [
181        tf_inspect.Parameter('self',
182                             tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
183        tf_inspect.Parameter(
184            'x',
185            tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
186            annotation=ops.Tensor),
187        tf_inspect.Parameter(
188            'y',
189            tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
190            annotation=tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)),
191        tf_inspect.Parameter(
192            'z',
193            tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
194            annotation=typing.Tuple[typing.Union[int, str], ...],
195            default=(1, 'two', 3)),
196    ]
197    expected_sig = tf_inspect.Signature(
198        expected_parameters, return_annotation=MyType)
199    self.assertEqual(expected_sig, tf_inspect.signature(MyType.__init__))
200
201  def testEmptyType(self):
202
203    class EmptyType(extension_type.ExtensionType):
204      pass
205
206    self.assertEmpty(EmptyType._tf_extension_type_fields())
207    x = EmptyType()
208    self.assertEqual(repr(x), 'EmptyType()')
209
210  def testCustomConstrutor(self):
211
212    class SummarizedTensor(extension_type.ExtensionType):
213      values: ops.Tensor
214      mean: ops.Tensor
215      max: ops.Tensor
216
217      def __init__(self, values):
218        self.values = ops.convert_to_tensor(values)
219        self.mean = math_ops.reduce_mean(values)
220        self.max = math_ops.reduce_max(values)
221
222    x = SummarizedTensor([[1.0, 2, 3], [4, 5, 6]])
223    self.assertAllEqual(x.values, [[1.0, 2, 3], [4, 5, 6]])
224    self.assertAllEqual(x.mean, 3.5)
225    self.assertAllEqual(x.max, 6)
226
227  class Node(extension_type.ExtensionType):
228    x: ops.Tensor
229    y: typing.Optional[str] = None
230    children: typing.Tuple['ExtensionTypeTest.Node', ...] = ()
231
232  def testCustomConstructorWithDefaultValues(self):
233    a = ExtensionTypeTest.Node(5)
234    self.assertAllEqual(a.x, 5)
235    self.assertIsNone(a.y)
236    self.assertEqual(a.children, ())
237
238    b = ExtensionTypeTest.Node(6, 'blue')
239    self.assertAllEqual(b.x, 6)
240    self.assertEqual(b.y, 'blue')
241    self.assertEqual(b.children, ())
242
243    c = ExtensionTypeTest.Node(7, children=(a, b))
244    self.assertAllEqual(c.x, 7)
245    self.assertIsNone(c.y)
246    self.assertEqual(c.children, (a, b))
247
248  def testCustomConstructorNondefaultCanotFollowDefault(self):
249    with self.assertRaisesRegex(
250        ValueError, "Field without default 'd' follows field with default 'c'"):
251
252      class MyType(extension_type.ExtensionType):
253        a: int
254        b: str = 'Hello world'
255        c: typing.Optional[ops.Tensor] = None
256        d: ops.Tensor
257
258      del MyType
259
260  def testCustomConstrutorCantMutateNestedValues(self):
261
262    class Foo(extension_type.ExtensionType):
263      x: int
264
265    class Bar(extension_type.ExtensionType):
266      foo: Foo
267
268      def __init__(self, foo):
269        foo.x = 33  # This raises an exception
270
271    with self.assertRaisesRegex(AttributeError, "cannot assign to field 'x'"):
272      Bar(Foo(12))
273
274  def testCustomValidate(self):
275
276    class AlignedTensors(extension_type.ExtensionType):
277      x: ops.Tensor
278      y: ops.Tensor
279
280      def __validate__(self):
281        self.x.shape.assert_is_compatible_with(self.y.shape)
282
283    aligned = AlignedTensors([1, 2, 3], ['a', 'b', 'c'])
284    self.assertAllEqual(aligned.x, [1, 2, 3])
285    self.assertAllEqual(aligned.y, [b'a', b'b', b'c'])
286
287    with self.assertRaises(ValueError):
288      AlignedTensors([1, 2, 3], ['a', 'b', 'c', 'd'])
289
290  def testEquals(self):
291
292    class MyType(extension_type.ExtensionType):
293      values: ops.Tensor
294      score: ops.Tensor
295      flavor: str
296
297    x1 = MyType([1, 2], 8, 'blue')
298    x2 = MyType([1, 2], 8, 'blue')
299    y = MyType([1, 2], 8, 'red')
300    z = MyType([1, 2], 7, 'blue')
301    self.assertAllEqual(x1 == x2, True)
302    self.assertAllEqual(x1 != x2, False)
303    self.assertAllEqual(x1 == y, False)
304    self.assertAllEqual(x1 != y, True)
305    self.assertAllEqual(x1 == z, False)
306    self.assertAllEqual(y == z, False)
307
308    # These are not equal, even though their values are broadcast-compatible
309    # and elements are all equal when we broadcast.  Shapes must match.
310    a = MyType([1, 1, 1, 1], 0, 'x')
311    b = MyType([[1, 1, 1, 1]], 0, 'x')
312    c = MyType([[1, 1], [1, 1]], 0, 'x')
313    self.assertAllEqual(a == b, False)
314    self.assertAllEqual(a == c, False)
315    self.assertAllEqual(b == c, False)
316
317    # Test with unknown shapes (executes a different codepath).
318    a_ph = replace_tensors_with_placeholders(a)
319    b_ph = replace_tensors_with_placeholders(b)
320    c_ph = replace_tensors_with_placeholders(c)
321    self.assertAllEqual(a_ph == b_ph, False)
322    self.assertAllEqual(a_ph == c_ph, False)
323    self.assertAllEqual(b_ph == c_ph, False)
324
325  def testPassIntoTfFunction(self):
326
327    @def_function.function
328    def fn(x):
329      return x.with_default(99)
330
331    mt = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
332    self.assertAllEqual([1, 2, 99, 4], fn(mt))
333    self.assertAllEqual([1, 2, 99, 4], fn(extension_type.pack(mt)))
334
335  def testReturnFromTfFunction(self):
336
337    @def_function.function
338    def mask_neg_values(x):
339      return MaskedTensorV2(x, x > 0)
340
341    @def_function.function
342    def mask_neg_values_packed(x):
343      return extension_type.pack(MaskedTensorV2(x, x > 0))
344
345    expected = MaskedTensorV2([5, 8, -3, 9], [True, True, False, True])
346
347    actual1 = mask_neg_values(constant_op.constant([5, 8, -3, 9]))
348    self.assertIsInstance(actual1, MaskedTensorV2)
349    self.assertAllEqual(expected.values, actual1.values)
350    self.assertAllEqual(expected.mask, actual1.mask)
351
352    actual2 = mask_neg_values_packed(constant_op.constant([5, 8, -3, 9]))
353    self.assertIsInstance(actual2, MaskedTensorV2)
354    self.assertTrue(extension_type.is_packed(actual2))
355    self.assertAllEqual(expected.values, actual2.values)
356    self.assertAllEqual(expected.mask, actual2.mask)
357
358  def testCaptureByTfFunction(self):
359    x = MaskedTensorV2(
360        values=[[1, 2, 3], [4, 5, 6]],
361        mask=[[True, True, True], [True, False, True]])
362
363    @def_function.function
364    def add_to_x(y):
365      return MaskedTensorV2(x.values + y.values, x.mask & y.mask)
366
367    actual = add_to_x(MaskedTensorV2([10, 20, 30], [False, True, True]))
368    expected = MaskedTensorV2(
369        values=[[11, 22, 33], [14, 25, 36]],
370        mask=[[False, True, True], [False, False, True]])
371    self.assertIsInstance(actual, MaskedTensorV2)
372    self.assertAllEqual(expected.values, actual.values)
373    self.assertAllEqual(expected.mask, actual.mask)
374
375  def testTfFunctionArgMutationError(self):
376
377    @def_function.function
378    def fn_with_side_effect(mts):
379      mts.append(MaskedTensorV1(mts[0].values * 2, mts[0].mask))
380
381    with self.assertRaisesRegex(ValueError, 'should not modify'):
382      fn_with_side_effect([MaskedTensorV1([10, 20, 30], [False, True, True])])
383
384  def testNestPackUnpack(self):
385
386    class CandyStore(extension_type.ExtensionType):
387      name: ops.Tensor
388      prices: typing.Mapping[str, ops.Tensor]
389
390    store = CandyStore('Yum', {'gum': [0.42, 0.48], 'chocolate': [0.83, 1.02]})
391    components = nest.flatten(store, expand_composites=True)
392    repacked_1 = nest.pack_sequence_as(
393        store, components, expand_composites=True)
394    repacked_2 = nest.pack_sequence_as(
395        store._type_spec, components, expand_composites=True)
396
397    # Note: dicts get sorted by key.
398    self.assertLen(components, 3)
399    self.assertAllEqual(components[0], b'Yum')
400    self.assertAllClose(components[1], [0.83, 1.02])
401    self.assertAllClose(components[2], [0.42, 0.48])
402
403    for repacked in [repacked_1, repacked_2]:
404      self.assertAllEqual(repacked.name, b'Yum')
405      self.assertAllClose(repacked.prices['gum'], [0.42, 0.48])
406      self.assertAllClose(repacked.prices['chocolate'], [0.83, 1.02])
407
408  def testSimpleCond(self):
409    x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False])
410    y = MaskedTensorV1([5, 6, 7, 8], [False, True, True, False])
411
412    x_2 = control_flow_ops.cond(
413        constant_op.constant(True), lambda: x, lambda: y)
414    y_2 = control_flow_ops.cond(
415        constant_op.constant(False), lambda: x, lambda: y)
416
417    self.assertAllEqual(x.values, x_2.values)
418    self.assertAllEqual(x.mask, x_2.mask)
419    self.assertAllEqual(y.values, y_2.values)
420    self.assertAllEqual(y.mask, y_2.mask)
421
422  def testComplexCond(self):
423    mt = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False])
424
425    def true_fn():
426      return MaskedTensorV1(
427          array_ops.where_v2(mt.mask, mt.values, -1), mt.values > 3)
428
429    def false_fn():
430      return MaskedTensorV1(
431          array_ops.where_v2(mt.mask, 100, mt.values * 2),
432          math_ops.logical_not(mt.mask))
433
434    x = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn)
435    y = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
436
437    self.assertAllEqual(x.values, [1, -1, 3, -1])
438    self.assertAllEqual(x.mask, [False, False, False, True])
439    self.assertAllEqual(y.values, [100, 4, 100, 8])
440    self.assertAllEqual(y.mask, [False, True, False, True])
441
442  def testCondAutograph(self):
443
444    @def_function.function
445    def fn(mt):
446      if mt.values[3] > 3:
447        return MaskedTensorV1(
448            array_ops.where_v2(mt.mask, mt.values, -1), mt.values > 3)
449      else:
450        return MaskedTensorV1(
451            array_ops.where_v2(mt.mask, 100, mt.values * 2), not mt.mask)
452
453    x = fn(MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]))
454    self.assertAllEqual(x.values, [1, -1, 3, -1])
455    self.assertAllEqual(x.mask, [False, False, False, True])
456
457  def testCondTypeMismatch(self):
458    if context.executing_eagerly:
459      # In eager mode, tf.cond eagerly runs either true_fn or false_fn, and
460      # ignores the other one; so it doesn't detect any type mismatches
461      # between the two outcomes.  (See _eager_cond_implementation in
462      # control_flow_ops.py.)
463      return
464
465    a = lambda: MaskedTensorV1([1, 2, 3], [True, True, False])
466    b = lambda: MaskedTensorV1(['a', 'b', 'c'], [False, True, True])
467    c = lambda: MaskedTensorV2([4, 5, 6], [True, True, False])
468    d = lambda: constant_op.constant([7, 8, 9])
469
470    with self.assertRaisesRegex(
471        ValueError,
472        'Incompatible return values of true_fn and false_fn: The two '
473        "structures don't have the same nested structure"):
474      control_flow_ops.cond(constant_op.constant(True), a, b)
475    with self.assertRaisesRegex(
476        TypeError, 'Incompatible return types of true_fn and false_fn: The two '
477        "structures don't have the same nested structure"):
478      control_flow_ops.cond(constant_op.constant(True), a, c)
479    with self.assertRaisesRegex(
480        ValueError,
481        'Incompatible return values of true_fn and false_fn: The two '
482        "structures don't have the same nested structure"):
483      control_flow_ops.cond(constant_op.constant(True), a, d)
484
485  def testCondPacked(self):
486    x = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
487    y = MaskedTensorV2([5, 6, 7, 8], [False, True, True, False])
488    x = extension_type.pack(x)
489    y = extension_type.pack(y)
490
491    x_2 = control_flow_ops.cond(
492        constant_op.constant(True), lambda: x, lambda: y)
493    y_2 = control_flow_ops.cond(
494        constant_op.constant(False), lambda: x, lambda: y)
495
496    self.assertAllEqual(x.values, x_2.values)
497    self.assertAllEqual(x.mask, x_2.mask)
498    self.assertAllEqual(y.values, y_2.values)
499    self.assertAllEqual(y.mask, y_2.mask)
500
501    a = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
502    b = extension_type.pack(a)
503    b = control_flow_ops.cond(
504        constant_op.constant(True), lambda: array_ops.size(a.mask),
505        lambda: array_ops.size(a.values))
506    self.assertAllEqual(b, 4)
507
508    # Note: the following example would fail (with `Retval[0] does not have a
509    # value`) if `ExtensionType.__getattr__` cached the results of unpacking
510    # the value.  See the comment in `ExtensionType.__getattr__` for details.
511    c = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
512    c = extension_type.pack(c)
513    d = control_flow_ops.cond(
514        constant_op.constant(False), lambda: array_ops.size(c.mask),
515        lambda: array_ops.size(c.values))
516    self.assertAllEqual(d, 4)
517
518  def testWhileLoop(self):
519    x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False])
520
521    cond = lambda i, x: i < 10
522    body = lambda i, x: (i + 1, MaskedTensorV1(x.values * 2, x.mask))
523    _, y = control_flow_ops.while_loop_v2(cond, body, [0, x])
524
525    self.assertIsInstance(y, MaskedTensorV1)
526    self.assertAllEqual(y.values, [1024, 2048, 3072, 4096])
527    self.assertAllEqual(y.mask, [True, False, True, False])
528
529  def testWhileLoopAutograph(self):
530
531    @def_function.function
532    def fn(x, n):
533      for _ in math_ops.range(n):
534        x = MaskedTensorV1(x.values * 2, x.mask)
535      return x
536
537    y = fn(MaskedTensorV1([1, 2, 3, 4], [True, False, True, False]), 10)
538    self.assertIsInstance(y, MaskedTensorV1)
539    self.assertAllEqual(y.values, [1024, 2048, 3072, 4096])
540    self.assertAllEqual(y.mask, [True, False, True, False])
541
542  def testWhileLoopTypeMismatch(self):
543    x = MaskedTensorV1([1, 2, 3, 4], [True, False, True, False])
544
545    cond = lambda i, x: i < 10
546
547    def body(i, x):
548      if isinstance(x, MaskedTensorV1):
549        return x.values * 2
550      else:
551        return MaskedTensorV1(x, x > i)
552
553    with self.assertRaisesRegex(
554        ValueError, "The two structures don't have the same nested structure"):
555      control_flow_ops.while_loop_v2(cond, body, [0, x])
556
557  def testWhileLoopPacked(self):
558    x = MaskedTensorV2([1, 2, 3, 4], [True, False, True, False])
559    x = extension_type.pack(x)
560    cond = lambda i, x: i < 10
561
562    def body(i, x):
563      return i + 1, extension_type.pack(MaskedTensorV2(x.values * 2, x.mask))
564
565    _, y = control_flow_ops.while_loop_v2(cond, body, [0, x])
566    self.assertIsInstance(y, MaskedTensorV2)
567    self.assertAllEqual(y.values, [1024, 2048, 3072, 4096])
568    self.assertAllEqual(y.mask, [True, False, True, False])
569
570  def testNestedFields(self):
571    PossiblyRaggedTensor = typing.Union[ops.Tensor, ragged_tensor.RaggedTensor]
572    ToyFeatures = typing.Mapping[str, PossiblyRaggedTensor]
573
574    class ToyInfo(extension_type.ExtensionType):
575      version: str
576      toys: typing.Tuple[typing.Tuple[str, ops.Tensor, ToyFeatures], ...]
577      boxes: typing.Mapping[str, ops.Tensor]
578
579    authors = [[b'A', b'Aardvark'], [b'Z', b'Zhook']]
580    toys = [('car', 1.0, {
581        'size': [8, 3, 2],
582        'color': [0.3, 0.2, 0.8]
583    }), ('book', 3.7, {
584        'authors': ragged_factory_ops.constant(authors)
585    })]
586    boxes = {'green': ['car'], 'blue': ['car', 'book', 'book']}
587    toy_info = ToyInfo(version='1.0 alpha', toys=toys, boxes=boxes)
588
589    self.assertEqual(toy_info.version, '1.0 alpha')
590    self.assertEqual(toy_info.toys[0][0], 'car')
591    self.assertIsInstance(toy_info.toys[0][1], ops.Tensor)
592    self.assertAllEqual(toy_info.toys[0][1], 1.0)
593    self.assertEqual(set(toy_info.toys[0][2].keys()), {'size', 'color'})
594    self.assertIsInstance(toy_info.toys[0][2]['size'], ops.Tensor)
595    self.assertAllEqual(toy_info.toys[0][2]['size'], [8, 3, 2])
596    self.assertIsInstance(toy_info.toys[1][2]['authors'],
597                          ragged_tensor.RaggedTensor)
598    self.assertAllEqual(toy_info.toys[1][2]['authors'], authors)
599    self.assertAllEqual(toy_info.boxes['green'], [b'car'])
600    self.assertAllEqual(toy_info.boxes['blue'], ['car', 'book', 'book'])
601
602    expected_repr = (
603        r"ToyInfo\(version='1.0 alpha', toys=\("
604        r"\('car', <tf.Tensor[^>]*>, ImmutableDict\("
605        r"{'size': <tf.Tensor[^>]*>, 'color': <tf.Tensor[^>]*>}\)\), "
606        r"\('book', <tf.Tensor[^>]*>, ImmutableDict\("
607        r"{'authors': (<tf.RaggedTensor[^>]*>|tf.RaggedTensor\(.*\))}\)\)\), "
608        r'boxes=ImmutableDict\('
609        r"{'green': <tf.Tensor[^>]*>, 'blue': <tf.Tensor[^>]*>}\)\)")
610
611    self.assertRegex(repr(toy_info), expected_repr)
612
613  def testNestedExtensionTypes(self):
614    PossiblyMaskedTensor = typing.Union[ops.Tensor, MaskedTensorV1]
615
616    class Toy(extension_type.ExtensionType):
617      name: str
618      price: ops.Tensor
619      features: typing.Mapping[str, PossiblyMaskedTensor]
620
621    class Box(extension_type.ExtensionType):
622      contents: ops.Tensor
623
624    class ToyInfo(extension_type.ExtensionType):
625      version: str
626      toys: typing.Tuple[Toy, ...]
627      boxes: typing.Mapping[str, Box]
628
629    authors = MaskedTensorV1(
630        values=[[b'A', b'Quincy', b'Aardvark'], [b'Z', b'Zhook', b'']],
631        mask=[[True, True, True], [True, True, False]])
632    toys = [
633        Toy('car', 1.0, {
634            'size': [8, 3, 2],
635            'color': [0.3, 0.2, 0.8]
636        }),
637        Toy(name='book', price=3.7, features={'authors': authors})
638    ]
639    boxes = {
640        'green': Box(['car']),
641        'blue': Box(contents=['car', 'book', 'book'])
642    }
643    toy_info = ToyInfo(version='1.0 alpha', toys=toys, boxes=boxes)
644
645    @def_function.function
646    def fn(info):
647      prices = [toy.price for toy in info.toys]
648      return math_ops.reduce_sum(array_ops.stack(prices))
649
650    self.assertAllClose(fn(toy_info), 4.7)
651
652  def testNestedCustomConstructor(self):
653
654    class Toy(extension_type.ExtensionType):
655      name: str
656      price: ops.Tensor
657
658      def __init__(self, name, price, discount=0):
659        if discount:
660          name += ' (discounted)'
661          price *= (1 - discount)
662        self.name = name
663        self.price = price
664
665    class ToyBox(extension_type.ExtensionType):
666      toys: typing.Tuple[Toy, ...]
667
668      def __init__(self, name_to_price, name_to_discount):
669        self.toys = [
670            Toy(name, price, name_to_discount.get(name, 0))
671            for (name, price) in name_to_price.items()
672        ]
673
674    toy_box = ToyBox({
675        'car': 8.3,
676        'truck': 5.9,
677        'puzzle': 5.3,
678        'jacks': 2.8
679    }, {
680        'puzzle': .2,
681        'truck': .3
682    })
683    self.assertLen(toy_box.toys, 4)
684    self.assertEqual(
685        set(toy.name for toy in toy_box.toys),
686        {'car', 'truck (discounted)', 'puzzle (discounted)', 'jacks'})
687
688  def testExtensionTypeWithMathOperators(self):
689
690    def masked_add(x, y, name=None):
691      del name
692      if not isinstance(x, MaskedTensorV2) and isinstance(y, MaskedTensorV2):
693        return dispatch.OpDispatcher.NOT_SUPPORTED
694      return MaskedTensorV2(x.values + y.values, x.mask & y.mask)
695
696    with temporarily_add_dispatch(math_ops.add, MaskedTensorV2, masked_add):
697      x = MaskedTensorV2([[1, 2], [3, 4]], [[True, False], [True, True]])
698      y = MaskedTensorV2([[3, 4], [5, 6]], [[True, True], [False, True]])
699      z = x + y
700      self.assertAllEqual(z.values, [[4, 6], [8, 10]])
701      self.assertAllEqual(z.mask, [[True, False], [False, True]])
702
703  def testGetExtensionTypeFields(self):
704
705    # Can be called on a type or an instance:
706    fields_1 = MaskedTensorV1._tf_extension_type_fields()
707    fields_2 = MaskedTensorV1([0], [True])._tf_extension_type_fields()
708
709    for fields in [fields_1, fields_2]:
710      self.assertLen(fields, 2)
711      self.assertEqual(fields[0].name, 'values')
712      self.assertEqual(fields[0].value_type, ops.Tensor)
713      self.assertEqual(fields[0].default, fields[0].NO_DEFAULT)
714      self.assertEqual(fields[1].name, 'mask')
715      self.assertEqual(fields[1].value_type,
716                       tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool))
717      self.assertEqual(fields[1].default, fields[0].NO_DEFAULT)
718
719  def testHasExtensionTypeField(self):
720
721    self.assertTrue(MaskedTensorV1._tf_extension_type_has_field('values'))
722    self.assertTrue(MaskedTensorV1._tf_extension_type_has_field('mask'))
723    self.assertFalse(MaskedTensorV1._tf_extension_type_has_field('labels'))
724
725    mt = MaskedTensorV1([0], [True])
726    self.assertTrue(mt._tf_extension_type_has_field('values'))
727    self.assertTrue(mt._tf_extension_type_has_field('mask'))
728    self.assertFalse(mt._tf_extension_type_has_field('labels'))
729
730  def testForwardReferences(self):
731    A, B = ForwardRefA, ForwardRefB
732
733    self.assertEqual(A._tf_extension_type_fields(),
734                     (extension_type_field.ExtensionTypeField(
735                         'x', typing.Tuple[typing.Union[A, B], ...]),
736                      extension_type_field.ExtensionTypeField('y', B)))
737    self.assertEqual(B._tf_extension_type_fields(),
738                     (extension_type_field.ExtensionTypeField('z', B),
739                      extension_type_field.ExtensionTypeField('n', ops.Tensor)))
740
741    # Check the signature.
742    expected_parameters = [
743        tf_inspect.Parameter('self',
744                             tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
745        tf_inspect.Parameter(
746            'x',
747            tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
748            annotation=typing.Tuple[typing.Union['ForwardRefA', 'ForwardRefB'],
749                                    ...]),
750        tf_inspect.Parameter(
751            'y',
752            tf_inspect.Parameter.POSITIONAL_OR_KEYWORD,
753            annotation='ForwardRefB'),
754    ]
755    expected_sig = tf_inspect.Signature(
756        expected_parameters, return_annotation=A)
757    self.assertEqual(tf_inspect.signature(A.__init__), expected_sig)
758
759  def testUnresolvedForwardReference(self):
760
761    class Broken(extension_type.ExtensionType):
762      x: 'Cra'  # note: intentional typo for Car.
763
764    class Car(extension_type.ExtensionType):
765      speed: float
766
767    with self.assertRaises(TypeError):
768      Broken(x=Car(3.8))
769
770  def testUnsupportedAnnotations(self):
771    with self.assertRaisesRegex(
772        TypeError, "In field 'values': Unsupported type annotation"):
773
774      class MyType1(extension_type.ExtensionType):  # pylint: disable=unused-variable
775        values: typing.List[ops.Tensor]
776
777    with self.assertRaisesRegex(TypeError,
778                                "In field 'xyz': Unsupported type annotation"):
779
780      class MyType2(extension_type.ExtensionType):  # pylint: disable=unused-variable
781        xyz: typing.Union[typing.Tuple[complex, ...], int]
782
783  def testExtensionTypeBaseClassHasNoSpec(self):
784    self.assertFalse(hasattr(extension_type.ExtensionType, 'Spec'))
785
786  def testExtensionTypeBaseConstructorRaisesException(self):
787    with self.assertRaisesRegex(AssertionError,
788                                'ExtensionType is an abstract base class.'):
789      extension_type.ExtensionType()
790
791  class ExtensionTypeWithName(extension_type.ExtensionType):
792    __name__ = 'tf.__test__.ExtensionTypeWithName'  # For SavedModel
793    x: typing.Tuple[ops.Tensor, int]
794    y: ops.Tensor
795
796  def testSavedModelSupport(self):
797
798    class TestModule(module.Module):
799
800      @def_function.function
801      def f(self, s):
802        return s.x[0] + s.x[1] + s.y
803
804    s1 = self.ExtensionTypeWithName((1, 2), 3)
805    s2 = self.ExtensionTypeWithName((1.0, 2), [3.0, 4.0])
806
807    m = TestModule()
808    m.f.get_concrete_function(s1)
809    m.f.get_concrete_function(s2)
810
811    path = tempfile.mkdtemp(prefix=test.get_temp_dir())
812    save.save(m, path)
813    loaded = load.load(path)
814
815    self.assertAllEqual(loaded.f(s1), 6)
816    self.assertAllEqual(loaded.f(s2), [6.0, 7.0])
817
818  def testPackedEncoding(self):
819    mt1 = MaskedTensorV2([1, 2, 3, 4], [True, True, False, True])
820    self.assertLen(nest.flatten(mt1, expand_composites=True), 2)
821
822    mt2 = extension_type.pack(mt1)
823    self.assertLen(nest.flatten(mt2, expand_composites=True), 1)
824    self.assertIsInstance(mt2.values, ops.Tensor)
825    self.assertAllEqual(mt2.values, [1, 2, 3, 4])
826    self.assertIsInstance(mt2.mask, ops.Tensor)
827    self.assertAllEqual(mt2.mask, [True, True, False, True])
828
829    mt3 = extension_type.unpack(mt2)
830    self.assertLen(nest.flatten(mt3, expand_composites=True), 2)
831    self.assertIsInstance(mt3.values, ops.Tensor)
832    self.assertAllEqual(mt3.values, [1, 2, 3, 4])
833    self.assertIsInstance(mt3.mask, ops.Tensor)
834    self.assertAllEqual(mt3.mask, [True, True, False, True])
835
836    nest.assert_same_structure(mt1, mt3, expand_composites=True)
837    with self.assertRaisesRegex(ValueError, "don't have the same"):  # pylint: disable=g-error-prone-assert-raises
838      nest.assert_same_structure(mt1, mt2, expand_composites=True)
839
840    mt4 = MaskedTensorV1([1, 2, 3, 4], [True, True, False, True])
841    with self.assertRaisesRegex(
842        ValueError,
843        'ExtensionTypes must have a __name__ field in order to be packed.'):
844      extension_type.pack(mt4)
845
846
847@test_util.run_all_in_graph_and_eager_modes
848class ExtensionTypeSpecTest(test_util.TensorFlowTestCase,
849                            parameterized.TestCase):
850
851  def testSpecConstructor(self):
852    values_spec = tensor_spec.TensorSpec([4], dtypes.float32)
853    mask_spec = tensor_spec.TensorSpec([4], dtypes.bool)
854    mt_spec = MaskedTensorV1.Spec(values_spec, mask_spec)
855    self.assertEqual(mt_spec.values, values_spec)
856    self.assertEqual(mt_spec.mask, mask_spec)
857
858    mt = MaskedTensorV1([1.0, 2.0, 3.0, 4.0], [True, True, False, True])
859    self.assertEqual(mt._type_spec, mt_spec)
860
861  def testSpecConstructorSignature(self):
862
863    class MyType(extension_type.ExtensionType):
864      x: ops.Tensor
865      y: tensor_spec.TensorSpec(shape=None, dtype=dtypes.bool)
866      z: typing.Tuple[typing.Union[int, str], ...] = [1, 'two', 3]
867
868    expected_parameters = [
869        tf_inspect.Parameter('self',
870                             tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
871        tf_inspect.Parameter('x', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
872        tf_inspect.Parameter('y', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
873        tf_inspect.Parameter('z', tf_inspect.Parameter.POSITIONAL_OR_KEYWORD),
874    ]
875    expected_sig = tf_inspect.Signature(
876        expected_parameters, return_annotation=MyType.Spec)
877    self.assertEqual(expected_sig, tf_inspect.signature(MyType.Spec.__init__))
878
879  def testSpecAttributesAreImmutable(self):
880    mt = MaskedTensorV1([1, 2, 3, 4], [True, True, False, True])
881    mt_spec = MaskedTensorV1.Spec.from_value(mt)
882    with self.assertRaisesRegex(AttributeError,
883                                "cannot assign to field 'score'"):
884      mt_spec.score = 12
885    with self.assertRaisesRegex(AttributeError,
886                                "cannot assign to field 'values'"):
887      mt_spec.values = constant_op.constant([4, 3, 2, 1])
888    with self.assertRaisesRegex(AttributeError, "cannot delete field 'values'"):
889      del mt_spec.values
890
891  def testSpecFromValue(self):
892    mt = MaskedTensorV1([1.0, 2.0, 3.0, 4.0], [True, True, False, True])
893    mt_spec = MaskedTensorV1.Spec.from_value(mt)
894
895    expected_values_spec = tensor_spec.TensorSpec([4], dtypes.float32)
896    expected_mask_spec = tensor_spec.TensorSpec([4], dtypes.bool)
897    self.assertEqual(mt_spec.values, expected_values_spec)
898    self.assertEqual(mt_spec.mask, expected_mask_spec)
899
900  def testSpecSerialize(self):
901
902    class Zoo(extension_type.ExtensionType):
903      zookeepers: typing.Tuple[str, ...]
904      animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]]
905
906    featurespec = {
907        'size': tensor_spec.TensorSpec([3]),
908        'weight': tensor_spec.TensorSpec([])
909    }
910    zoo_spec = Zoo.Spec(
911        zookeepers=['Zoey', 'Zack'],
912        animals={
913            'tiger': featurespec,
914            'elephant': featurespec
915        })
916
917    serialized = zoo_spec._serialize()
918    self.assertEqual(serialized,
919                     (('zookeepers', ('Zoey', 'Zack')), ('animals', {
920                         'tiger': featurespec,
921                         'elephant': featurespec
922                     })))
923    restored = Zoo.Spec._deserialize(serialized)
924    self.assertEqual(zoo_spec, restored)
925
926    # ImmutableDict is used for the field, but dict for the serialization:
927    self.assertIsInstance(zoo_spec.animals, immutable_dict.ImmutableDict)
928    serialized_field_name, serialized_field_value = serialized[1]
929    self.assertEqual(serialized_field_name, 'animals')
930    self.assertIsInstance(serialized_field_value, dict)
931
932  def testSpecComponents(self):
933
934    class Zoo(extension_type.ExtensionType):
935      zookeepers: typing.Tuple[str, ...]
936      animals: typing.Mapping[str, typing.Mapping[str, ops.Tensor]]
937
938    zoo = Zoo(
939        ['Zoey', 'Zack'], {
940            'elephant': {
941                'size': [25, 30, 20],
942                'weight': 2000.0
943            },
944            'tiger': {
945                'hunger': 3.2,
946                'size': [3, 8, 2],
947                'weight': 87.3
948            }
949        })
950    zoo_spec = Zoo.Spec.from_value(zoo)
951
952    components = zoo_spec._to_components(zoo)
953    self.assertLen(components, 5)
954    self.assertAllClose(components[0], [25, 30, 20])
955    self.assertAllClose(components[1], 2000.0)
956    self.assertAllClose(components[2], 3.2)
957    self.assertAllClose(components[3], [3, 8, 2])
958    self.assertAllClose(components[4], 87.3)
959
960    restored = zoo_spec._from_components(components)
961    self.assertAllEqual(zoo == restored, True)
962
963    self.assertEqual(zoo_spec._component_specs,
964                     (tensor_spec.TensorSpec([3], dtypes.int32),
965                      tensor_spec.TensorSpec([], dtypes.float32),
966                      tensor_spec.TensorSpec([], dtypes.float32),
967                      tensor_spec.TensorSpec([3], dtypes.int32),
968                      tensor_spec.TensorSpec([], dtypes.float32)))
969
970
971@test_util.run_all_in_graph_and_eager_modes
972class AnonymousExtensionTypeTest(test_util.TensorFlowTestCase,
973                                 parameterized.TestCase):
974
975  @parameterized.parameters([
976      [dict(i=5, f=3.2, b=True, n=None)],
977      [dict(x=(1, 2), y={
978          3: 4,
979          5: 6
980      })],
981      [lambda: dict(t=constant_op.constant(123))],
982      [lambda: dict(r=ragged_factory_ops.constant([[1, 2], [3]]))],
983  ])
984  def testConstruction(self, fields):
985    if callable(fields):
986      fields = fields()
987    extension_type.AnonymousExtensionType(**fields)
988
989  @parameterized.parameters([
990      [dict(x=[1, 2, 3]), 'Unsupported field value'],
991      [dict(x=set([1, 2])), 'Unsupported field value'],
992      [dict(x=(1, dict([(2, [])]))), 'Unsupported field value'],
993      [
994          dict(_tf_extension_type_xyz=5),
995          "The field name '_tf_extension_type_xyz' is reserved"
996      ],
997  ])
998  def testConstructionErrors(self, fields, error):
999    with self.assertRaisesRegex(ValueError, error):
1000      extension_type.AnonymousExtensionType(**fields)
1001
1002  @parameterized.parameters([
1003      [dict(i=5, f=3.2, b=True, n=None)],
1004      [dict(x=(1, 2), y={
1005          3: 4,
1006          5: 6
1007      })],
1008      [lambda: dict(t=constant_op.constant(123))],
1009      [lambda: dict(r=ragged_factory_ops.constant([[1, 2], [3]]))],
1010  ])
1011  def testAttributeAccessors(self, fields):
1012    if callable(fields):
1013      fields = fields()
1014    s = extension_type.AnonymousExtensionType(**fields)
1015    for (name, value) in fields.items():
1016      actual = getattr(s, name)
1017      if isinstance(actual, (ops.Tensor, ragged_tensor.RaggedTensor)):
1018        self.assertAllEqual(actual, value)
1019      else:
1020        self.assertEqual(actual, value)
1021
1022  def testAttributeAccessorsAreImmutable(self):
1023    s = extension_type.AnonymousExtensionType(x=12, y={'x': 55})
1024    with self.assertRaisesRegex(AttributeError, "cannot assign to field 'x'"):
1025      s.x = 22
1026    with self.assertRaisesRegex(AttributeError, "cannot delete field 'y'"):
1027      del s.y
1028    with self.assertRaisesRegex(TypeError, 'does not support item assignment'):
1029      s.y['x'] = 66
1030
1031  def testReinterpret(self):
1032    x = MaskedTensorV2([4, 5], [True, False])
1033    anon_x = extension_type.reinterpret(x,
1034                                        extension_type.AnonymousExtensionType)
1035    self.assertAllEqual(anon_x.values, [4, 5])
1036    self.assertAllEqual(anon_x.mask, [True, False])
1037
1038    round_trip_x = extension_type.reinterpret(anon_x, MaskedTensorV2)
1039    self.assertAllEqual(round_trip_x.values, [4, 5])
1040    self.assertAllEqual(round_trip_x.mask, [True, False])
1041
1042    converted_x = extension_type.reinterpret(anon_x, MaskedTensorV1)
1043    self.assertAllEqual(converted_x.values, [4, 5])
1044    self.assertAllEqual(converted_x.mask, [True, False])
1045
1046  # pylint: disable=g-long-lambda
1047  @parameterized.parameters([
1048      [
1049          lambda: extension_type.AnonymousExtensionType(
1050              values=constant_op.constant([1, 2, 3])), MaskedTensorV2,
1051          "Missing required fields: {'mask'}"
1052      ],
1053      [
1054          lambda: extension_type.AnonymousExtensionType(
1055              values=(1, 2, 3), mask=None), MaskedTensorV2,
1056          'mask: expected a tf.bool Tensor, got None'
1057      ],
1058      [
1059          lambda: extension_type.AnonymousExtensionType(
1060              values=constant_op.constant([[1, 2], [3, 4]]),
1061              mask=ragged_factory_ops.constant([[1, 2], [3]])), MaskedTensorV2,
1062          'mask: expected a tf.bool Tensor'
1063      ],
1064      [
1065          lambda: extension_type.AnonymousExtensionType(
1066              values=constant_op.constant([1, 2, 3]),
1067              mask=constant_op.constant([True, False])), MaskedTensorV2,
1068          'Shapes .* are incompatible'
1069      ],
1070      [
1071          lambda: extension_type.AnonymousExtensionType(
1072              values=constant_op.constant([1, 2, 3])), ops.Tensor,
1073          'Expected `new_type` to be a subclass of tf.ExtensionType'
1074      ],
1075      [
1076          lambda: constant_op.constant([1, 2, 3]),
1077          extension_type.AnonymousExtensionType,
1078          'Expected `value` to be a tf.ExtensionType'
1079      ],
1080  ])
1081  def testReinterpretErrors(self, value, new_type, error):
1082    if callable(value):
1083      value = value()
1084    with self.assertRaisesRegex((TypeError, ValueError), error):
1085      extension_type.reinterpret(value, new_type)
1086
1087  def testLoadSavedModelWithUnregisteredExtensionType(self):
1088
1089    def f(x, y):
1090      x_values = x.values if isinstance(x, MaskedTensorV1) else x
1091      y_values = y.values if isinstance(y, MaskedTensorV1) else y
1092      x_mask = x.mask if isinstance(x, MaskedTensorV1) else True
1093      y_mask = y.mask if isinstance(y, MaskedTensorV1) else True
1094      return MaskedTensorV1(x_values + y_values, x_mask & y_mask)
1095
1096    t_spec = tensor_spec.TensorSpec(None, dtypes.int32)
1097    b_spec = tensor_spec.TensorSpec(None, dtypes.bool)
1098    mt_spec = MaskedTensorV1.Spec(values=t_spec, mask=b_spec)
1099    model = module.Module()
1100    model.f = def_function.function(f)
1101    model.f.get_concrete_function(t_spec, t_spec)
1102    model.f.get_concrete_function(t_spec, mt_spec)
1103    model.f.get_concrete_function(mt_spec, t_spec)
1104    model.f.get_concrete_function(mt_spec, mt_spec)
1105
1106    path = tempfile.mkdtemp(prefix=test.get_temp_dir())
1107    with temporarily_register_type_spec('tf.test.MaskedTensorV1.Spec',
1108                                        MaskedTensorV1.Spec):
1109      save.save(model, path)
1110    loaded_model = load.load(path)
1111
1112    with self.assertRaises(ValueError):
1113      type_spec.lookup('tf.test.MaskedTensorV1')
1114
1115    t = constant_op.constant([10, 20, 30])
1116    v1 = loaded_model.f(t, t)
1117    self.assertIsInstance(v1, extension_type.AnonymousExtensionType)
1118    self.assertAllEqual(v1.values, [20, 40, 60])
1119    self.assertAllEqual(v1.mask, True)
1120
1121    v2 = loaded_model.f(v1, v1)
1122    self.assertIsInstance(v2, extension_type.AnonymousExtensionType)
1123    self.assertAllEqual(v2.values, [40, 80, 120])
1124    self.assertAllEqual(v2.mask, True)
1125
1126    mt = MaskedTensorV1([1, 2, 3], [True, True, False])
1127    v3 = loaded_model.f(
1128        t, extension_type.reinterpret(mt,
1129                                      extension_type.AnonymousExtensionType))
1130    self.assertIsInstance(v3, extension_type.AnonymousExtensionType)
1131    self.assertAllEqual(v3.values, [11, 22, 33])
1132    self.assertAllEqual(v3.mask, [True, True, False])
1133
1134    v4 = extension_type.reinterpret(v3, MaskedTensorV1)
1135    self.assertIsInstance(v4, MaskedTensorV1)
1136    self.assertAllEqual(v4.values, [11, 22, 33])
1137    self.assertAllEqual(v4.mask, [True, True, False])
1138
1139
1140def replace_tensors_with_placeholders(value):
1141
1142  def repl(x):
1143    if isinstance(x, ops.Tensor):
1144      return array_ops.placeholder_with_default(x, shape=None)
1145    else:
1146      return x
1147
1148  return nest.map_structure(repl, value, expand_composites=True)
1149
1150
1151@contextlib.contextmanager
1152def temporarily_add_dispatch(op, typ, fn):
1153  n = len(op._tf_dispatchers)
1154  dispatch.dispatch_for_types(op, typ)(fn)
1155  yield
1156  assert len(op._tf_dispatchers) == n + 1
1157  del op._tf_dispatchers[-1]
1158
1159
1160@contextlib.contextmanager
1161def temporarily_register_type_spec(name, cls):
1162  """Context manager for making temporary changes to the TypeSpec registry."""
1163  type_spec.register(name)(cls)
1164  yield
1165  assert type_spec._TYPE_SPEC_TO_NAME.pop(cls) == name
1166  assert type_spec._NAME_TO_TYPE_SPEC.pop(name) is cls
1167
1168
1169if __name__ == '__main__':
1170  googletest.main()
1171