• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Test utils for tensorflow."""
18import collections
19from collections import OrderedDict
20import contextlib
21import functools
22import gc
23import itertools
24import math
25import os
26import random
27import re
28import tempfile
29import threading
30import time
31import unittest
32
33from absl.testing import parameterized
34import numpy as np
35
36from google.protobuf import descriptor_pool
37from google.protobuf import text_format
38
39from tensorflow.core.config import flags
40from tensorflow.core.framework import graph_pb2
41from tensorflow.core.protobuf import rewriter_config_pb2
42from tensorflow.python import pywrap_sanitizers
43from tensorflow.python import tf2
44from tensorflow.python.client import device_lib
45from tensorflow.python.client import pywrap_tf_session
46from tensorflow.python.client import session
47from tensorflow.python.compat.compat import forward_compatibility_horizon
48from tensorflow.python.eager import backprop
49from tensorflow.python.eager import context
50from tensorflow.python.eager import def_function
51from tensorflow.python.eager import tape
52from tensorflow.python.framework import _test_metrics_util
53from tensorflow.python.framework import config
54from tensorflow.python.framework import device as pydev
55from tensorflow.python.framework import dtypes
56from tensorflow.python.framework import errors
57from tensorflow.python.framework import errors_impl
58from tensorflow.python.framework import gpu_util
59from tensorflow.python.framework import importer
60from tensorflow.python.framework import indexed_slices
61from tensorflow.python.framework import ops
62from tensorflow.python.framework import random_seed
63from tensorflow.python.framework import sparse_tensor
64from tensorflow.python.framework import tensor_shape
65from tensorflow.python.framework import tensor_util
66from tensorflow.python.framework import tfrt_utils
67from tensorflow.python.framework import versions
68from tensorflow.python.ops import array_ops
69from tensorflow.python.ops import control_flow_util
70from tensorflow.python.ops import control_flow_util_v2
71from tensorflow.python.ops import gradients_impl
72from tensorflow.python.ops import math_ops
73from tensorflow.python.ops import script_ops
74from tensorflow.python.ops import summary_ops_v2
75from tensorflow.python.ops import variables
76from tensorflow.python.ops.ragged import ragged_ops  # pylint: disable=unused-import
77from tensorflow.python.ops.ragged import ragged_tensor
78from tensorflow.python.ops.ragged import ragged_tensor_value
79from tensorflow.python.platform import _pywrap_stacktrace_handler
80from tensorflow.python.platform import googletest
81from tensorflow.python.platform import tf_logging as logging
82from tensorflow.python.training import server_lib
83from tensorflow.python.util import _pywrap_util_port
84from tensorflow.python.util import compat
85from tensorflow.python.util import deprecation
86from tensorflow.python.util import nest
87from tensorflow.python.util import tf_decorator
88from tensorflow.python.util import tf_inspect
89from tensorflow.python.util import traceback_utils
90from tensorflow.python.util.compat import collections_abc
91from tensorflow.python.util.protobuf import compare
92from tensorflow.python.util.tf_export import tf_export
93
94
95# If the below import is made available through the BUILD rule, then this
96# function is overridden and will instead return True and cause Tensorflow
97# graphs to be compiled with XLA.
98def is_xla_enabled():
99  return False
100
101
102try:
103  from tensorflow.python.framework.is_xla_test_true import is_xla_enabled  # pylint: disable=g-import-not-at-top, unused-import
104except Exception:  # pylint: disable=broad-except
105  pass
106
107
108# Uses the same mechanism as above to selectively enable/disable MLIR
109# compilation.
110def is_mlir_bridge_enabled():
111  return None
112
113
114try:
115  from tensorflow.python.framework.is_mlir_bridge_test_false import is_mlir_bridge_enabled  # pylint: disable=g-import-not-at-top, unused-import
116except ImportError:
117  try:
118    from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled  # pylint: disable=g-import-not-at-top, unused-import
119  except ImportError:
120    pass
121
122
123def is_asan_enabled():
124  """Check if ASAN is enabled."""
125  return pywrap_sanitizers.is_asan_enabled()
126
127
128def is_msan_enabled():
129  """Check if MSAN is enabled."""
130  return pywrap_sanitizers.is_msan_enabled()
131
132
133def is_tsan_enabled():
134  """Check if TSAN is enabled."""
135  return pywrap_sanitizers.is_tsan_enabled()
136
137
138def is_ubsan_enabled():
139  """Check if UBSAN is enabled."""
140  return pywrap_sanitizers.is_ubsan_enabled()
141
142
143def _get_object_count_by_type(exclude=()):
144  return (
145      collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) -
146      collections.Counter([type(obj).__name__ for obj in exclude]))
147
148
149@tf_export("test.gpu_device_name")
150def gpu_device_name():
151  """Returns the name of a GPU device if available or a empty string.
152
153  This method should only be used in tests written with `tf.test.TestCase`.
154
155  >>> class MyTest(tf.test.TestCase):
156  ...
157  ...   def test_add_on_gpu(self):
158  ...     if not tf.test.is_built_with_gpu_support():
159  ...       self.skipTest("test is only applicable on GPU")
160  ...
161  ...     with tf.device(tf.test.gpu_device_name()):
162  ...       self.assertEqual(tf.math.add(1.0, 2.0), 3.0)
163
164  """
165  for x in device_lib.list_local_devices():
166    if x.device_type == "GPU":
167      return compat.as_str(x.name)
168  return ""
169
170
171def assert_ops_in_graph(expected_ops, graph):
172  """Assert all expected operations are found.
173
174  Args:
175    expected_ops: `dict<string, string>` of op name to op type.
176    graph: Graph to check.
177
178  Returns:
179    `dict<string, node>` of node name to node.
180
181  Raises:
182    ValueError: If the expected ops are not present in the graph.
183  """
184  actual_ops = {}
185  gd = graph.as_graph_def()
186  for node in gd.node:
187    if node.name in expected_ops:
188      if expected_ops[node.name] != node.op:
189        raise ValueError("Expected op for node %s is different. %s vs %s" %
190                         (node.name, expected_ops[node.name], node.op))
191      actual_ops[node.name] = node
192  if set(expected_ops.keys()) != set(actual_ops.keys()):
193    raise ValueError("Not all expected ops are present. Expected %s, found %s" %
194                     (expected_ops.keys(), actual_ops.keys()))
195  return actual_ops
196
197
198@tf_export("test.assert_equal_graph_def", v1=[])
199def assert_equal_graph_def_v2(expected, actual):
200  """Asserts that two `GraphDef`s are (mostly) the same.
201
202  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
203  nodes, attrs, and control inputs.  Node names are used to match up nodes
204  between the graphs, so the naming of nodes must be consistent. This function
205  ignores randomized attribute values that may appear in V2 checkpoints.
206
207  Args:
208    expected: The `GraphDef` we expected.
209    actual: The `GraphDef` we have.
210
211  Raises:
212    AssertionError: If the `GraphDef`s do not match.
213    TypeError: If either argument is not a `GraphDef`.
214  """
215  assert_equal_graph_def(actual, expected, checkpoint_v2=True,
216                         hash_table_shared_name=True)
217
218
219@tf_export(v1=["test.assert_equal_graph_def"])
220def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False,
221                              hash_table_shared_name=False):
222  """Asserts that two `GraphDef`s are (mostly) the same.
223
224  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
225  nodes, attrs, and control inputs.  Node names are used to match up nodes
226  between the graphs, so the naming of nodes must be consistent.
227
228  Args:
229    actual: The `GraphDef` we have.
230    expected: The `GraphDef` we expected.
231    checkpoint_v2: boolean determining whether to ignore randomized attribute
232      values that appear in V2 checkpoints.
233    hash_table_shared_name: boolean determining whether to ignore randomized
234      shared_names that appear in HashTableV2 op defs.
235
236  Raises:
237    AssertionError: If the `GraphDef`s do not match.
238    TypeError: If either argument is not a `GraphDef`.
239  """
240  assert_equal_graph_def(actual, expected, checkpoint_v2,
241                         hash_table_shared_name)
242
243
244def assert_equal_graph_def(actual, expected, checkpoint_v2=False,
245                           hash_table_shared_name=False):
246  if not isinstance(actual, graph_pb2.GraphDef):
247    raise TypeError("Expected tf.GraphDef for actual, got %s" %
248                    type(actual).__name__)
249  if not isinstance(expected, graph_pb2.GraphDef):
250    raise TypeError("Expected tf.GraphDef for expected, got %s" %
251                    type(expected).__name__)
252
253  if checkpoint_v2:
254    _strip_checkpoint_v2_randomized(actual)
255    _strip_checkpoint_v2_randomized(expected)
256
257  if hash_table_shared_name:
258    _strip_hash_table_shared_name(actual)
259    _strip_hash_table_shared_name(expected)
260
261  diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(),
262                                                expected.SerializeToString())
263  if diff:
264    raise AssertionError(compat.as_str(diff))
265
266
267def assert_meta_graph_protos_equal(tester, a, b):
268  """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
269  # Carefully check the collection_defs
270  tester.assertEqual(set(a.collection_def), set(b.collection_def))
271  collection_keys = a.collection_def.keys()
272  for k in collection_keys:
273    a_value = a.collection_def[k]
274    b_value = b.collection_def[k]
275    proto_type = ops.get_collection_proto_type(k)
276    if proto_type:
277      a_proto = proto_type()
278      b_proto = proto_type()
279      # Number of entries in the collections is the same
280      tester.assertEqual(
281          len(a_value.bytes_list.value), len(b_value.bytes_list.value))
282      for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
283                                              b_value.bytes_list.value):
284        a_proto.ParseFromString(a_value_item)
285        b_proto.ParseFromString(b_value_item)
286        tester.assertProtoEquals(a_proto, b_proto)
287    else:
288      tester.assertEquals(a_value, b_value)
289  # Compared the fields directly, remove their raw values from the
290  # proto comparison below.
291  a.ClearField("collection_def")
292  b.ClearField("collection_def")
293
294  # Check the graph_defs.
295  assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
296  # Check graph_def versions (ignored by assert_equal_graph_def).
297  tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
298  # Compared the fields directly, remove their raw values from the
299  # proto comparison below.
300  a.ClearField("graph_def")
301  b.ClearField("graph_def")
302
303  tester.assertProtoEquals(a, b)
304
305
306# Matches attributes named via _SHARDED_SUFFIX in
307# tensorflow/python/training/saver.py
308_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part"
309
310
311def _strip_checkpoint_v2_randomized(graph_def):
312  for node in graph_def.node:
313    delete_keys = []
314    for attr_key in node.attr:
315      attr_tensor_value = node.attr[attr_key].tensor
316      if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
317        attr_tensor_string_value = attr_tensor_value.string_val[0]
318        if (attr_tensor_string_value and
319            re.match(compat.as_bytes(_SHARDED_SAVE_OP_PATTERN),
320                     attr_tensor_string_value)):
321          delete_keys.append(attr_key)
322    for attr_key in delete_keys:
323      del node.attr[attr_key]
324
325
326_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+"
327
328
329def _strip_hash_table_shared_name(graph_def):
330  for node in graph_def.node:
331    delete_keys = []
332    if node.op == "HashTableV2" and "shared_name" in node.attr:
333      if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN),
334                  node.attr["shared_name"].s):
335        delete_keys.append("shared_name")
336    for attr_key in delete_keys:
337      del node.attr[attr_key]
338
339
340def IsGoogleCudaEnabled():
341  return _pywrap_util_port.IsGoogleCudaEnabled()
342
343
344def IsBuiltWithROCm():
345  return _pywrap_util_port.IsBuiltWithROCm()
346
347
348def IsBuiltWithXLA():
349  return _pywrap_util_port.IsBuiltWithXLA()
350
351
352def IsBuiltWithNvcc():
353  return _pywrap_util_port.IsBuiltWithNvcc()
354
355
356def GpuSupportsHalfMatMulAndConv():
357  return _pywrap_util_port.GpuSupportsHalfMatMulAndConv()
358
359
360def IsMklEnabled():
361  return (_pywrap_util_port.IsMklEnabled() or
362          os.getenv("TF_ENABLE_ONEDNN_OPTS", "False").lower() in ["true", "1"])
363
364
365def InstallStackTraceHandler():
366  _pywrap_stacktrace_handler.InstallStacktraceHandler()
367
368
369def NHWCToNCHW(input_tensor):
370  """Converts the input from the NHWC format to NCHW.
371
372  Args:
373    input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape
374
375  Returns:
376    converted tensor or shape array
377  """
378  # tensor dim -> new axis order
379  new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
380  if isinstance(input_tensor, ops.Tensor):
381    ndims = input_tensor.shape.ndims
382    return array_ops.transpose(input_tensor, new_axes[ndims])
383  else:
384    ndims = len(input_tensor)
385    return [input_tensor[a] for a in new_axes[ndims]]
386
387
388def NHWCToNCHW_VECT_C(input_shape_or_tensor):
389  """Transforms the input from the NHWC layout to NCHW_VECT_C layout.
390
391  Note: Does not include quantization or type conversion steps, which should
392  be applied afterwards.
393
394  Args:
395    input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
396
397  Returns:
398    tensor or shape array transformed into NCHW_VECT_C
399
400  Raises:
401    ValueError: if last dimension of `input_shape_or_tensor` is not evenly
402        divisible by 4.
403  """
404  permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
405  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
406  temp_shape = (
407      input_shape_or_tensor.shape.as_list()
408      if is_tensor else input_shape_or_tensor)
409  if temp_shape[-1] % 4 != 0:
410    raise ValueError(
411        "Last dimension of input must be evenly divisible by 4 to convert to "
412        "NCHW_VECT_C.")
413  temp_shape[-1] //= 4
414  temp_shape.append(4)
415  permutation = permutations[len(temp_shape)]
416  if is_tensor:
417    t = array_ops.reshape(input_shape_or_tensor, temp_shape)
418    return array_ops.transpose(t, permutation)
419  else:
420    return [temp_shape[a] for a in permutation]
421
422
423def NCHW_VECT_CToNHWC(input_shape_or_tensor):
424  """Transforms the input from the NCHW_VECT_C layout to NHWC layout.
425
426  Note: Does not include de-quantization or type conversion steps, which should
427  be applied beforehand.
428
429  Args:
430    input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
431
432  Returns:
433    tensor or shape array transformed into NHWC
434
435  Raises:
436    ValueError: if last dimension of `input_shape_or_tensor` is not 4.
437  """
438  permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
439  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
440  input_shape = (
441      input_shape_or_tensor.shape.as_list()
442      if is_tensor else input_shape_or_tensor)
443  if input_shape[-1] != 4:
444    raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
445  permutation = permutations[len(input_shape)]
446  nhwc_shape = [input_shape[a] for a in permutation[:-1]]
447  nhwc_shape[-1] *= input_shape[-1]
448  if is_tensor:
449    t = array_ops.transpose(input_shape_or_tensor, permutation)
450    return array_ops.reshape(t, nhwc_shape)
451  else:
452    return nhwc_shape
453
454
455def NCHWToNHWC(input_tensor):
456  """Converts the input from the NCHW format to NHWC.
457
458  Args:
459    input_tensor: a 4- or 5-D tensor, or an array representing shape
460
461  Returns:
462    converted tensor or shape array
463  """
464  # tensor dim -> new axis order
465  new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]}
466  if isinstance(input_tensor, ops.Tensor):
467    ndims = input_tensor.shape.ndims
468    return array_ops.transpose(input_tensor, new_axes[ndims])
469  else:
470    ndims = len(input_tensor)
471    return [input_tensor[a] for a in new_axes[ndims]]
472
473
474def skip_if(condition):
475  """Skips the decorated function if condition is or evaluates to True.
476
477  Args:
478    condition: Either an expression that can be used in "if not condition"
479      statement, or a callable whose result should be a boolean.
480
481  Returns:
482    The wrapped function
483  """
484
485  def real_skip_if(fn):
486
487    def wrapper(*args, **kwargs):
488      if callable(condition):
489        skip = condition()
490      else:
491        skip = condition
492      if not skip:
493        return fn(*args, **kwargs)
494
495    return wrapper
496
497  return real_skip_if
498
499
500@contextlib.contextmanager
501def skip_if_error(test_obj, error_type, messages=None):
502  """Context manager to skip cases not considered failures by the tests.
503
504  Note that this does not work if used in setUpClass/tearDownClass.
505  Usage in setUp/tearDown works fine just like regular test methods.
506
507  Args:
508    test_obj: A test object provided as `self` in the test methods; this object
509      is usually an instance of `unittest.TestCase`'s subclass and should have
510      `skipTest` method.
511    error_type: The error type to skip. Note that if `messages` are given, both
512      `error_type` and `messages` need to match for the test to be skipped.
513    messages: Optional, a string or list of strings. If `None`, the test will be
514      skipped if `error_type` matches what is raised; otherwise, the test is
515      skipped if any of the `messages` is contained in the message of the error
516      raised, and `error_type` matches the error raised.
517
518  Yields:
519    Nothing.
520  """
521  if messages:
522    messages = nest.flatten(messages)
523  try:
524    yield
525  except error_type as e:
526    if not messages or any(message in str(e) for message in messages):
527      test_obj.skipTest("Skipping error: {}: {}".format(type(e), str(e)))
528    else:
529      raise
530
531
532def enable_c_shapes(fn):
533  """No-op. TODO(b/74620627): Remove this."""
534  return fn
535
536
537def with_c_shapes(cls):
538  """No-op. TODO(b/74620627): Remove this."""
539  return cls
540
541
542def enable_control_flow_v2(fn):
543  """Decorator for enabling CondV2 and WhileV2 on a test.
544
545  Note this enables using CondV2 and WhileV2 after running the test class's
546  setup/teardown methods.
547
548  In addition to this, callers must import the while_v2 module in order to set
549  the _while_v2 module in control_flow_ops.
550
551  Args:
552    fn: the function to be wrapped
553
554  Returns:
555    The wrapped function
556  """
557
558  def wrapper(*args, **kwargs):
559    enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
560    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
561    try:
562      return fn(*args, **kwargs)
563    finally:
564      control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
565
566  return wrapper
567
568
569def with_control_flow_v2(cls):
570  """Adds methods that call original methods with WhileV2 and CondV2 enabled.
571
572  Note this enables CondV2 and WhileV2 in new methods after running the test
573  class's setup method.
574
575  In addition to this, callers must import the while_v2 module in order to set
576  the _while_v2 module in control_flow_ops.
577
578  If a test function has _disable_control_flow_v2 attr set to True (using the
579  @disable_control_flow_v2 decorator), the v2 function is not generated for it.
580
581  Example:
582
583  @test_util.with_control_flow_v2
584  class ControlFlowTest(test.TestCase):
585
586    def testEnabledForV2(self):
587      ...
588
589    @test_util.disable_control_flow_v2("b/xyzabc")
590    def testDisabledForV2(self):
591      ...
592
593  Generated class:
594  class ControlFlowTest(test.TestCase):
595
596    def testEnabledForV2(self):
597      ...
598
599    def testEnabledForV2WithControlFlowV2(self):
600      // Enable V2 flags.
601      testEnabledForV2(self)
602      // Restore V2 flags.
603
604    def testDisabledForV2(self):
605      ...
606
607  Args:
608    cls: class to decorate
609
610  Returns:
611    cls with new test methods added
612  """
613  if control_flow_util.ENABLE_CONTROL_FLOW_V2:
614    return cls
615
616  for name, value in cls.__dict__.copy().items():
617    if (callable(value) and
618        name.startswith(unittest.TestLoader.testMethodPrefix) and
619        not getattr(value, "_disable_control_flow_v2", False)):
620      setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
621  return cls
622
623
624def disable_control_flow_v2(unused_msg):
625  """Decorator for a function in a with_control_flow_v2 enabled test class.
626
627  Blocks the function from being run with v2 control flow ops.
628
629  Args:
630    unused_msg: Reason for disabling.
631
632  Returns:
633    The wrapped function with _disable_control_flow_v2 attr set to True.
634  """
635
636  def wrapper(func):
637    func._disable_control_flow_v2 = True
638    return func
639
640  return wrapper
641
642
643def enable_output_all_intermediates(fn):
644  """Force-enable outputing all intermediates from functional control flow ops.
645
646  Args:
647    fn: the function to be wrapped
648
649  Returns:
650    The wrapped function
651  """
652
653  def wrapper(*args, **kwargs):
654    output_all_intermediates_old = \
655        control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
656    control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
657    try:
658      return fn(*args, **kwargs)
659    finally:
660      control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
661          output_all_intermediates_old
662
663  return wrapper
664
665
666def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
667  """Decorator for asserting that no new Python objects persist after a test.
668
669  Runs the test multiple times executing eagerly, first as a warmup and then to
670  let objects accumulate. The warmup helps ignore caches which do not grow as
671  the test is run repeatedly.
672
673  Useful for checking that there are no missing Py_DECREFs in the C exercised by
674  a bit of Python.
675
676  Args:
677    func: The function to test.
678    warmup_iters: The numer of warmup iterations, excluded from measuring.
679
680  Returns:
681    The wrapped function performing the test.
682  """
683
684  def wrap_f(f):
685    def decorator(self, *args, **kwargs):
686      """Warms up, gets object counts, runs the test, checks for new objects."""
687      with context.eager_mode():
688        gc.disable()
689        # Run the test 2 times as warmup, in an attempt to fill up caches, which
690        # should not grow as the test is run repeatedly below.
691        #
692        # TODO(b/117156879): Running warmup twice is black magic; we have seen
693        # tests that fail with 1 warmup run, and pass with 2, on various
694        # versions of python2.7.x.
695        for _ in range(warmup_iters):
696          f(self, *args, **kwargs)
697        # Since we aren't in the normal test lifecycle, we need to manually run
698        # cleanups to clear out their object references.
699        self.doCleanups()
700
701        # Some objects are newly created by _get_object_count_by_type().  So
702        # create and save as a dummy variable to include it as a baseline.
703        obj_count_by_type = _get_object_count_by_type()
704        gc.collect()
705
706        # Make sure any registered functions are cleaned up in the C++ runtime.
707        registered_function_names = context.context().list_function_names()
708
709        # unittest.doCleanups adds to self._outcome with each unwound call.
710        # These objects are retained across gc collections so we exclude them
711        # from the object count calculation.
712        obj_count_by_type = _get_object_count_by_type(
713            exclude=gc.get_referents(self._outcome.errors,
714                                     self._outcome.skipped))
715
716        if ops.has_default_graph():
717          collection_sizes_before = {
718              collection: len(ops.get_collection(collection))
719              for collection in ops.get_default_graph().collections
720          }
721        for _ in range(3):
722          f(self, *args, **kwargs)
723        # Since we aren't in the normal test lifecycle, we need to manually run
724        # cleanups to clear out their object references.
725        self.doCleanups()
726        # Note that gc.get_objects misses anything that isn't subject to garbage
727        # collection (C types). Collections are a common source of leaks, so we
728        # test for collection sizes explicitly.
729        if ops.has_default_graph():
730          for collection_key in ops.get_default_graph().collections:
731            collection = ops.get_collection(collection_key)
732            size_before = collection_sizes_before.get(collection_key, 0)
733            if len(collection) > size_before:
734              raise AssertionError(
735                  ("Collection %s increased in size from "
736                   "%d to %d (current items %s).") %
737                  (collection_key, size_before, len(collection), collection))
738            # Make sure our collection checks don't show up as leaked memory by
739            # removing references to temporary variables.
740            del collection
741            del collection_key
742            del size_before
743          del collection_sizes_before
744        gc.collect()
745
746        # There should be no new Python objects hanging around.
747        obj_count_by_type = (
748            _get_object_count_by_type(
749                exclude=gc.get_referents(self._outcome.errors,
750                                         self._outcome.skipped)) -
751            obj_count_by_type)
752
753        # There should be no newly registered functions hanging around.
754        leftover_functions = (
755            context.context().list_function_names() - registered_function_names)
756        assert not leftover_functions, (
757            "The following functions were newly created: %s" %
758            leftover_functions)
759
760        # In some cases (specifically on MacOS), new_count is somehow
761        # smaller than previous_count.
762        # Using plain assert because not all classes using this decorator
763        # have assertLessEqual
764        assert not obj_count_by_type, (
765            "The following objects were newly created: %s" %
766            str(obj_count_by_type))
767        gc.enable()
768    return decorator
769
770  if func is None:
771    return wrap_f
772  else:
773    return wrap_f(func)
774
775
776def assert_no_new_tensors(f):
777  """Decorator for asserting that no new Tensors persist after a test.
778
779  Mainly useful for checking that code using the Python C API has correctly
780  manipulated reference counts.
781
782  Clears the caches that it knows about, runs the garbage collector, then checks
783  that there are no Tensor or Tensor-like objects still around. This includes
784  Tensors to which something still has a reference (e.g. from missing
785  Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
786  of the objects has __del__ defined).
787
788  Args:
789    f: The test case to run.
790
791  Returns:
792    The decorated test case.
793  """
794
795  def decorator(self, **kwargs):
796    """Finds existing Tensors, runs the test, checks for new Tensors."""
797
798    def _is_tensorflow_object(obj):
799      try:
800        return isinstance(obj,
801                          (ops.Tensor, variables.Variable,
802                           tensor_shape.Dimension, tensor_shape.TensorShape))
803      except (ReferenceError, AttributeError):
804        # If the object no longer exists, we don't care about it.
805        return False
806
807    tensors_before = set(
808        id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
809    outside_executed_eagerly = context.executing_eagerly()
810    # Run the test in a new graph so that collections get cleared when it's
811    # done, but inherit the graph key so optimizers behave.
812    outside_graph_key = ops.get_default_graph()._graph_key
813    with ops.Graph().as_default():
814      ops.get_default_graph()._graph_key = outside_graph_key
815      if outside_executed_eagerly:
816        with context.eager_mode():
817          result = f(self, **kwargs)
818      else:
819        result = f(self, **kwargs)
820    # Make an effort to clear caches, which would otherwise look like leaked
821    # Tensors.
822    context.context()._clear_caches()  # pylint: disable=protected-access
823    gc.collect()
824    tensors_after = [
825        obj for obj in gc.get_objects()
826        if _is_tensorflow_object(obj) and id(obj) not in tensors_before
827    ]
828    if tensors_after:
829      raise AssertionError(("%d Tensors not deallocated after test: %s" % (
830          len(tensors_after),
831          str(tensors_after),
832      )))
833    return result
834
835  return decorator
836
837
838def _find_reference_cycle(objects, idx):
839
840  def get_ignore_reason(obj, denylist):
841    """Tests whether an object should be omitted from the dependency graph."""
842    if len(denylist) > 100:
843      return "<depth limit>"
844    if tf_inspect.isframe(obj):
845      if "test_util.py" in tf_inspect.getframeinfo(obj)[0]:
846        return "<test code>"
847    for b in denylist:
848      if b is obj:
849        return "<test code>"
850    if obj is denylist:
851      return "<test code>"
852    return None
853
854  # Note: this function is meant to help with diagnostics. Its output is purely
855  # a human-readable representation, so you may freely modify it to suit your
856  # needs.
857  def describe(obj, denylist, leaves_only=False):
858    """Returns a custom human-readable summary of obj.
859
860    Args:
861      obj: the value to describe.
862      denylist: same as denylist in get_ignore_reason.
863      leaves_only: boolean flag used when calling describe recursively. Useful
864        for summarizing collections.
865    """
866    if get_ignore_reason(obj, denylist):
867      return "{}{}".format(get_ignore_reason(obj, denylist), type(obj))
868    if tf_inspect.isframe(obj):
869      return "frame: {}".format(tf_inspect.getframeinfo(obj))
870    elif tf_inspect.ismodule(obj):
871      return "module: {}".format(obj.__name__)
872    else:
873      if leaves_only:
874        return "{}, {}".format(type(obj), id(obj))
875      elif isinstance(obj, list):
876        return "list({}): {}".format(
877            id(obj), [describe(e, denylist, leaves_only=True) for e in obj])
878      elif isinstance(obj, tuple):
879        return "tuple({}): {}".format(
880            id(obj), [describe(e, denylist, leaves_only=True) for e in obj])
881      elif isinstance(obj, dict):
882        return "dict({}): {} keys".format(id(obj), len(obj.keys()))
883      elif tf_inspect.isfunction(obj):
884        return "function({}) {}; globals ID: {}".format(
885            id(obj), obj.__name__, id(obj.__globals__))
886      else:
887        return "{}, {}".format(type(obj), id(obj))
888
889  def build_ref_graph(obj, graph, reprs, denylist):
890    """Builds a reference graph as <referrer> -> <list of referents>.
891
892    Args:
893      obj: The object to start from. The graph will be built by recursively
894        adding its referrers.
895      graph: Dict holding the graph to be built. To avoid creating extra
896        references, the graph holds object IDs rather than actual objects.
897      reprs: Auxiliary structure that maps object IDs to their human-readable
898        description.
899      denylist: List of objects to ignore.
900    """
901    referrers = gc.get_referrers(obj)
902    denylist = denylist + (referrers,)
903
904    obj_id = id(obj)
905    for r in referrers:
906      if get_ignore_reason(r, denylist) is None:
907        r_id = id(r)
908        if r_id not in graph:
909          graph[r_id] = []
910        if obj_id not in graph[r_id]:
911          graph[r_id].append(obj_id)
912          build_ref_graph(r, graph, reprs, denylist)
913          reprs[r_id] = describe(r, denylist)
914
915  def find_cycle(el, graph, reprs, path):
916    """Finds and prints a single cycle in the dependency graph."""
917    if el not in graph:
918      return
919    for r in graph[el]:
920      if r in path:
921        logging.error("Reference cycle sample:")
922        for p in path + (r,):
923          logging.error(reprs.get(p, "unknown object " + str(p)))
924        return True
925      else:
926        if find_cycle(r, graph, reprs, path + (r,)):
927          return True
928    return False
929
930  obj = objects[idx]
931  graph = {}  # referrer ID -> object ID
932  reprs = {}  # object ID -> description
933  build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason,
934                                      describe, build_ref_graph, find_cycle))
935  for k in graph:
936    if find_cycle(k, graph, reprs, ()):
937      return True
938  return False
939
940
941def assert_no_garbage_created(f):
942  """Test method decorator to assert that no garbage has been created.
943
944  Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
945  cannot be un-set (i.e. will disable garbage collection for any other unit
946  tests in the same file/shard).
947
948  Args:
949    f: The function to decorate.
950
951  Returns:
952    The decorated function.
953  """
954
955  def decorator(self, **kwargs):
956    """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
957    # Force-load `distribution_strategy_context` to prevent GC at
958    # test time when using eager. Remove once b/117329403 is resolved.
959    tape.distribution_strategy_context.get_strategy()
960
961    gc.disable()
962    previous_debug_flags = gc.get_debug()
963    gc.set_debug(gc.DEBUG_SAVEALL)
964    gc.collect()
965    previous_garbage = len(gc.garbage)
966    result = f(self, **kwargs)
967    gc.collect()
968    new_garbage = len(gc.garbage)
969    if new_garbage > previous_garbage:
970
971      for i, obj in enumerate(gc.garbage[previous_garbage:]):
972        # Known false positive for ast.fix_missing_locations.
973        if getattr(obj, "__module__", "") == "ast":
974          new_garbage -= 3
975
976    if new_garbage > previous_garbage:
977      logging.error(
978          "The decorated test created work for Python's garbage collector, "
979          "likely due to a reference cycle. New objects in cycle(s):")
980      for i, obj in enumerate(gc.garbage[previous_garbage:]):
981        try:
982          logging.error("Object %d of %d", i,
983                        len(gc.garbage) - previous_garbage)
984
985          def _safe_object_str(obj):
986            return "<%s %d>" % (obj.__class__.__name__, id(obj))
987
988          logging.error("  Object type: %s", _safe_object_str(obj))
989          logging.error(
990              "  Referrer types: %s", ", ".join(
991                  [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
992          logging.error(
993              "  Referent types: %s", ", ".join(
994                  [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
995          logging.error("  Object attribute names: %s", dir(obj))
996          logging.error("  Object __str__:")
997          logging.error(obj)
998          logging.error("  Object __repr__:")
999          logging.error(repr(obj))
1000        except Exception:  # pylint: disable=broad-except
1001          logging.error("(Exception while printing object)")
1002
1003    # When garbage is created, this call can help identify reference cycles,
1004    # which are typically the cause of such garbage.
1005    if new_garbage > previous_garbage:
1006      for i in range(previous_garbage, new_garbage):
1007        if _find_reference_cycle(gc.garbage, i):
1008          break
1009
1010    # This will fail if any garbage has been created, typically because of a
1011    # reference cycle.
1012    self.assertEqual(previous_garbage, new_garbage)
1013    # TODO(allenl): Figure out why this debug flag reset doesn't work. It would
1014    # be nice to be able to decorate arbitrary tests in a large test suite and
1015    # not hold on to every object in other tests.
1016    gc.set_debug(previous_debug_flags)
1017    gc.enable()
1018    return result
1019
1020  return decorator
1021
1022
1023def _combine_named_parameters(**kwargs):
1024  """Generate combinations based on its keyword arguments.
1025
1026  Two sets of returned combinations can be concatenated using +.  Their product
1027  can be computed using `times()`.
1028
1029  Args:
1030    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
1031      `option=the_only_possibility`.
1032
1033  Returns:
1034    a list of dictionaries for each combination. Keys in the dictionaries are
1035    the keyword argument names.  Each key has one value - one of the
1036    corresponding keyword argument values.
1037  """
1038  sort_by_key = lambda k: k[0]
1039  combinations = []
1040  for key, values in sorted(kwargs.items(), key=sort_by_key):
1041    if not isinstance(values, list):
1042      values = [values]
1043    combinations.append([(key, value) for value in values])
1044
1045  return [OrderedDict(result) for result in itertools.product(*combinations)]
1046
1047
1048def generate_combinations_with_testcase_name(**kwargs):
1049  """Generate combinations based on its keyword arguments using combine().
1050
1051  This function calls combine() and appends a testcase name to the list of
1052  dictionaries returned. The 'testcase_name' key is a required for named
1053  parameterized tests.
1054
1055  Args:
1056    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
1057      `option=the_only_possibility`.
1058
1059  Returns:
1060    a list of dictionaries for each combination. Keys in the dictionaries are
1061    the keyword argument names.  Each key has one value - one of the
1062    corresponding keyword argument values.
1063  """
1064  combinations = _combine_named_parameters(**kwargs)
1065  named_combinations = []
1066  for combination in combinations:
1067    assert isinstance(combination, OrderedDict)
1068    name = "".join([
1069        "_{}_{}".format("".join(filter(str.isalnum, key)),
1070                        "".join(filter(str.isalnum, str(value))))
1071        for key, value in combination.items()
1072    ])
1073    named_combinations.append(
1074        OrderedDict(
1075            list(combination.items()) +
1076            [("testcase_name", "_test{}".format(name))]))
1077
1078  return named_combinations
1079
1080
1081def run_all_in_graph_and_eager_modes(cls):
1082  """Execute all test methods in the given class with and without eager."""
1083  base_decorator = run_in_graph_and_eager_modes
1084  for name in dir(cls):
1085    if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
1086        name.startswith("testSkipEager") or
1087        name.startswith("test_skip_eager") or
1088        name == "test_session"):
1089      continue
1090    value = getattr(cls, name, None)
1091    if callable(value):
1092      setattr(cls, name, base_decorator(value))
1093  return cls
1094
1095
1096def enable_nested_function_shape_inference(fn):
1097  """Decorator for enabling nested_function_shape_inference on a test.
1098
1099  This function returns a decorator intended to be applied to test methods in
1100  a `tf.test.TestCase` class. Doing so will set nested_function_shape_inference,
1101  reset the context, execute the test, then reset the context to the state
1102  it was in prior to this test.
1103
1104  Example:
1105
1106  class MyTest(test.TestCase):
1107
1108    @enable_nested_function_shape_inference
1109    def testFoo(self):
1110      ...
1111
1112  Args:
1113    fn: the function to be wrapped.
1114
1115  Returns:
1116    The wrapped function.
1117  """
1118
1119  def wrapper(*args, **kwargs):
1120    # If `nested_function_shape_inference` is already enabled do nothing.
1121    if flags.config().enable_nested_function_shape_inference.value():
1122      return fn(*args, **kwargs)
1123
1124    flags.config().enable_nested_function_shape_inference.reset(True)
1125    try:
1126      return fn(*args, **kwargs)
1127    finally:
1128      flags.config().enable_nested_function_shape_inference.reset(False)
1129
1130  return wrapper
1131
1132
1133def enable_eager_op_as_function(fn):
1134  """Returns the same fn. This will be removed once all usages are removed.
1135
1136  Args:
1137    fn: the function to be wrapped.
1138
1139  Returns:
1140    The wrapped function.
1141  """
1142
1143  def wrapper(*args, **kwargs):
1144    return fn(*args, **kwargs)
1145
1146  return wrapper
1147
1148
1149@tf_export("test.with_eager_op_as_function")
1150def with_eager_op_as_function(cls=None, only_as_function=False):  # pylint: disable=unused-argument
1151  """Returns the same class. This will be removed once all usages are removed.
1152
1153  Args:
1154    cls: class to decorate.
1155    only_as_function: unused argument.
1156
1157  Returns:
1158    cls
1159  """
1160
1161  def decorator(cls):
1162    return cls
1163
1164  if cls is not None:
1165    return decorator(cls)
1166
1167  return decorator
1168
1169
1170def enable_graph_building_optimization(fn):
1171  """Decorator for enabling graph_building_optimization on a test.
1172
1173  This function returns a decorator intended to be applied to test methods in
1174  a `tf.test.TestCase` class. Doing so will enable graph_building_optimization,
1175  execute the test, then reset the feature flag to its default value.
1176
1177  Example:
1178
1179  class MyTest(test.TestCase):
1180
1181    @enable_graph_building_optimization
1182    def testFoo(self):
1183      ...
1184
1185  Args:
1186    fn: the function to be wrapped.
1187
1188  Returns:
1189    The wrapped function.
1190  """
1191
1192  def wrapper(*args, **kwargs):
1193    # If `graph_building_optimization` is already enabled do nothing.
1194    if flags.config().graph_building_optimization.value():
1195      return fn(*args, **kwargs)
1196
1197    flags.config().graph_building_optimization.reset(True)
1198    try:
1199      return fn(*args, **kwargs)
1200    finally:
1201      flags.config().graph_building_optimization.reset(False)
1202
1203  return wrapper
1204
1205
1206def add_graph_building_optimization_tests(cls=None):
1207  """Adds methods with graph_building_optimization enabled to the test suite.
1208
1209  Example:
1210
1211  @test_util.add_graph_building_optimization_tests
1212  class FooTest(test.TestCase):
1213
1214    def testBar(self):
1215      ...
1216
1217  Generated class:
1218  class FooTest(test.TestCase):
1219
1220    def testBar(self):
1221      ...
1222
1223    def testBarWithGraphBuildingOptimization(self):
1224      // Enable graph_building_optimization
1225      testBar(self)
1226      // Disable graph_building_optimization
1227
1228  Args:
1229    cls: class to decorate.
1230
1231  Returns:
1232    cls with new test methods added.
1233  """
1234
1235  def decorator(cls):
1236    if flags.config().graph_building_optimization.value():
1237      return cls
1238
1239    for name, value in cls.__dict__.copy().items():
1240      if (callable(value) and
1241          (name.startswith(unittest.TestLoader.testMethodPrefix) or
1242           name.startswith("benchmark"))):
1243        setattr(cls, name + "WithGraphBuildingOptimization",
1244                enable_graph_building_optimization(value))
1245    return cls
1246
1247  if cls is not None:
1248    return decorator(cls)
1249
1250  return decorator
1251
1252
1253def disable_eager_op_as_function(unused_msg):
1254  """Decorator for a function in a with_eager_op_as_function enabled test class.
1255
1256  Blocks the function from being run with eager_op_as_function enabled.
1257
1258  Args:
1259    unused_msg: Reason for disabling.
1260
1261  Returns:
1262    The wrapped function with _disable_eager_op_as_function attr set to True.
1263  """
1264  return _disable_test(execute_func=False)
1265
1266
1267def set_xla_env_flag(func=None, flag=""):
1268  """Decorator for setting XLA_FLAGS prior to running a test.
1269
1270  This function returns a decorator intended to be applied to test methods in
1271  a `tf.test.TestCase` class. Doing so will allow users to set any xla flags
1272  exposed via the XLA_FLAGS environment variable, execute the test, then reset
1273  the XLA_FLAGS to the state it was in prior to this test.
1274
1275  Example:
1276
1277  class MyTest(test.TestCase):
1278
1279    @set_xla_env_flag(flag='--xla_gpu_enable_fast_min_max=false')
1280    def testFoo(self):
1281      ...
1282
1283  Args:
1284    func: The function to be wrapped.
1285    flag: The xla flag to be set in the XLA_FLAGS env variable.
1286
1287  Returns:
1288    The wrapped function.
1289  """
1290
1291  def decorator(f):
1292
1293    @functools.wraps(f)
1294    def decorated(*args, **kwargs):
1295      original_xla_flags = os.environ.get("XLA_FLAGS")
1296      new_xla_flags = flag
1297      if original_xla_flags:
1298        new_xla_flags = new_xla_flags + " " + original_xla_flags
1299      os.environ["XLA_FLAGS"] = new_xla_flags
1300      try:
1301        return f(*args, **kwargs)
1302      finally:
1303        if original_xla_flags is None:
1304          del os.environ["XLA_FLAGS"]
1305        else:
1306          os.environ["XLA_FLAGS"] = original_xla_flags
1307
1308    return decorated
1309
1310  if func is not None:
1311    return decorator(func)
1312
1313  return decorator
1314
1315
1316def build_as_function_and_v1_graph(func=None):
1317  """Run a test case in v1 graph mode and inside tf.function in eager mode.
1318
1319  WARNING: This decorator can only be used in test cases that statically checks
1320  generated graph. Attempting to evaluate graph or function results via.
1321  session.run() or self.evaluate() will fail.
1322
1323  WARNING: This decorator can only be used for test cases that inherit from
1324  absl.testing.parameterized.TestCase.
1325
1326  Args:
1327    func: Test case function to be decorated.
1328
1329  Returns:
1330    Decorated test case function.
1331  """
1332
1333  def decorator(f):
1334    if tf_inspect.isclass(f):
1335      raise ValueError(
1336          "`run_in_graph_mode_and_function` only supports test methods.")
1337
1338    @parameterized.named_parameters(("_v1_graph", "v1_graph"),
1339                                    ("_function", "function"))
1340    @functools.wraps(f)
1341    def decorated(self, run_mode, *args, **kwargs):
1342      if run_mode == "v1_graph":
1343        with ops.Graph().as_default():
1344          f(self, *args, **kwargs)
1345      elif run_mode == "function":
1346
1347        @def_function.function
1348        def function_in_eager():
1349          f(self, *args, **kwargs)
1350
1351        # Create a new graph for the eagerly executed version of this test for
1352        # better isolation.
1353        graph_for_eager_test = ops.Graph()
1354        with graph_for_eager_test.as_default(), context.eager_mode():
1355          function_in_eager()
1356        ops.dismantle_graph(graph_for_eager_test)
1357      else:
1358        raise ValueError("Unknown run mode %s" % run_mode)
1359
1360    return decorated
1361
1362  if func is not None:
1363    return decorator(func)
1364
1365  return decorator
1366
1367
1368def run_in_async_and_sync_mode(f):
1369  """Execute the test in async mode and sync mode."""
1370
1371  @parameterized.named_parameters([("Async", True), ("", False)])
1372  @functools.wraps(f)
1373  def decorator(self, async_mode, *args, **kwargs):
1374    if async_mode:
1375      with context.execution_mode(context.ASYNC):
1376        f(self, *args, **kwargs)
1377    else:
1378      with context.execution_mode(context.SYNC):
1379        f(self, *args, **kwargs)
1380  return decorator
1381
1382
1383def run_in_graph_and_eager_modes(func=None,
1384                                 config=None,
1385                                 use_gpu=True,
1386                                 assert_no_eager_garbage=False):
1387  """Execute the decorated test with and without enabling eager execution.
1388
1389  This function returns a decorator intended to be applied to test methods in
1390  a `tf.test.TestCase` class. Doing so will cause the contents of the test
1391  method to be executed twice - once normally, and once with eager execution
1392  enabled. This allows unittests to confirm the equivalence between eager
1393  and graph execution (see `tf.compat.v1.enable_eager_execution`).
1394
1395  For example, consider the following unittest:
1396
1397  ```python
1398  class MyTests(tf.test.TestCase):
1399
1400    @run_in_graph_and_eager_modes
1401    def test_foo(self):
1402      x = tf.constant([1, 2])
1403      y = tf.constant([3, 4])
1404      z = tf.add(x, y)
1405      self.assertAllEqual([4, 6], self.evaluate(z))
1406
1407  if __name__ == "__main__":
1408    tf.test.main()
1409  ```
1410
1411  This test validates that `tf.add()` has the same behavior when computed with
1412  eager execution enabled as it does when constructing a TensorFlow graph and
1413  executing the `z` tensor in a session.
1414
1415  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1416  `run_in_graph_and_eager_modes` are available decorators for different
1417  v1/v2/eager/graph combinations.
1418
1419
1420  Args:
1421    func: function to be annotated. If `func` is None, this method returns a
1422      decorator the can be applied to a function. If `func` is not None this
1423      returns the decorator applied to `func`.
1424    config: An optional config_pb2.ConfigProto to use to configure the session
1425      when executing graphs.
1426    use_gpu: If True, attempt to run as many operations as possible on GPU.
1427    assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
1428      collector and asserts that no extra garbage has been created when running
1429      the test with eager execution enabled. This will fail if there are
1430      reference cycles (e.g. a = []; a.append(a)). Off by default because some
1431      tests may create garbage for legitimate reasons (e.g. they define a class
1432      which inherits from `object`), and because DEBUG_SAVEALL is sticky in some
1433      Python interpreters (meaning that tests which rely on objects being
1434      collected elsewhere in the unit test file will not work). Additionally,
1435      checks that nothing still has a reference to Tensors that the test
1436      allocated.
1437
1438  Returns:
1439    Returns a decorator that will run the decorated test method twice:
1440    once by constructing and executing a graph in a session and once with
1441    eager execution enabled.
1442  """
1443
1444  def decorator(f):
1445    if tf_inspect.isclass(f):
1446      raise ValueError(
1447          "`run_in_graph_and_eager_modes` only supports test methods. "
1448          "Did you mean to use `run_all_in_graph_and_eager_modes`?")
1449
1450    def decorated(self, *args, **kwargs):
1451      logging.info("Running %s in GRAPH mode.", f.__name__)
1452      try:
1453        with context.graph_mode():
1454          with self.test_session(use_gpu=use_gpu, config=config):
1455            f(self, *args, **kwargs)
1456      except unittest.case.SkipTest:
1457        pass
1458
1459      def run_eagerly(self, **kwargs):
1460        logging.info("Running %s in EAGER mode.", f.__name__)
1461        if not use_gpu:
1462          with ops.device("/device:CPU:0"):
1463            f(self, *args, **kwargs)
1464        else:
1465          f(self, *args, **kwargs)
1466
1467      if assert_no_eager_garbage:
1468        ops.reset_default_graph()
1469        run_eagerly = assert_no_new_tensors(
1470            assert_no_garbage_created(run_eagerly))
1471
1472      # This decorator runs the wrapped test twice.
1473      # Reset the test environment between runs.
1474      self.tearDown()
1475      self._tempdir = None
1476      # Create a new graph for the eagerly executed version of this test for
1477      # better isolation.
1478      graph_for_eager_test = ops.Graph()
1479      with graph_for_eager_test.as_default(), context.eager_mode():
1480        self.setUp()
1481        run_eagerly(self, **kwargs)
1482      ops.dismantle_graph(graph_for_eager_test)
1483
1484    return tf_decorator.make_decorator(f, decorated)
1485
1486  if func is not None:
1487    return decorator(func)
1488
1489  return decorator
1490
1491
1492def py_func_if_in_function(f):
1493
1494  def decorated(*args, **kwds):
1495    if not ops.inside_function():
1496      return f(*args, **kwds)
1497
1498    tensor_args = []
1499    tensor_indices = []
1500    for i, arg in enumerate(args):
1501      if isinstance(arg, (ops.Tensor, variables.Variable)):
1502        tensor_args.append(arg)
1503        tensor_indices.append(i)
1504
1505    def inner_f(*inner_tensor_args):
1506      my_args = list(args)
1507      for i, n in zip(tensor_indices, inner_tensor_args):
1508        my_args[i] = n
1509      return f(*my_args, **kwds)
1510
1511    return script_ops.py_func(inner_f, tensor_args, [])
1512
1513  return tf_decorator.make_decorator(f, decorated)
1514
1515
1516def also_run_as_tf_function(f):
1517  """Runs the decorated test twice--once as is, once inside a tf.function.
1518
1519  This allows you to run a test both in eager execution and inside a
1520  tf.function, exercising the two execution modes supported in tf 2.0. The test
1521  assertions are automatically done inside tf.py_funcs, and tf.function ensures
1522  that they run in the proper order and with the proper side effects.
1523
1524  Currently variable creation is not supported in tests annotated with this
1525  decorator since it's tricky to ensure the variable doesn't get repeatedly
1526  created when retracing the tf.function.
1527
1528  Args:
1529    f: the test method to be decorated
1530
1531  Returns:
1532    The decorated test method, which will run both in eager and inside a
1533    tf.function.
1534  """
1535
1536  def decorated(*args, **kwds):
1537
1538    def bound_f():
1539      f(*args, **kwds)
1540
1541    with context.eager_mode():
1542      # Running in eager mode
1543      bound_f()
1544      # Running as TF function
1545      # TODO(b/121143941): Remove the autograph override.
1546      def_function.function(bound_f, autograph=False)()
1547
1548  return decorated
1549
1550
1551def deprecated_graph_mode_only(func=None):
1552  """Execute the decorated test in graph mode.
1553
1554  This function returns a decorator intended to be applied to tests that are not
1555  compatible with eager mode. When this decorator is applied, the test body will
1556  be run in an environment where API calls construct graphs instead of executing
1557  eagerly.
1558
1559  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1560  `run_in_graph_and_eager_modes` are available decorators for different
1561  v1/v2/eager/graph combinations.
1562
1563  Args:
1564    func: function to be annotated. If `func` is None, this method returns a
1565      decorator the can be applied to a function. If `func` is not None this
1566      returns the decorator applied to `func`.
1567
1568  Returns:
1569    Returns a decorator that will run the decorated test method in graph mode.
1570  """
1571
1572  def decorator(f):
1573    if tf_inspect.isclass(f):
1574      setup = f.__dict__.get("setUp")
1575      if setup is not None:
1576        setattr(f, "setUp", decorator(setup))
1577
1578      for name, value in f.__dict__.copy().items():
1579        if (callable(value) and
1580            name.startswith(unittest.TestLoader.testMethodPrefix)):
1581          setattr(f, name, decorator(value))
1582
1583      return f
1584
1585    def decorated(self, *args, **kwargs):
1586      if context.executing_eagerly():
1587        with context.graph_mode():
1588          return f(self, *args, **kwargs)
1589      else:
1590        return f(self, *args, **kwargs)
1591
1592    return decorated
1593
1594  if func is not None:
1595    return decorator(func)
1596
1597  return decorator
1598
1599
1600run_deprecated_v1 = deprecated_graph_mode_only
1601
1602
1603def run_all_in_deprecated_graph_mode_only(cls):
1604  """Execute all tests in a class in graph mode."""
1605  base_decorator = deprecated_graph_mode_only
1606  for name in dir(cls):
1607    if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
1608        name == "test_session"):
1609      continue
1610    value = getattr(cls, name, None)
1611    if callable(value):
1612      setattr(cls, name, base_decorator(value))
1613  return cls
1614
1615
1616def run_v1_only(reason, func=None):
1617  """Execute the decorated test only if running in v1 mode.
1618
1619  This function is intended to be applied to tests that exercise v1 only
1620  functionality. If the test is run in v2 mode it will simply be skipped.
1621
1622  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1623  `run_in_graph_and_eager_modes` are available decorators for different
1624  v1/v2/eager/graph combinations.
1625
1626  Args:
1627    reason: string giving a reason for limiting the test to v1 only.
1628    func: function to be annotated. If `func` is None, this method returns a
1629      decorator the can be applied to a function. If `func` is not None this
1630      returns the decorator applied to `func`.
1631
1632  Returns:
1633    Returns a decorator that will conditionally skip the decorated test method.
1634  """
1635  if not isinstance(reason, str):
1636    raise ValueError("'reason' should be string, got {}".format(type(reason)))
1637
1638  def decorator(f):
1639    if tf_inspect.isclass(f):
1640      # To skip an entire test suite class, we only decorate the setUp method
1641      # to skip all tests. There are cases when setUp is not defined (not
1642      # overridden in subclasses of TestCase, so not available in f.__dict__
1643      # below). For those cases, we walk the method resolution order list and
1644      # pick the first setUp method we find (usually this should be the one in
1645      # the parent class since that's the TestCase class).
1646      for cls in type.mro(f):
1647        setup = cls.__dict__.get("setUp")
1648        if setup is not None:
1649          setattr(f, "setUp", decorator(setup))
1650          break
1651
1652      return f
1653    else:
1654      # If f is just a function, just create a decorator for it and return it
1655      def decorated(self, *args, **kwargs):
1656        if tf2.enabled():
1657          self.skipTest(reason)
1658
1659        return f(self, *args, **kwargs)
1660
1661      return decorated
1662
1663  if func is not None:
1664    return decorator(func)
1665
1666  return decorator
1667
1668
1669def run_v2_only(func=None):
1670  """Execute the decorated test only if running in v2 mode.
1671
1672  This function is intended to be applied to tests that exercise v2 only
1673  functionality. If the test is run in v1 mode it will simply be skipped.
1674
1675  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1676  `run_in_graph_and_eager_modes` are available decorators for different
1677  v1/v2/eager/graph combinations.
1678
1679  Args:
1680    func: function to be annotated. If `func` is None, this method returns a
1681      decorator the can be applied to a function. If `func` is not None this
1682      returns the decorator applied to `func`.
1683
1684  Returns:
1685    Returns a decorator that will conditionally skip the decorated test method.
1686  """
1687
1688  def decorator(f):
1689    if tf_inspect.isclass(f):
1690      raise ValueError("`run_v2_only` only supports test methods.")
1691
1692    def decorated(self, *args, **kwargs):
1693      if not tf2.enabled():
1694        self.skipTest("Test is only compatible with v2")
1695
1696      return f(self, *args, **kwargs)
1697
1698    return decorated
1699
1700  if func is not None:
1701    return decorator(func)
1702
1703  return decorator
1704
1705
1706def run_gpu_only(func=None):
1707  """Execute the decorated test only if a GPU is available.
1708
1709  This function is intended to be applied to tests that require the presence
1710  of a GPU. If a GPU is absent, it will simply be skipped.
1711
1712  Args:
1713    func: function to be annotated. If `func` is None, this method returns a
1714      decorator the can be applied to a function. If `func` is not None this
1715      returns the decorator applied to `func`.
1716
1717  Returns:
1718    Returns a decorator that will conditionally skip the decorated test method.
1719  """
1720
1721  def decorator(f):
1722    if tf_inspect.isclass(f):
1723      raise ValueError("`run_gpu_only` only supports test methods.")
1724
1725    def decorated(self, *args, **kwargs):
1726      if not is_gpu_available():
1727        self.skipTest("Test requires GPU")
1728
1729      return f(self, *args, **kwargs)
1730
1731    return decorated
1732
1733  if func is not None:
1734    return decorator(func)
1735
1736  return decorator
1737
1738
1739def run_cuda_only(func=None):
1740  """Execute the decorated test only if a GPU is available.
1741
1742  This function is intended to be applied to tests that require the presence
1743  of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped.
1744
1745  Args:
1746    func: function to be annotated. If `func` is None, this method returns a
1747      decorator the can be applied to a function. If `func` is not None this
1748      returns the decorator applied to `func`.
1749
1750  Returns:
1751    Returns a decorator that will conditionally skip the decorated test method.
1752  """
1753
1754  def decorator(f):
1755    if tf_inspect.isclass(f):
1756      raise ValueError("`run_cuda_only` only supports test methods.")
1757
1758    def decorated(self, *args, **kwargs):
1759      if not is_gpu_available(cuda_only=True):
1760        self.skipTest("Test requires CUDA GPU")
1761
1762      return f(self, *args, **kwargs)
1763
1764    return decorated
1765
1766  if func is not None:
1767    return decorator(func)
1768
1769  return decorator
1770
1771
1772def run_gpu_or_tpu(func=None):
1773  """Execute the decorated test only if a physical GPU or TPU is available.
1774
1775  This function is intended to be applied to tests that require the presence
1776  of a physical GPU or TPU. It complies with the following rules:
1777  - If a GPU is available, the test will run on the GPU.
1778  - If a GPU is absent and a TPU is available, the test will run on the TPU.
1779  - If both GPU and TPU are absent, the test will be skipped.
1780
1781  Args:
1782    func: function to be annotated. If `func` is None, this method returns a
1783      decorator the can be applied to a function. If `func` is not None this
1784      returns the decorator applied to `func`.
1785
1786  Returns:
1787    Returns a decorator that will conditionally skip the decorated test method.
1788  """
1789
1790  def decorator(f):
1791    if tf_inspect.isclass(f):
1792      raise ValueError("`run_gpu_or_tpu` only supports test methods.")
1793
1794    def decorated(self, *args, **kwargs):
1795      if config.list_physical_devices("GPU"):
1796        return f(self, "GPU", *args, **kwargs)
1797
1798      if config.list_physical_devices("TPU"):
1799        return f(self, "TPU", *args, **kwargs)
1800
1801      self.skipTest("Test requires GPU or TPU")
1802
1803    return decorated
1804
1805  return decorator if func is None else decorator(func)
1806
1807
1808def with_forward_compatibility_horizons(*horizons):
1809  """Executes the decorated test with the specified forward-compat horizons.
1810
1811  Args:
1812    *horizons: A list of (year, month, day) tuples.  If the list includes
1813      `None`, then the test will also be run with no forward-compatibility
1814      horizon set.
1815
1816  Returns:
1817    A decorator that will execute the test with the specified horizons.
1818  """
1819  if not horizons:
1820    raise ValueError("Expected at least one horizon.")
1821  for horizon in horizons:
1822    if not ((horizon is None) or
1823            (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))):
1824      raise ValueError("Bad horizon value: %r" % horizon)
1825
1826  def decorator(f):
1827    if tf_inspect.isclass(f):
1828      raise ValueError("`with_forward_compatibility_horizons` only "
1829                       "supports test methods.")
1830    def decorated(self, *args, **kwargs):
1831      for horizon in horizons:
1832        if horizon is None:
1833          f(self, *args, **kwargs)
1834        else:
1835          (year, month, day) = horizon
1836          with forward_compatibility_horizon(year, month, day):
1837            f(self, *args, **kwargs)
1838    return decorated
1839
1840  return decorator
1841
1842
1843@deprecation.deprecated(None,
1844                        "Use `tf.config.list_physical_devices('GPU')` instead.")
1845@tf_export("test.is_gpu_available")
1846def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
1847  """Returns whether TensorFlow can access a GPU.
1848
1849  Warning: if a non-GPU version of the package is installed, the function would
1850  also return False. Use `tf.test.is_built_with_cuda` to validate if TensorFlow
1851  was build with CUDA support.
1852
1853  For example,
1854  >>> gpu_available = tf.test.is_gpu_available()
1855  >>> is_cuda_gpu_available = tf.test.is_gpu_available(cuda_only=True)
1856  >>> is_cuda_gpu_min_3 = tf.test.is_gpu_available(True, (3,0))
1857
1858  Args:
1859    cuda_only: limit the search to CUDA GPUs.
1860    min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
1861      CUDA compute capability required, or None if no requirement.
1862
1863  Note that the keyword arg name "cuda_only" is misleading (since routine will
1864  return true when a GPU device is available irrespective of whether TF was
1865  built with CUDA support or ROCm support. However no changes here because
1866
1867  ++ Changing the name "cuda_only" to something more generic would break
1868     backward compatibility
1869
1870  ++ Adding an equivalent "rocm_only" would require the implementation check
1871     the build type. This in turn would require doing the same for CUDA and thus
1872     potentially break backward compatibility
1873
1874  ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility,
1875     but would require most (if not all) callers to update the call to use
1876     "cuda_or_rocm_only" instead of "cuda_only"
1877
1878  Returns:
1879    True if a GPU device of the requested kind is available.
1880  """
1881
1882  # This was needed earlier when we had support for SYCL in TensorFlow.
1883  del cuda_only
1884
1885  try:
1886    for local_device in device_lib.list_local_devices():
1887      if local_device.device_type == "GPU":
1888        gpu_info = gpu_util.compute_capability_from_device_desc(local_device)
1889        cc = gpu_info.compute_capability or (0, 0)
1890        if not min_cuda_compute_capability or cc >= min_cuda_compute_capability:
1891          return True
1892    return False
1893  except errors_impl.NotFoundError as e:
1894    if not all(x in str(e) for x in ["CUDA", "not find"]):
1895      raise e
1896    else:
1897      logging.error(str(e))
1898      return False
1899
1900
1901@contextlib.contextmanager
1902def device(use_gpu):
1903  """Uses gpu when requested and available."""
1904  if use_gpu and is_gpu_available():
1905    dev = "/device:GPU:0"
1906  else:
1907    dev = "/device:CPU:0"
1908  with ops.device(dev):
1909    yield
1910
1911
1912@contextlib.contextmanager
1913def use_gpu():
1914  """Uses gpu when requested and available."""
1915  with device(use_gpu=True):
1916    yield
1917
1918
1919@contextlib.contextmanager
1920def force_gpu():
1921  """Force the gpu to be used."""
1922  with ops.device("/device:GPU:0"):
1923    yield
1924
1925
1926@contextlib.contextmanager
1927def force_cpu():
1928  """Force the cpu to be used."""
1929  with ops.device("/device:CPU:0"):
1930    yield
1931
1932
1933@contextlib.contextmanager
1934def deterministic_ops():
1935  """Enables deterministic ops."""
1936  try:
1937    config.enable_op_determinism()
1938    yield
1939  finally:
1940    config.disable_op_determinism()
1941
1942
1943class CapturedWrites:
1944  """A utility class to load the captured writes made to a stream."""
1945
1946  def __init__(self, capture_location):
1947    self.capture_location = capture_location
1948
1949  def contents(self):
1950    """Get the captured writes as a single string."""
1951    with open(self.capture_location) as tmp_file:
1952      output_data = "".join(tmp_file.readlines())
1953    return output_data
1954
1955
1956class FakeEagerSession:
1957  """Fake session so tests that conditionally use placeholders can use eager.
1958
1959  There are a number of tests that conditionally use placeholders for shape
1960  inference. The pattern is demonstrated here:
1961
1962  ```python
1963  with self.cached_session() as sess:
1964    if static_shape:
1965      y = math_ops.matmul(x, ...)
1966      feed_dict = {}
1967    else:
1968      x_ph = array_ops.placeholder(...)
1969      y = math_ops.matmul(x_ph, ...)
1970      feed_dict = {x_ph: x}
1971    val = sess.run(y, feed_dict=feed_dict)
1972  ```
1973
1974  Since the feed_dict is empty when not using placeholders we should be able to
1975  call self.evaluate(), however this requires rewriting the test case.
1976  This class should be considered a stop-gap solution to get tests running with
1977  eager with minimal changes to the actual test.
1978  """
1979
1980  def __init__(self, test_case):
1981    self._test_case = test_case
1982
1983  def run(self, fetches, *args, **kwargs):
1984    """Evaluate `fetches`.
1985
1986    Fail if additional args are specified.
1987
1988    Args:
1989      fetches: A Tensor or a nested list/tuple of Tensors.
1990      *args: Positional arguments
1991      **kwargs: Keyword arguments
1992
1993    Raises:
1994      RuntimeError: If args or kwargs are specified.
1995
1996    Returns:
1997      Tensors as numpy values.
1998    """
1999    feed_dict = kwargs.pop("feed_dict", {})
2000    if feed_dict:
2001      raise RuntimeError(
2002          "feed_dict is not supported when eager execution is enabled "
2003          "(in this case, sess.run(t) is shorthand for t.numpy()")
2004
2005    if args or kwargs:
2006      raise RuntimeError(
2007          "Optional args are not supported when eager execution is enabled "
2008          "(in this case, sess.run(t) is shorthand for t.numpy()")
2009
2010    return self._test_case.evaluate(fetches)
2011
2012
2013class ErrorLoggingSession(session.Session):
2014  """Wrapper around a Session that logs errors in run()."""
2015
2016  def run(self, *args, **kwargs):
2017    try:
2018      return super().run(*args, **kwargs)
2019    except Exception as e:  # pylint: disable=broad-except
2020      # Note: disable the logging for OutOfRangeError, which makes the output
2021      # of tf.data tests hard to read, because OutOfRangeError is used as the
2022      # signal completion
2023      if not isinstance(e, errors.OutOfRangeError):
2024        logging.error(str(e))
2025      raise
2026
2027
2028def disable_cudnn_autotune(func):
2029  """Disable autotuning during the call to this function.
2030
2031  Some tests want to base assertions on a graph being isomorphic with a copy.
2032  To ensure this, this decorator disables autotuning.
2033
2034  Args:
2035    func: Function to run with CuDNN autotuning turned off.
2036
2037  Returns:
2038    Decorated function.
2039  """
2040
2041  def decorator(f):
2042
2043    def decorated(self, *args, **kwargs):
2044      original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE")
2045      os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false"
2046      original_xla_flags = os.environ.get("XLA_FLAGS")
2047      new_xla_flags = "--xla_gpu_autotune_level=0"
2048      if original_xla_flags:
2049        new_xla_flags = original_xla_flags + " " + new_xla_flags
2050      os.environ["XLA_FLAGS"] = new_xla_flags
2051
2052      result = f(self, *args, **kwargs)
2053
2054      if (original_tf_cudnn_use_autotune is None):
2055        del os.environ["TF_CUDNN_USE_AUTOTUNE"]
2056      else:
2057        os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune
2058      if (original_xla_flags is None):
2059        del os.environ["XLA_FLAGS"]
2060      else:
2061        os.environ["XLA_FLAGS"] = original_xla_flags
2062
2063      return result
2064
2065    return decorated
2066
2067  if func is not None:
2068    return decorator(func)
2069
2070  return decorator
2071
2072
2073# The description is just for documentation purposes.
2074def enable_tf_xla_constant_folding(description):
2075
2076  if not isinstance(description, str):
2077    raise ValueError("'description' should be string, got {}".format(
2078        type(description)))
2079
2080  def enable_tf_xla_constant_folding_impl(func):
2081    """Enable constant folding during the call to this function.
2082
2083    Some tests fail without constant folding.
2084
2085    Args:
2086      func: Function to run with constant folding turned on.
2087
2088    Returns:
2089      Decorated function.
2090    """
2091
2092    def decorator(f):
2093
2094      def decorated(self, *args, **kwargs):
2095        original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled()
2096        pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False)
2097        result = f(self, *args, **kwargs)
2098        pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var)
2099        return result
2100
2101      return decorated
2102
2103    if func is not None:
2104      return decorator(func)
2105
2106    return decorator
2107
2108  return enable_tf_xla_constant_folding_impl
2109
2110
2111# Updates test function by selectively disabling it.
2112def _disable_test(execute_func):
2113
2114  def disable_test_impl(func):
2115
2116    def decorator(func):
2117
2118      def decorated(self, *args, **kwargs):
2119        if execute_func:
2120          return func(self, *args, **kwargs)
2121
2122      return tf_decorator.make_decorator(func, decorated)
2123
2124    if func is not None:
2125      return decorator(func)
2126
2127    return decorator
2128
2129  return disable_test_impl
2130
2131
2132# The description is just for documentation purposes.
2133def disable_xla(description):  # pylint: disable=unused-argument
2134  """Execute the test method only if xla is not enabled."""
2135  execute_func = not is_xla_enabled()
2136  return _disable_test(execute_func)
2137
2138
2139# The description is just for documentation purposes.
2140def disable_mlir_bridge(description):  # pylint: disable=unused-argument
2141  """Execute the test method only if MLIR bridge is not enabled."""
2142  execute_func = not is_mlir_bridge_enabled()
2143  return _disable_test(execute_func)
2144
2145
2146# The description is just for documentation purposes.
2147def disable_asan(description):  # pylint: disable=unused-argument
2148  """Execute the test method only if ASAN is not enabled."""
2149  execute_func = not is_asan_enabled()
2150  return _disable_test(execute_func)
2151
2152
2153# The description is just for documentation purposes.
2154def disable_msan(description):  # pylint: disable=unused-argument
2155  """Execute the test method only if MSAN is not enabled."""
2156  execute_func = not is_msan_enabled()
2157  return _disable_test(execute_func)
2158
2159
2160# The description is just for documentation purposes.
2161def disable_tsan(description):  # pylint: disable=unused-argument
2162  """Execute the test method only if TSAN is not enabled."""
2163  execute_func = not is_tsan_enabled()
2164  return _disable_test(execute_func)
2165
2166
2167# The description is just for documentation purposes.
2168def disable_ubsan(description):  # pylint: disable=unused-argument
2169  """Execute the test method only if UBSAN is not enabled."""
2170  execute_func = not is_ubsan_enabled()
2171  return _disable_test(execute_func)
2172
2173
2174# The description is just for documentation purposes.
2175def disable_tfrt(unused_description):
2176
2177  def disable_tfrt_impl(cls_or_func):
2178    """Execute the test only if tfrt is not enabled."""
2179
2180    if tf_inspect.isclass(cls_or_func):
2181      if tfrt_utils.enabled():
2182        return None
2183      else:
2184        return cls_or_func
2185    else:
2186      def decorator(func):
2187
2188        def decorated(self, *args, **kwargs):
2189          if tfrt_utils.enabled():
2190            return
2191          else:
2192            return func(self, *args, **kwargs)
2193
2194        return decorated
2195
2196      if cls_or_func is not None:
2197        return decorator(cls_or_func)
2198
2199      return decorator
2200
2201  return disable_tfrt_impl
2202
2203
2204def for_all_test_methods(decorator, *args, **kwargs):
2205  """Generate class-level decorator from given method-level decorator.
2206
2207  It is expected for the given decorator to take some arguments and return
2208  a method that is then called on the test method to produce a decorated
2209  method.
2210
2211  Args:
2212    decorator: The decorator to apply.
2213    *args: Positional arguments
2214    **kwargs: Keyword arguments
2215  Returns: Function that will decorate a given classes test methods with the
2216    decorator.
2217  """
2218
2219  def all_test_methods_impl(cls):
2220    """Apply decorator to all test methods in class."""
2221    for name in dir(cls):
2222      value = getattr(cls, name)
2223      if callable(value) and name.startswith(
2224          "test") and (name != "test_session"):
2225        setattr(cls, name, decorator(*args, **kwargs)(value))
2226    return cls
2227
2228  return all_test_methods_impl
2229
2230
2231# The description is just for documentation purposes.
2232def no_xla_auto_jit(description):  # pylint: disable=unused-argument
2233  """This test is not intended to be run with XLA auto jit enabled."""
2234  execute_func = not is_xla_enabled()
2235  return _disable_test(execute_func)
2236
2237
2238# The description is just for documentation purposes.
2239def xla_allow_fallback(description):  # pylint: disable=unused-argument
2240
2241  def xla_allow_fallback_impl(func):
2242    """Allow fallback to TF even though testing xla."""
2243
2244    def decorator(func):
2245
2246      def decorated(self, *args, **kwargs):
2247        if is_xla_enabled():
2248          # Update the global XLABuildOpsPassFlags to enable lazy compilation,
2249          # which allows the compiler to fall back to TF classic. Remember the
2250          # old value so that we can reset it.
2251          old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True)
2252          result = func(self, *args, **kwargs)
2253          pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value)
2254          return result
2255        else:
2256          return func(self, *args, **kwargs)
2257
2258      return decorated
2259
2260    if func is not None:
2261      return decorator(func)
2262
2263    return decorator
2264
2265  return xla_allow_fallback_impl
2266
2267
2268# The description is just for documentation purposes.
2269def run_without_tensor_float_32(description):  # pylint: disable=unused-argument
2270  """Execute test with TensorFloat-32 disabled.
2271
2272  While almost every real-world deep learning model runs fine with
2273  TensorFloat-32, many tests use assertAllClose or similar methods.
2274  TensorFloat-32 matmuls typically will cause such methods to fail with the
2275  default tolerances.
2276
2277  Args:
2278    description: A description used for documentation purposes, describing why
2279      the test requires TensorFloat-32 to be disabled.
2280
2281  Returns:
2282    Decorator which runs a test with TensorFloat-32 disabled.
2283  """
2284
2285  def decorator(f):
2286
2287    @functools.wraps(f)
2288    def decorated(self, *args, **kwargs):
2289      allowed = config.tensor_float_32_execution_enabled()
2290      try:
2291        config.enable_tensor_float_32_execution(False)
2292        f(self, *args, **kwargs)
2293      finally:
2294        config.enable_tensor_float_32_execution(allowed)
2295
2296    return decorated
2297
2298  return decorator
2299
2300
2301# The description is just for documentation purposes.
2302def run_all_without_tensor_float_32(description):  # pylint: disable=unused-argument
2303  """Execute all tests in a class with TensorFloat-32 disabled."""
2304  return for_all_test_methods(run_without_tensor_float_32, description)
2305
2306
2307def matmul_without_tf32(a, b, *args, **kwargs):
2308  """Run matmul but cast float32 inputs to float64 if TensorFloat-32 is enabled.
2309
2310  This effectively runs matmul without TensorFloat-32. It should only be used in
2311  tests when verifying some other op or functions works correctly, e.g. to test
2312  `tf.linalg.sqrtm` by matrix multiplying the output of the op by itself. In
2313  such cases, the matmul itself is not being tested so it's OK to run it with
2314  higher precision.
2315
2316  If a matmul itself is being tested, or some other op which uses matmul, use
2317  `run_without_tensor_float_32` instead.
2318
2319  This also casts complex64 inputs to complex128, since TensorFloat-32 can also
2320  be used with complex64
2321
2322  Args:
2323    a: First input to tf.linalg.matmul
2324    b: Second input to tf.linalg.matmul
2325    args: Other positional arguments to tf.linalg.matmul
2326    **kwargs: Other keyword arguments to tf.linalg.matmul
2327
2328  Returns:
2329    A tensor with the same type as `a`.
2330  """
2331  if config.tensor_float_32_execution_enabled() and a.dtype == "float32":
2332    a = math_ops.cast(a, "float64")
2333    b = math_ops.cast(b, "float64")
2334    ret = math_ops.matmul(a, b, *args, **kwargs)
2335    return math_ops.cast(ret, a.dtype)
2336  elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64":
2337    a = math_ops.cast(a, "complex128")
2338    b = math_ops.cast(b, "complex128")
2339    ret = math_ops.matmul(a, b, *args, **kwargs)
2340    return math_ops.cast(ret, a.dtype)
2341  else:
2342    return math_ops.matmul(a, b, *args, **kwargs)
2343
2344
2345class EagerSessionWarner:
2346
2347  def __getattr__(self, attr):
2348    raise AttributeError(
2349        "Trying to access properties or call methods on the result of "
2350        "self.session(), self.cached_session(), etc while eager execution "
2351        "is enabled. If you're porting this test case to TF 2.0, either "
2352        "adapt the test to work with eager execution or insert a call to "
2353        "tf.disable_eager_execution() in the main() function of this test "
2354        "file.")
2355
2356
2357@tf_export("test.TestCase")
2358class TensorFlowTestCase(googletest.TestCase):
2359  """Base class for tests that need to test TensorFlow."""
2360
2361  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
2362    super().__init__(methodName)
2363    # Make sure we get unfiltered stack traces during the test
2364    traceback_utils.disable_traceback_filtering()
2365    if is_xla_enabled():
2366      pywrap_tf_session.TF_SetXlaAutoJitMode("2")
2367      pywrap_tf_session.TF_SetXlaMinClusterSize(1)
2368      pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False)
2369      pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True)
2370      # Constant folding secretly runs code on TF:Classic CPU, so we also
2371      # disable it here.
2372      pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
2373
2374    # Check if the mlir bridge has been explicitly enabled or disabled. If
2375    # is_mlir_bridge_enabled() returns None, the user did not explictly enable
2376    # or disable the bridge so do not update enable_mlir_bridge.
2377    if is_mlir_bridge_enabled():
2378      context.context().enable_mlir_bridge = True
2379    elif is_mlir_bridge_enabled() is not None:
2380      context.context().enable_mlir_bridge = False
2381
2382    self._threads = []
2383    self._tempdir = None
2384    self._cached_session = None
2385    self._test_start_time = None
2386    # This flag provides the ability to control whether the graph mode gets
2387    # initialized for TF1 or not. Initializing for TF1, which is what was
2388    # happening earlier, was preventing enablement of 'eager mode' in the test.
2389    self._set_default_seed = True
2390
2391  def setUp(self):
2392    super().setUp()
2393    self._ClearCachedSession()
2394    random.seed(random_seed.DEFAULT_GRAPH_SEED)
2395    np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
2396    # Note: The following line is necessary because some test methods may error
2397    # out from within nested graph contexts (e.g., via assertRaises and
2398    # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty
2399    # under certain versions of Python. That would cause
2400    # ops.reset_default_graph() to throw an exception if the stack were not
2401    # cleared first.
2402    ops._default_graph_stack.reset()  # pylint: disable=protected-access
2403    ops.reset_default_graph()
2404    if self._set_default_seed:
2405      random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
2406    # Reset summary writer in case another test used set_as_default() with their
2407    # summary writer.
2408    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
2409    summary_state.writer = None
2410
2411    # Avoiding calling setUp() for the poorly named test_session method.
2412    if self.id().endswith(".test_session"):
2413      self.skipTest("Not a test.")
2414
2415    self._test_start_time = time.time()
2416
2417  def tearDown(self):
2418    # If a subclass overrides setUp and doesn't call the parent class's setUp,
2419    # then we may not have set the start time.
2420    if self._test_start_time is not None:
2421      logging.info("time(%s): %ss", self.id(),
2422                   round(time.time() - self._test_start_time, 2))
2423
2424    for thread in self._threads:
2425      thread.check_termination()
2426
2427    self._ClearCachedSession()
2428    super().tearDown()
2429
2430  def _ClearCachedSession(self):
2431    if self._cached_session is not None:
2432      self._cached_session.close()
2433      self._cached_session = None
2434
2435  def get_temp_dir(self):
2436    """Returns a unique temporary directory for the test to use.
2437
2438    If you call this method multiple times during in a test, it will return the
2439    same folder. However, across different runs the directories will be
2440    different. This will ensure that across different runs tests will not be
2441    able to pollute each others environment.
2442    If you need multiple unique directories within a single test, you should
2443    use tempfile.mkdtemp as follows:
2444      tempfile.mkdtemp(dir=self.get_temp_dir()):
2445
2446    Returns:
2447      string, the path to the unique temporary directory created for this test.
2448    """
2449    if not self._tempdir:
2450      self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
2451    return self._tempdir
2452
2453  @contextlib.contextmanager
2454  def captureWritesToStream(self, stream):
2455    """A context manager that captures the writes to a given stream.
2456
2457    This context manager captures all writes to a given stream inside of a
2458    `CapturedWrites` object. When this context manager is created, it yields
2459    the `CapturedWrites` object. The captured contents can be accessed  by
2460    calling `.contents()` on the `CapturedWrites`.
2461
2462    For this function to work, the stream must have a file descriptor that
2463    can be modified using `os.dup` and `os.dup2`, and the stream must support
2464    a `.flush()` method. The default python sys.stdout and sys.stderr are
2465    examples of this. Note that this does not work in Colab or Jupyter
2466    notebooks, because those use alternate stdout streams.
2467
2468    Example:
2469    ```python
2470    class MyOperatorTest(test_util.TensorFlowTestCase):
2471      def testMyOperator(self):
2472        input = [1.0, 2.0, 3.0, 4.0, 5.0]
2473        with self.captureWritesToStream(sys.stdout) as captured:
2474          result = MyOperator(input).eval()
2475        self.assertStartsWith(captured.contents(), "This was printed.")
2476    ```
2477
2478    Args:
2479      stream: The stream whose writes should be captured. This stream must have
2480        a file descriptor, support writing via using that file descriptor, and
2481        must have a `.flush()` method.
2482
2483    Yields:
2484      A `CapturedWrites` object that contains all writes to the specified stream
2485      made during this context.
2486    """
2487    stream.flush()
2488    fd = stream.fileno()
2489    tmp_file, tmp_file_path = tempfile.mkstemp(dir=self.get_temp_dir())
2490    orig_fd = os.dup(fd)
2491    os.dup2(tmp_file, fd)
2492    try:
2493      yield CapturedWrites(tmp_file_path)
2494    finally:
2495      os.close(tmp_file)
2496      os.dup2(orig_fd, fd)
2497
2498  def _AssertProtoEquals(self, a, b, msg=None):
2499    """Asserts that a and b are the same proto.
2500
2501    Uses ProtoEq() first, as it returns correct results
2502    for floating point attributes, and then use assertProtoEqual()
2503    in case of failure as it provides good error messages.
2504
2505    Args:
2506      a: a proto.
2507      b: another proto.
2508      msg: Optional message to report on failure.
2509    """
2510    if not compare.ProtoEq(a, b):
2511      compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)
2512
2513  def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
2514    """Asserts that message is same as parsed expected_message_ascii.
2515
2516    Creates another prototype of message, reads the ascii message into it and
2517    then compares them using self._AssertProtoEqual().
2518
2519    Args:
2520      expected_message_maybe_ascii: proto message in original or ascii form.
2521      message: the message to validate.
2522      msg: Optional message to report on failure.
2523    """
2524    if isinstance(expected_message_maybe_ascii, type(message)):
2525      expected_message = expected_message_maybe_ascii
2526      self._AssertProtoEquals(expected_message, message, msg=msg)
2527    elif isinstance(expected_message_maybe_ascii, (str, bytes)):
2528      expected_message = type(message)()
2529      text_format.Merge(
2530          expected_message_maybe_ascii,
2531          expected_message,
2532          descriptor_pool=descriptor_pool.Default())
2533      self._AssertProtoEquals(expected_message, message, msg=msg)
2534    else:
2535      assert False, ("Can't compare protos of type %s and %s." %
2536                     (type(expected_message_maybe_ascii), type(message)))
2537
2538  def assertProtoEqualsVersion(
2539      self,
2540      expected,
2541      actual,
2542      producer=versions.GRAPH_DEF_VERSION,
2543      min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER,
2544      msg=None):
2545    expected = "versions { producer: %d min_consumer: %d };\n%s" % (
2546        producer, min_consumer, expected)
2547    self.assertProtoEquals(expected, actual, msg=msg)
2548
2549  def assertStartsWith(self, actual, expected_start, msg=None):
2550    """Assert that actual.startswith(expected_start) is True.
2551
2552    Args:
2553      actual: str
2554      expected_start: str
2555      msg: Optional message to report on failure.
2556    """
2557    if not actual.startswith(expected_start):
2558      fail_msg = "%r does not start with %r" % (actual, expected_start)
2559      fail_msg += " : %r" % (msg) if msg else ""
2560      self.fail(fail_msg)
2561
2562  def _eval_tensor(self, tensor):
2563    if tensor is None:
2564      return None
2565    elif callable(tensor):
2566      return self._eval_helper(tensor())
2567    else:
2568      try:
2569        # for compatibility with TF1 test cases
2570        if sparse_tensor.is_sparse(tensor):
2571          return sparse_tensor.SparseTensorValue(tensor.indices.numpy(),
2572                                                 tensor.values.numpy(),
2573                                                 tensor.dense_shape.numpy())
2574        elif ragged_tensor.is_ragged(tensor):
2575          return ragged_tensor_value.RaggedTensorValue(
2576              self._eval_tensor(tensor.values),
2577              self._eval_tensor(tensor.row_splits))
2578        elif isinstance(tensor, indexed_slices.IndexedSlices):
2579          return indexed_slices.IndexedSlicesValue(
2580              values=tensor.values.numpy(),
2581              indices=tensor.indices.numpy(),
2582              dense_shape=None
2583              if tensor.dense_shape is None else tensor.dense_shape.numpy())
2584        else:
2585          if hasattr(tensor, "numpy") and callable(tensor.numpy):
2586            return tensor.numpy()
2587          else:
2588            # Try our best to convert CompositeTensor components to NumPy
2589            # arrays. Officially, we don't support NumPy arrays as
2590            # CompositeTensor components. So don't be surprised if this doesn't
2591            # work.
2592            return nest.map_structure(lambda t: t.numpy(), tensor,
2593                                      expand_composites=True)
2594      except AttributeError as e:
2595        raise ValueError(f"Unsupported type {type(tensor).__name__!r}.") from e
2596
2597  def _eval_helper(self, tensors):
2598    if tensors is None:
2599      return None
2600    return nest.map_structure(self._eval_tensor, tensors)
2601
2602  def evaluate(self, tensors):
2603    """Evaluates tensors and returns numpy values.
2604
2605    Args:
2606      tensors: A Tensor or a nested list/tuple of Tensors.
2607
2608    Returns:
2609      tensors numpy values.
2610    """
2611    if context.executing_eagerly():
2612      return self._eval_helper(tensors)
2613    else:
2614      sess = ops.get_default_session()
2615      if sess is None:
2616        with self.test_session() as sess:
2617          return sess.run(tensors)
2618      else:
2619        return sess.run(tensors)
2620
2621  # pylint: disable=g-doc-return-or-yield
2622  @contextlib.contextmanager
2623  def session(self, graph=None, config=None, use_gpu=True, force_gpu=False):
2624    """A context manager for a TensorFlow Session for use in executing tests.
2625
2626    Note that this will set this session and the graph as global defaults.
2627
2628    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
2629    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
2630    `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
2631    possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
2632    the CPU.
2633
2634    Example:
2635
2636    ``` python
2637    class MyOperatorTest(test_util.TensorFlowTestCase):
2638      def testMyOperator(self):
2639        with self.session():
2640          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
2641          result = MyOperator(valid_input).eval()
2642          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
2643          invalid_input = [-1.0, 2.0, 7.0]
2644          with self.assertRaisesOpError("negative input not supported"):
2645            MyOperator(invalid_input).eval()
2646    ```
2647
2648    Args:
2649      graph: Optional graph to use during the returned session.
2650      config: An optional config_pb2.ConfigProto to use to configure the
2651        session.
2652      use_gpu: If True, attempt to run as many ops as possible on GPU.
2653      force_gpu: If True, pin all ops to `/device:GPU:0`.
2654
2655    Yields:
2656      A Session object that should be used as a context manager to surround
2657      the graph building and execution code in a test case.
2658    """
2659    if context.executing_eagerly():
2660      yield EagerSessionWarner()
2661    else:
2662      with self._create_session(graph, config, force_gpu) as sess:
2663        with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu):
2664          yield sess
2665
2666  @contextlib.contextmanager
2667  def cached_session(self,
2668                     graph=None,
2669                     config=None,
2670                     use_gpu=True,
2671                     force_gpu=False):
2672    """Returns a TensorFlow Session for use in executing tests.
2673
2674    This method behaves differently than self.session(): for performance reasons
2675    `cached_session` will by default reuse the same session within the same
2676    test. The session returned by this function will only be closed at the end
2677    of the test (in the TearDown function).
2678
2679    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
2680    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
2681    `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
2682    possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
2683    the CPU.
2684
2685    Example:
2686    ```python
2687    class MyOperatorTest(test_util.TensorFlowTestCase):
2688      def testMyOperator(self):
2689        with self.cached_session() as sess:
2690          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
2691          result = MyOperator(valid_input).eval()
2692          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
2693          invalid_input = [-1.0, 2.0, 7.0]
2694          with self.assertRaisesOpError("negative input not supported"):
2695            MyOperator(invalid_input).eval()
2696    ```
2697
2698    Args:
2699      graph: Optional graph to use during the returned session.
2700      config: An optional config_pb2.ConfigProto to use to configure the
2701        session.
2702      use_gpu: If True, attempt to run as many ops as possible on GPU.
2703      force_gpu: If True, pin all ops to `/device:GPU:0`.
2704
2705    Yields:
2706      A Session object that should be used as a context manager to surround
2707      the graph building and execution code in a test case.
2708    """
2709    if context.executing_eagerly():
2710      yield FakeEagerSession(self)
2711    else:
2712      sess = self._get_cached_session(
2713          graph, config, force_gpu, crash_if_inconsistent_args=True)
2714      with self._constrain_devices_and_set_default(sess, use_gpu,
2715                                                   force_gpu) as cached:
2716        yield cached
2717
2718  @contextlib.contextmanager
2719  @deprecation.deprecated(None, "Use `self.session()` or "
2720                          "`self.cached_session()` instead.")
2721  def test_session(self,
2722                   graph=None,
2723                   config=None,
2724                   use_gpu=True,
2725                   force_gpu=False):
2726    """Use cached_session instead."""
2727    if self.id().endswith(".test_session"):
2728      self.skipTest(
2729          "Tests that have the name \"test_session\" are automatically skipped "
2730          "by TensorFlow test fixture, as the name is reserved for creating "
2731          "sessions within tests. Please rename your test if you have a test "
2732          "with this name.")
2733    if context.executing_eagerly():
2734      yield None
2735    else:
2736      if graph is None:
2737        sess = self._get_cached_session(
2738            graph, config, force_gpu, crash_if_inconsistent_args=False)
2739        with self._constrain_devices_and_set_default(sess, use_gpu,
2740                                                     force_gpu) as cached:
2741          yield cached
2742      else:
2743        with self.session(graph, config, use_gpu, force_gpu) as sess:
2744          yield sess
2745
2746  # pylint: enable=g-doc-return-or-yield
2747
2748  class _CheckedThread(object):
2749    """A wrapper class for Thread that asserts successful completion.
2750
2751    This class should be created using the TensorFlowTestCase.checkedThread()
2752    method.
2753    """
2754
2755    def __init__(self, testcase, target, args=None, kwargs=None):
2756      """Constructs a new instance of _CheckedThread.
2757
2758      Args:
2759        testcase: The TensorFlowTestCase for which this thread is being created.
2760        target: A callable object representing the code to be executed in the
2761          thread.
2762        args: A tuple of positional arguments that will be passed to target.
2763        kwargs: A dictionary of keyword arguments that will be passed to target.
2764      """
2765      self._testcase = testcase
2766      self._target = target
2767      self._args = () if args is None else args
2768      self._kwargs = {} if kwargs is None else kwargs
2769      self._thread = threading.Thread(target=self._protected_run)
2770      self._exception = None
2771
2772      self._is_thread_joined = False
2773
2774    def _protected_run(self):
2775      """Target for the wrapper thread. Sets self._exception on failure."""
2776      try:
2777        self._target(*self._args, **self._kwargs)
2778      except Exception as e:  # pylint: disable=broad-except
2779        self._exception = e
2780
2781    def start(self):
2782      """Starts the thread's activity.
2783
2784      This must be called at most once per _CheckedThread object. It arranges
2785      for the object's target to be invoked in a separate thread of control.
2786      """
2787      self._thread.start()
2788
2789    def join(self):
2790      """Blocks until the thread terminates.
2791
2792      Raises:
2793        self._testcase.failureException: If the thread terminates with due to
2794          an exception.
2795      """
2796      self._is_thread_joined = True
2797      self._thread.join()
2798      if self._exception is not None:
2799        self._testcase.fail("Error in checkedThread: %s" % str(self._exception))
2800
2801    def is_alive(self):
2802      """Returns whether the thread is alive.
2803
2804      This method returns True just before the run() method starts
2805      until just after the run() method terminates.
2806
2807      Returns:
2808        True if the thread is alive, otherwise False.
2809      """
2810      return self._thread.is_alive()
2811
2812    def check_termination(self):
2813      """Returns whether the checked thread was properly used and did terminate.
2814
2815      Every checked thread should be "join"ed after starting, and before the
2816      test tears down. If it is not joined, it is possible the thread will hang
2817      and cause flaky failures in tests.
2818
2819      Raises:
2820        self._testcase.failureException: If check_termination was called before
2821        thread was joined.
2822
2823        RuntimeError: If the thread is not terminated. This means thread was not
2824        joined with the main thread.
2825      """
2826      if self._is_thread_joined:
2827        if self.is_alive():
2828          raise RuntimeError(
2829              "Thread was not joined with main thread, and is still running "
2830              "when the test finished.")
2831      else:
2832        self._testcase.fail("A checked thread was not joined.")
2833
2834  def checkedThread(self, target, args=None, kwargs=None):
2835    """Returns a Thread wrapper that asserts 'target' completes successfully.
2836
2837    This method should be used to create all threads in test cases, as
2838    otherwise there is a risk that a thread will silently fail, and/or
2839    assertions made in the thread will not be respected.
2840
2841    Args:
2842      target: A callable object to be executed in the thread.
2843      args: The argument tuple for the target invocation. Defaults to ().
2844      kwargs: A dictionary of keyword arguments for the target invocation.
2845        Defaults to {}.
2846
2847    Returns:
2848      A wrapper for threading.Thread that supports start() and join() methods.
2849    """
2850    ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs)
2851    self._threads.append(ret)
2852    return ret
2853
2854  # pylint: enable=invalid-name
2855  @py_func_if_in_function
2856  def assertNear(self, f1, f2, err, msg=None):
2857    """Asserts that two floats are near each other.
2858
2859    Checks that |f1 - f2| < err and asserts a test failure
2860    if not.
2861
2862    Args:
2863      f1: A float value.
2864      f2: A float value.
2865      err: A float value.
2866      msg: An optional string message to append to the failure message.
2867    """
2868    # f1 == f2 is needed here as we might have: f1, f2 = inf, inf
2869    self.assertTrue(
2870        f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" %
2871        (f1, f2, err, " (%s)" % msg if msg is not None else ""))
2872
2873  @py_func_if_in_function
2874  def assertArrayNear(self, farray1, farray2, err, msg=None):
2875    """Asserts that two float arrays are near each other.
2876
2877    Checks that for all elements of farray1 and farray2
2878    |f1 - f2| < err.  Asserts a test failure if not.
2879
2880    Args:
2881      farray1: a list of float values.
2882      farray2: a list of float values.
2883      err: a float value.
2884      msg: Optional message to report on failure.
2885    """
2886    self.assertEqual(len(farray1), len(farray2), msg=msg)
2887    for f1, f2 in zip(farray1, farray2):
2888      self.assertNear(float(f1), float(f2), err, msg=msg)
2889
2890  def _NDArrayNear(self, ndarray1, ndarray2, err):
2891    return np.linalg.norm(ndarray1 - ndarray2) < err
2892
2893  @py_func_if_in_function
2894  def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
2895    """Asserts that two numpy arrays have near values.
2896
2897    Args:
2898      ndarray1: a numpy ndarray.
2899      ndarray2: a numpy ndarray.
2900      err: a float. The maximum absolute difference allowed.
2901      msg: Optional message to report on failure.
2902    """
2903    self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg)
2904
2905  def _GetNdArray(self, a):
2906    # If a is tensor-like then convert it to ndarray
2907    if tensor_util.is_tf_type(a):
2908      if isinstance(a, ops._EagerTensorBase):
2909        a = a.numpy()
2910      else:
2911        a = self.evaluate(a)
2912    if not isinstance(a, np.ndarray):
2913      return np.array(a)
2914    return a
2915
2916  def evaluate_if_both_tensors(self, a, b):
2917    if (tensor_util.is_tf_type(a) and tensor_util.is_tf_type(b) and
2918        not isinstance(a, ops._EagerTensorBase) and
2919        not isinstance(b, ops._EagerTensorBase)):
2920      return self.evaluate((a, b))
2921    else:
2922      return (a, b)
2923
2924  def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2925    (a, b) = self.evaluate_if_both_tensors(a, b)
2926    a = self._GetNdArray(a)
2927    b = self._GetNdArray(b)
2928    # When the array rank is small, print its contents. Numpy array printing is
2929    # implemented using inefficient recursion so prints can cause tests to
2930    # time out.
2931    if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
2932      shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
2933                            "%s.") % (a.shape, b.shape, b)
2934    else:
2935      shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
2936                                                                     b.shape)
2937    self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
2938
2939    msgs = [msg]
2940    # np.allclose does not always work for our custom bfloat16 extension type
2941    # when type promotions are involved, so we first cast any bfloat16 arrays
2942    # to float32.
2943    a_dtype = a.dtype
2944    a = a.astype(np.float32) if a.dtype == dtypes.bfloat16.as_numpy_dtype else a
2945    b = b.astype(np.float32) if b.dtype == dtypes.bfloat16.as_numpy_dtype else b
2946    if not np.allclose(a, b, rtol=rtol, atol=atol):
2947      # Adds more details to np.testing.assert_allclose.
2948      #
2949      # NOTE: numpy.allclose (and numpy.testing.assert_allclose)
2950      # checks whether two arrays are element-wise equal within a
2951      # tolerance. The relative difference (rtol * abs(b)) and the
2952      # absolute difference atol are added together to compare against
2953      # the absolute difference between a and b.  Here, we want to
2954      # tell user which elements violate such conditions.
2955      cond = np.logical_or(
2956          np.abs(a - b) > atol + rtol * np.abs(b),
2957          np.isnan(a) != np.isnan(b))
2958      if a.ndim:
2959        x = a[np.where(cond)]
2960        y = b[np.where(cond)]
2961        msgs.append("not close where = {}".format(np.where(cond)))
2962      else:
2963        # np.where is broken for scalars
2964        x, y = a, b
2965      msgs.append("not close lhs = {}".format(x))
2966      msgs.append("not close rhs = {}".format(y))
2967      msgs.append("not close dif = {}".format(np.abs(x - y)))
2968      msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
2969      msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape))
2970      # TODO(xpan): There seems to be a bug:
2971      # tensorflow/compiler/tests:binary_ops_test pass with float32
2972      # nan even though the equal_nan is False by default internally.
2973      np.testing.assert_allclose(
2974          a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True)
2975
2976  def _assertAllCloseRecursive(self,
2977                               a,
2978                               b,
2979                               rtol=1e-6,
2980                               atol=1e-6,
2981                               path=None,
2982                               msg=None):
2983    if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b):
2984      return self._assertRaggedClose(a, b, rtol, atol, msg)
2985    path = path or []
2986    path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "")
2987    msg = msg if msg else ""
2988
2989    # Check if a and/or b are namedtuples.
2990    if hasattr(a, "_asdict"):
2991      a = a._asdict()
2992    if hasattr(b, "_asdict"):
2993      b = b._asdict()
2994    a_is_dict = isinstance(a, collections_abc.Mapping)
2995    if a_is_dict != isinstance(b, collections_abc.Mapping):
2996      raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
2997                       (path_str, path_str, msg))
2998    if a_is_dict:
2999      self.assertItemsEqual(
3000          a.keys(),
3001          b.keys(),
3002          msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" %
3003          (path_str, a.keys(), path_str, b.keys(), msg))
3004      for k in a:
3005        path.append(k)
3006        self._assertAllCloseRecursive(
3007            a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg)
3008        del path[-1]
3009    elif isinstance(a, (list, tuple)):
3010      # Try to directly compare a, b as ndarrays; if not work, then traverse
3011      # through the sequence, which is more expensive.
3012      try:
3013        (a, b) = self.evaluate_if_both_tensors(a, b)
3014        a_as_ndarray = self._GetNdArray(a)
3015        b_as_ndarray = self._GetNdArray(b)
3016        self._assertArrayLikeAllClose(
3017            a_as_ndarray,
3018            b_as_ndarray,
3019            rtol=rtol,
3020            atol=atol,
3021            msg="Mismatched value: a%s is different from b%s. %s" %
3022            (path_str, path_str, msg))
3023      except (ValueError, TypeError, NotImplementedError) as e:
3024        if len(a) != len(b):
3025          raise ValueError(
3026              "Mismatched length: a%s has %d items, but b%s has %d items. %s" %
3027              (path_str, len(a), path_str, len(b), msg))
3028        for idx, (a_ele, b_ele) in enumerate(zip(a, b)):
3029          path.append(str(idx))
3030          self._assertAllCloseRecursive(
3031              a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg)
3032          del path[-1]
3033    # a and b are ndarray like objects
3034    else:
3035      try:
3036        self._assertArrayLikeAllClose(
3037            a,
3038            b,
3039            rtol=rtol,
3040            atol=atol,
3041            msg=("Mismatched value: a%s is different from b%s. %s" %
3042                 (path_str, path_str, msg)))
3043      except TypeError as e:
3044        msg = ("Error: a%s has %s, but b%s has %s. %s" %
3045               (path_str, type(a), path_str, type(b), msg))
3046        e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
3047        raise
3048
3049  @py_func_if_in_function
3050  def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
3051    """Asserts that two structures of numpy arrays or Tensors, have near values.
3052
3053    `a` and `b` can be arbitrarily nested structures. A layer of a nested
3054    structure can be a `dict`, `namedtuple`, `tuple` or `list`.
3055
3056    Note: the implementation follows
3057    [`numpy.allclose`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html)
3058    (and numpy.testing.assert_allclose). It checks whether two arrays are
3059    element-wise equal within a tolerance. The relative difference
3060    (`rtol * abs(b)`) and the absolute difference `atol` are added together
3061    to compare against the absolute difference between `a` and `b`.
3062
3063    Args:
3064      a: The expected numpy `ndarray`, or anything that can be converted into a
3065        numpy `ndarray` (including Tensor), or any arbitrarily nested of
3066        structure of these.
3067      b: The actual numpy `ndarray`, or anything that can be converted into a
3068        numpy `ndarray` (including Tensor), or any arbitrarily nested of
3069        structure of these.
3070      rtol: relative tolerance.
3071      atol: absolute tolerance.
3072      msg: Optional message to report on failure.
3073
3074    Raises:
3075      ValueError: if only one of `a[p]` and `b[p]` is a dict or
3076          `a[p]` and `b[p]` have different length, where `[p]` denotes a path
3077          to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
3078          `[p] = [1]['d']`, then `a[p] = (6, 7)`.
3079    """
3080    self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
3081
3082  @py_func_if_in_function
3083  def assertAllCloseAccordingToType(self,
3084                                    a,
3085                                    b,
3086                                    rtol=1e-6,
3087                                    atol=1e-6,
3088                                    float_rtol=1e-6,
3089                                    float_atol=1e-6,
3090                                    half_rtol=1e-3,
3091                                    half_atol=1e-3,
3092                                    bfloat16_rtol=1e-2,
3093                                    bfloat16_atol=1e-2,
3094                                    msg=None):
3095    """Like assertAllClose, but also suitable for comparing fp16 arrays.
3096
3097    In particular, the tolerance is reduced to 1e-3 if at least
3098    one of the arguments is of type float16.
3099
3100    Args:
3101      a: the expected numpy ndarray or anything can be converted to one.
3102      b: the actual numpy ndarray or anything can be converted to one.
3103      rtol: relative tolerance.
3104      atol: absolute tolerance.
3105      float_rtol: relative tolerance for float32.
3106      float_atol: absolute tolerance for float32.
3107      half_rtol: relative tolerance for float16.
3108      half_atol: absolute tolerance for float16.
3109      bfloat16_rtol: relative tolerance for bfloat16.
3110      bfloat16_atol: absolute tolerance for bfloat16.
3111      msg: Optional message to report on failure.
3112    """
3113    (a, b) = self.evaluate_if_both_tensors(a, b)
3114    a = self._GetNdArray(a)
3115    b = self._GetNdArray(b)
3116    # types with lower tol are put later to overwrite previous ones.
3117    if (a.dtype == np.float32 or b.dtype == np.float32 or
3118        a.dtype == np.complex64 or b.dtype == np.complex64):
3119      rtol = max(rtol, float_rtol)
3120      atol = max(atol, float_atol)
3121    if a.dtype == np.float16 or b.dtype == np.float16:
3122      rtol = max(rtol, half_rtol)
3123      atol = max(atol, half_atol)
3124    if (a.dtype == dtypes.bfloat16.as_numpy_dtype or
3125        b.dtype == dtypes.bfloat16.as_numpy_dtype):
3126      rtol = max(rtol, bfloat16_rtol)
3127      atol = max(atol, bfloat16_atol)
3128
3129    self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
3130
3131  @py_func_if_in_function
3132  def assertNotAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
3133    """Assert that two numpy arrays, or Tensors, do not have near values.
3134
3135    Args:
3136      a: The expected numpy `ndarray`, or anything that can be converted into a
3137        numpy `ndarray` (including Tensor), or any arbitrarily nested of
3138        structure of these.
3139      b: The actual numpy `ndarray`, or anything that can be converted into a
3140        numpy `ndarray` (including Tensor), or any arbitrarily nested of
3141        structure of these.
3142      rtol: relative tolerance.
3143      atol: absolute tolerance.
3144      msg: Optional message to report on failure.
3145
3146    Raises:
3147      AssertionError: If `a` and `b` are unexpectedly close at all elements.
3148    """
3149    try:
3150      self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
3151    except AssertionError:
3152      return
3153    msg = msg or ""
3154    raise AssertionError("The two values are close at all elements. %s" % msg)
3155
3156  @py_func_if_in_function
3157  def assertAllEqual(self, a, b, msg=None):
3158    """Asserts that two numpy arrays or Tensors have the same values.
3159
3160    Args:
3161      a: the expected numpy ndarray or anything can be converted to one.
3162      b: the actual numpy ndarray or anything can be converted to one.
3163      msg: Optional message to report on failure.
3164    """
3165    if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)):
3166      return self._assertRaggedEqual(a, b, msg)
3167    msg = msg if msg else ""
3168    (a, b) = self.evaluate_if_both_tensors(a, b)
3169    a = self._GetNdArray(a)
3170    b = self._GetNdArray(b)
3171    # Arbitrary bounds so that we don't print giant tensors.
3172    if (b.ndim <= 3 or b.size < 500):
3173      self.assertEqual(
3174          a.shape, b.shape, "Shape mismatch: expected %s, got %s."
3175          " Contents: %r. \n%s." % (a.shape, b.shape, b, msg))
3176    else:
3177      self.assertEqual(
3178          a.shape, b.shape, "Shape mismatch: expected %s, got %s."
3179          " %s" % (a.shape, b.shape, msg))
3180
3181    same = (a == b)
3182
3183    if (a.dtype in [
3184        np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
3185    ]):
3186      same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
3187    msgs = [msg]
3188    if not np.all(same):
3189      # Adds more details to np.testing.assert_array_equal.
3190      diff = np.logical_not(same)
3191      if a.ndim:
3192        x = a[np.where(diff)]
3193        y = b[np.where(diff)]
3194        msgs.append("not equal where = {}".format(np.where(diff)))
3195      else:
3196        # np.where is broken for scalars
3197        x, y = a, b
3198      msgs.append("not equal lhs = %r" % x)
3199      msgs.append("not equal rhs = %r" % y)
3200
3201      if (a.dtype.kind != b.dtype.kind and
3202          {a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})):
3203        a_list = []
3204        b_list = []
3205        # OK to flatten `a` and `b` because they are guaranteed to have the
3206        # same shape.
3207        for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]:
3208          for item in flat_arr:
3209            if isinstance(item, str):
3210              out_list.append(item.encode("utf-8"))
3211            else:
3212              out_list.append(item)
3213        a = np.array(a_list)
3214        b = np.array(b_list)
3215
3216      np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
3217
3218  @py_func_if_in_function
3219  def assertNotAllEqual(self, a, b, msg=None):
3220    """Asserts that two numpy arrays or Tensors do not have the same values.
3221
3222    Args:
3223      a: the expected numpy ndarray or anything can be converted to one.
3224      b: the actual numpy ndarray or anything can be converted to one.
3225      msg: Optional message to report on failure.
3226    """
3227    try:
3228      self.assertAllEqual(a, b)
3229    except AssertionError:
3230      return
3231    msg = msg or ""
3232    raise AssertionError("The two values are equal at all elements. %s" % msg)
3233
3234  @py_func_if_in_function
3235  def assertAllGreater(self, a, comparison_target):
3236    """Assert element values are all greater than a target value.
3237
3238    Args:
3239      a: The numpy `ndarray`, or anything that can be converted into a numpy
3240        `ndarray` (including Tensor).
3241      comparison_target: The target value of comparison.
3242    """
3243    (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3244    a = self._GetNdArray(a)
3245    self.assertGreater(np.min(a), comparison_target)
3246
3247  @py_func_if_in_function
3248  def assertAllLess(self, a, comparison_target):
3249    """Assert element values are all less than a target value.
3250
3251    Args:
3252      a: The numpy `ndarray`, or anything that can be converted into a numpy
3253        `ndarray` (including Tensor).
3254      comparison_target: The target value of comparison.
3255    """
3256    (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3257    a = self._GetNdArray(a)
3258    self.assertLess(np.max(a), comparison_target)
3259
3260  @py_func_if_in_function
3261  def assertAllGreaterEqual(self, a, comparison_target):
3262    """Assert element values are all greater than or equal to a target value.
3263
3264    Args:
3265      a: The numpy `ndarray`, or anything that can be converted into a numpy
3266        `ndarray` (including Tensor).
3267      comparison_target: The target value of comparison.
3268    """
3269    (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3270    a = self._GetNdArray(a)
3271    self.assertGreaterEqual(np.min(a), comparison_target)
3272
3273  @py_func_if_in_function
3274  def assertAllLessEqual(self, a, comparison_target):
3275    """Assert element values are all less than or equal to a target value.
3276
3277    Args:
3278      a: The numpy `ndarray`, or anything that can be converted into a numpy
3279        `ndarray` (including Tensor).
3280      comparison_target: The target value of comparison.
3281    """
3282    (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3283    a = self._GetNdArray(a)
3284    self.assertLessEqual(np.max(a), comparison_target)
3285
3286  def _format_subscripts(self, subscripts, value, limit=10, indent=2):
3287    """Generate a summary of ndarray subscripts as a list of str.
3288
3289    If limit == N, this method will print up to the first N subscripts on
3290    separate
3291    lines. A line of ellipses (...) will be appended at the end if the number of
3292    subscripts exceeds N.
3293
3294    Args:
3295      subscripts: The tensor (np.ndarray) subscripts, of the same format as
3296        np.where()'s return value, i.e., a tuple of arrays with each array
3297        corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])).
3298      value: (np.ndarray) value of the tensor.
3299      limit: (int) The maximum number of indices to print.
3300      indent: (int) Number of characters to indent at the beginning of each
3301        line.
3302
3303    Returns:
3304      (list of str) the multi-line representation of the subscripts and values,
3305        potentially with omission at the end.
3306    """
3307    lines = []
3308    subscripts = np.transpose(subscripts)
3309    prefix = " " * indent
3310    if np.ndim(value) == 0:
3311      return [prefix + "[0] : " + str(value)]
3312    for subscript in itertools.islice(subscripts, limit):
3313      lines.append(prefix + str(subscript) + " : " +
3314                   str(value[tuple(subscript)]))
3315    if len(subscripts) > limit:
3316      lines.append(prefix + "...")
3317    return lines
3318
3319  @py_func_if_in_function
3320  def assertAllInRange(self,
3321                       target,
3322                       lower_bound,
3323                       upper_bound,
3324                       open_lower_bound=False,
3325                       open_upper_bound=False):
3326    """Assert that elements in a Tensor are all in a given range.
3327
3328    Args:
3329      target: The numpy `ndarray`, or anything that can be converted into a
3330        numpy `ndarray` (including Tensor).
3331      lower_bound: lower bound of the range
3332      upper_bound: upper bound of the range
3333      open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather
3334        than the default >=)
3335      open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather
3336        than the default <=)
3337
3338    Raises:
3339      AssertionError:
3340        if the value tensor does not have an ordered numeric type (float* or
3341          int*), or
3342        if there are nan values, or
3343        if any of the elements do not fall in the specified range.
3344    """
3345    target = self._GetNdArray(target)
3346    if not (np.issubdtype(target.dtype, np.floating) or
3347            np.issubdtype(target.dtype, np.integer)):
3348      raise AssertionError(
3349          "The value of %s does not have an ordered numeric type, instead it "
3350          "has type: %s" % (target, target.dtype))
3351
3352    nan_subscripts = np.where(np.isnan(target))
3353    if np.size(nan_subscripts):
3354      raise AssertionError(
3355          "%d of the %d element(s) are NaN. "
3356          "Subscripts(s) and value(s) of the NaN element(s):\n" %
3357          (len(nan_subscripts[0]), np.size(target)) +
3358          "\n".join(self._format_subscripts(nan_subscripts, target)))
3359
3360    range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " +
3361                 str(upper_bound) + (")" if open_upper_bound else "]"))
3362
3363    violations = (
3364        np.less_equal(target, lower_bound) if open_lower_bound else np.less(
3365            target, lower_bound))
3366    violations = np.logical_or(
3367        violations,
3368        np.greater_equal(target, upper_bound)
3369        if open_upper_bound else np.greater(target, upper_bound))
3370    violation_subscripts = np.where(violations)
3371    if np.size(violation_subscripts):
3372      raise AssertionError(
3373          "%d of the %d element(s) are outside the range %s. " %
3374          (len(violation_subscripts[0]), np.size(target), range_str) +
3375          "Subscript(s) and value(s) of the offending elements:\n" +
3376          "\n".join(self._format_subscripts(violation_subscripts, target)))
3377
3378  @py_func_if_in_function
3379  def assertAllInSet(self, target, expected_set):
3380    """Assert that elements of a Tensor are all in a given closed set.
3381
3382    Args:
3383      target: The numpy `ndarray`, or anything that can be converted into a
3384        numpy `ndarray` (including Tensor).
3385      expected_set: (`list`, `tuple` or `set`) The closed set that the elements
3386        of the value of `target` are expected to fall into.
3387
3388    Raises:
3389      AssertionError:
3390        if any of the elements do not fall into `expected_set`.
3391    """
3392    target = self._GetNdArray(target)
3393
3394    # Elements in target that are not in expected_set.
3395    diff = np.setdiff1d(target.flatten(), list(expected_set))
3396    if np.size(diff):
3397      raise AssertionError("%d unique element(s) are not in the set %s: %s" %
3398                           (np.size(diff), expected_set, diff))
3399
3400  @py_func_if_in_function
3401  def assertDTypeEqual(self, target, expected_dtype):
3402    """Assert ndarray data type is equal to expected.
3403
3404    Args:
3405      target: The numpy `ndarray`, or anything that can be converted into a
3406        numpy `ndarray` (including Tensor).
3407      expected_dtype: Expected data type.
3408    """
3409    target = self._GetNdArray(target)
3410    if not isinstance(target, list):
3411      arrays = [target]
3412    for arr in arrays:
3413      self.assertEqual(arr.dtype, expected_dtype)
3414
3415  # pylint: disable=g-doc-return-or-yield
3416  @contextlib.contextmanager
3417  def assertRaisesWithPredicateMatch(self, exception_type,
3418                                     expected_err_re_or_predicate):
3419    """Returns a context manager to enclose code expected to raise an exception.
3420
3421    If the exception is an OpError, the op stack is also included in the message
3422    predicate search.
3423
3424    Args:
3425      exception_type: The expected type of exception that should be raised.
3426      expected_err_re_or_predicate: If this is callable, it should be a function
3427        of one argument that inspects the passed-in exception and returns True
3428        (success) or False (please fail the test). Otherwise, the error message
3429        is expected to match this regular expression partially.
3430
3431    Returns:
3432      A context manager to surround code that is expected to raise an
3433      exception.
3434    """
3435    if callable(expected_err_re_or_predicate):
3436      predicate = expected_err_re_or_predicate
3437    else:
3438
3439      def predicate(e):
3440        err_str = e.message if isinstance(e, errors.OpError) else str(e)
3441        op = e.op if isinstance(e, errors.OpError) else None
3442        while op is not None:
3443          err_str += "\nCaused by: " + op.name
3444          op = op._original_op  # pylint: disable=protected-access
3445        logging.info("Searching within error strings: '%s' within '%s'",
3446                     expected_err_re_or_predicate, err_str)
3447        return re.search(expected_err_re_or_predicate, err_str)
3448
3449    try:
3450      yield
3451      self.fail(exception_type.__name__ + " not raised")
3452    except Exception as e:  # pylint: disable=broad-except
3453      if not isinstance(e, exception_type) or not predicate(e):
3454        raise AssertionError("Exception of type %s: %s" %
3455                             (str(type(e)), str(e)))
3456
3457  # pylint: enable=g-doc-return-or-yield
3458
3459  def assertRaisesOpError(self, expected_err_re_or_predicate):
3460    return self.assertRaisesWithPredicateMatch(errors.OpError,
3461                                               expected_err_re_or_predicate)
3462
3463  def assertRaisesIncompatibleShapesError(
3464      self, exception_type=errors.InvalidArgumentError):
3465    return self.assertRaisesWithPredicateMatch(
3466        exception_type, r"Incompatible shapes|Dimensions must be equal|"
3467        r"required broadcastable shapes")
3468
3469  def assertShapeEqual(self, input_a, input_b, msg=None):
3470    """Asserts that two Numpy or TensorFlow objects have the same shape.
3471
3472    For Tensors, this compares statically known shapes at compile time, not
3473    dynamic shapes at runtime.
3474
3475    Args:
3476      input_a: A Numpy ndarray, Numpy scalar, or a Tensor.
3477      input_b: A Numpy ndarray, Numpy scalar, or a Tensor.
3478      msg: Optional message to report on failure.
3479
3480    Raises:
3481      TypeError: If the arguments have the wrong type.
3482    """
3483    if not isinstance(input_a, (np.ndarray, np.generic, ops.Tensor)):
3484      raise TypeError(
3485          "input_a must be a Numpy ndarray, Numpy scalar, or a Tensor."
3486          f"Instead received {type(input_a)}")
3487    if not isinstance(input_b, (np.ndarray, np.generic, ops.Tensor)):
3488      raise TypeError(
3489          "input_b must be a Numpy ndarray, Numpy scalar, or a Tensor."
3490          f"Instead received {type(input_b)}")
3491    shape_a = input_a.get_shape().as_list() if isinstance(
3492        input_a, ops.Tensor) else input_a.shape
3493    shape_b = input_b.get_shape().as_list() if isinstance(
3494        input_b, ops.Tensor) else input_b.shape
3495    self.assertAllEqual(shape_a, shape_b, msg=msg)
3496
3497  def assertDeviceEqual(self, device1, device2, msg=None):
3498    """Asserts that the two given devices are the same.
3499
3500    Args:
3501      device1: A string device name or TensorFlow `DeviceSpec` object.
3502      device2: A string device name or TensorFlow `DeviceSpec` object.
3503      msg: Optional message to report on failure.
3504    """
3505    device1 = pydev.canonical_name(device1)
3506    device2 = pydev.canonical_name(device2)
3507    self.assertEqual(
3508        device1, device2,
3509        "Devices %s and %s are not equal. %s" % (device1, device2, msg))
3510
3511  @py_func_if_in_function
3512  def assertDictEqual(self, a, b, msg=None):
3513    """Assert that two given dictionary of tensors are the same.
3514
3515    Args:
3516      a: Expected dictionary with numpy ndarray or anything else that can be
3517        converted to one as values.
3518      b: Actual dictionary with numpy ndarray or anything else that can be
3519        converted to one as values.
3520      msg: Optional message to report on failure.
3521    """
3522    # To keep backwards compatibility, we first try the base class
3523    # assertDictEqual. If that fails we try the tensorflow one.
3524    try:
3525      super().assertDictEqual(a, b, msg)
3526    except Exception:  # pylint: disable=broad-except
3527      self.assertSameElements(a.keys(), b.keys())  # pylint: disable=g-assert-in-except
3528      for k, v in a.items():
3529        (a_k, b_k) = self.evaluate_if_both_tensors(v, b[k])
3530        a_k = self._GetNdArray(a_k)
3531        b_k = self._GetNdArray(b_k)
3532        if np.issubdtype(a_k.dtype, np.floating):
3533          self.assertAllClose(v, b[k], msg=k)
3534        else:
3535          self.assertAllEqual(v, b[k], msg=k)
3536
3537  def _GetPyList(self, a):
3538    """Converts `a` to a nested python list."""
3539    if isinstance(a, ragged_tensor.RaggedTensor):
3540      return self.evaluate(a).to_list()
3541    elif isinstance(a, ops.Tensor):
3542      a = self.evaluate(a)
3543      return a.tolist() if isinstance(a, np.ndarray) else a
3544    elif isinstance(a, np.ndarray):
3545      return a.tolist()
3546    elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
3547      return a.to_list()
3548    else:
3549      return np.array(a).tolist()
3550
3551  def _assertRaggedEqual(self, a, b, msg):
3552    """Asserts that two ragged tensors are equal."""
3553    a_list = self._GetPyList(a)
3554    b_list = self._GetPyList(b)
3555    self.assertEqual(a_list, b_list, msg)
3556
3557    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
3558      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
3559      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
3560      self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
3561
3562  def _assertRaggedClose(self, a, b, rtol, atol, msg=None):
3563    a_list = self._GetPyList(a)
3564    b_list = self._GetPyList(b)
3565    self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg)
3566
3567    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
3568      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
3569      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
3570      self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
3571
3572  def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"):
3573    self.assertEqual(type(a), type(b))
3574    if isinstance(a, (list, tuple)):
3575      self.assertLen(a, len(b), "Length differs for %s" % path)
3576      for i in range(len(a)):
3577        self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg,
3578                                       "%s[%s]" % (path, i))
3579    else:
3580      self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
3581
3582  # Fix Python 3+ compatibility issues
3583  # pylint: disable=invalid-name
3584
3585  # Silence a deprecation warning
3586  assertRaisesRegexp = googletest.TestCase.assertRaisesRegex
3587
3588  # assertItemsEqual is assertCountEqual as of 3.2.
3589  assertItemsEqual = googletest.TestCase.assertCountEqual
3590
3591  # pylint: enable=invalid-name
3592
3593  @contextlib.contextmanager
3594  def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
3595    """Set the session and its graph to global default and constrain devices."""
3596    if context.executing_eagerly():
3597      yield None
3598    else:
3599      with sess.graph.as_default(), sess.as_default():
3600        if force_gpu:
3601          # Use the name of an actual device if one is detected, or
3602          # '/device:GPU:0' otherwise
3603          gpu_name = gpu_device_name()
3604          if not gpu_name:
3605            gpu_name = "/device:GPU:0"
3606          with sess.graph.device(gpu_name):
3607            yield sess
3608        elif use_gpu:
3609          yield sess
3610        else:
3611          with sess.graph.device("/device:CPU:0"):
3612            yield sess
3613
3614  def _create_session(self, graph, config, force_gpu):
3615    """See session() for details."""
3616
3617    def prepare_config(config):
3618      """Returns a config for sessions.
3619
3620      Args:
3621        config: An optional config_pb2.ConfigProto to use to configure the
3622          session.
3623
3624      Returns:
3625        A config_pb2.ConfigProto object.
3626      """
3627      # TODO(b/114333779): Enforce allow_soft_placement=False when
3628      # use_gpu=False. Currently many tests rely on the fact that any device
3629      # will be used even when a specific device is supposed to be used.
3630      allow_soft_placement = not force_gpu
3631      if config is None:
3632        config = context.context().config
3633        config.allow_soft_placement = allow_soft_placement
3634      elif not allow_soft_placement and config.allow_soft_placement:
3635        config_copy = context.context().config
3636        config = config_copy
3637        config.allow_soft_placement = False
3638      # Don't perform optimizations for tests so we don't inadvertently run
3639      # gpu ops on cpu
3640      config.graph_options.optimizer_options.opt_level = -1
3641      # Disable Grappler constant folding since some tests & benchmarks
3642      # use constant input and become meaningless after constant folding.
3643      # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
3644      # GRAPPLER TEAM.
3645      config.graph_options.rewrite_options.constant_folding = (
3646          rewriter_config_pb2.RewriterConfig.OFF)
3647      config.graph_options.rewrite_options.pin_to_host_optimization = (
3648          rewriter_config_pb2.RewriterConfig.OFF)
3649      return config
3650
3651    return ErrorLoggingSession(graph=graph, config=prepare_config(config))
3652
3653  def _get_cached_session(self,
3654                          graph=None,
3655                          config=None,
3656                          force_gpu=False,
3657                          crash_if_inconsistent_args=True):
3658    """See cached_session() for documentation."""
3659    if self._cached_session is None:
3660      sess = self._create_session(
3661          graph=graph, config=config, force_gpu=force_gpu)
3662      self._cached_session = sess
3663      self._cached_graph = graph
3664      self._cached_config = config
3665      self._cached_force_gpu = force_gpu
3666      return sess
3667    else:
3668      if crash_if_inconsistent_args and self._cached_graph is not graph:
3669        raise ValueError("The graph used to get the cached session is "
3670                         "different than the one that was used to create the "
3671                         "session. Maybe create a new session with "
3672                         "self.session()")
3673      if crash_if_inconsistent_args and self._cached_config is not config:
3674        raise ValueError("The config used to get the cached session is "
3675                         "different than the one that was used to create the "
3676                         "session. Maybe create a new session with "
3677                         "self.session()")
3678      if crash_if_inconsistent_args and (self._cached_force_gpu is
3679                                         not force_gpu):
3680        raise ValueError(
3681            "The force_gpu value used to get the cached session is "
3682            "different than the one that was used to create the "
3683            "session. Maybe create a new session with "
3684            "self.session()")
3685      return self._cached_session
3686
3687
3688ASSIGNED_PORTS = set()
3689lock = threading.Lock()
3690
3691
3692def pick_unused_port():
3693  """Returns an unused and unassigned local port."""
3694  import portpicker  # pylint: disable=g-import-not-at-top
3695
3696  global ASSIGNED_PORTS
3697  with lock:
3698    while True:
3699      try:
3700        port = portpicker.pick_unused_port()
3701      except portpicker.NoFreePortFoundError as porterror:
3702        raise unittest.SkipTest("Flakes in portpicker library do not represent"
3703                                " TensorFlow errors.") from porterror
3704      if port > 10000 and port not in ASSIGNED_PORTS:
3705        ASSIGNED_PORTS.add(port)
3706        logging.info("Using local port %r", port)
3707        return port
3708
3709
3710@tf_export("test.create_local_cluster")
3711def create_local_cluster(num_workers,
3712                         num_ps,
3713                         protocol="grpc",
3714                         worker_config=None,
3715                         ps_config=None):
3716  """Create and start local servers and return the associated `Server` objects.
3717
3718  "PS" stands for "parameter server": a task responsible for storing and
3719  updating the model's parameters. Other tasks send updates to these parameters
3720  as they work on optimizing the parameters. This particular division of labor
3721  between tasks is not required, but is common for distributed training.
3722
3723  Read more at https://www.tensorflow.org/guide/extend/architecture
3724
3725  ![components](https://www.tensorflow.org/images/diag1.svg "components")
3726
3727
3728  Figure illustrates the interaction of these components.
3729  "/job:worker/task:0" and "/job:ps/task:0" are both tasks with worker services.
3730
3731
3732  Example:
3733  ```python
3734  workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2)
3735
3736  worker_sessions = [tf.compat.v1.Session(w.target) for w in workers]
3737
3738  with tf.device("/job:ps/task:0"):
3739    ...
3740  with tf.device("/job:ps/task:1"):
3741    ...
3742  with tf.device("/job:worker/task:0"):
3743    ...
3744  with tf.device("/job:worker/task:1"):
3745    ...
3746
3747  worker_sessions[0].run(...)
3748  ```
3749
3750  Args:
3751    num_workers: Number of worker servers to start.
3752    num_ps: Number of PS servers to start.
3753    protocol: Communication protocol. Allowed values are documented in the
3754      documentation of `tf.distribute.Server`.
3755    worker_config: (optional) `tf.ConfigProto` to initialize workers. Can be
3756      used to instantiate multiple devices etc.
3757    ps_config: (optional) `tf.ConfigProto` to initialize PS servers.
3758
3759  Returns:
3760    A tuple `(worker_servers, ps_servers)`.  `worker_servers` is a list
3761    of `num_workers` objects of type `tf.distribute.Server` (all running
3762    locally);
3763    and `ps_servers` is a list of `num_ps` objects of similar type.
3764
3765  Raises:
3766    ImportError: if portpicker module was not found at load time
3767  """
3768  worker_ports = [pick_unused_port() for _ in range(num_workers)]
3769  ps_ports = [pick_unused_port() for _ in range(num_ps)]
3770  cluster_dict = {
3771      "worker": ["localhost:%s" % port for port in worker_ports],
3772      "ps": ["localhost:%s" % port for port in ps_ports]
3773  }
3774  cs = server_lib.ClusterSpec(cluster_dict)
3775
3776  workers = [
3777      server_lib.Server(
3778          cs,
3779          job_name="worker",
3780          protocol=protocol,
3781          task_index=ix,
3782          config=worker_config,
3783          start=True) for ix in range(num_workers)
3784  ]
3785  ps_servers = [
3786      server_lib.Server(
3787          cs,
3788          job_name="ps",
3789          protocol=protocol,
3790          task_index=ix,
3791          config=ps_config,
3792          start=True) for ix in range(num_ps)
3793  ]
3794
3795  return workers, ps_servers
3796
3797
3798def get_node_def_from_graph(node_name, graph_def):
3799  """Returns the `NodeDef` instance for given node name in the graph def.
3800
3801  This method explores only the NodeDefs in `graph_def.node`.
3802
3803  Args:
3804    node_name: Name of the NodeDef to search for.
3805    graph_def: An instance of `GraphDef` proto.
3806
3807  Returns:
3808    the `NodeDef` instance whose name field matches the given node_name or None.
3809  """
3810  for node_def in graph_def.node:
3811    if node_def.name == node_name:
3812      return node_def
3813  return None
3814
3815
3816def set_producer_version(graph, producer_version):
3817  """Sets graph.graph_def_versions.producer to `producer_version`."""
3818  # The C API doesn't expose altering GraphDefVersions. We can indirectly set
3819  # it via import_graph_def though.
3820  graph_def = graph_pb2.GraphDef()
3821  graph_def.versions.producer = producer_version
3822  with graph.as_default():
3823    importer.import_graph_def(graph_def)
3824  assert graph.graph_def_versions.producer, producer_version
3825
3826
3827@contextlib.contextmanager
3828def _fake_gradient_tape_context_manager():
3829  """tf.gradients(...) implemented as tf.GradientTape context manager interface.
3830
3831  This is useful to test tf.gradients() in tests that uses tf.GradientTape().
3832
3833  Yields:
3834    gradient tape instance that's implemented by tf.gradients() underneath.
3835  """
3836  try:
3837    class FakeGradientTape:
3838
3839      def watch(self, x):
3840        pass
3841
3842      def gradient(self, y, x, grad_ys=None):
3843        result = gradients_impl.gradients(y, x, grad_ys)
3844
3845        # Unlike `tape.gradient()`, `tf.gradients()` returns a list for a single
3846        # element. So unpack if needed to match `tape.gradient()` behavior.
3847        if not isinstance(x, (list, tuple)):
3848          assert len(result) == 1
3849          return result[0]
3850
3851        return result
3852
3853    yield FakeGradientTape()
3854  finally:
3855    pass
3856
3857
3858class AbstractGradientTape:
3859  """Abstract GradientTape context manager that has multiple implementations.
3860
3861  This is useful to test both tf.GradientTape() and tf.gradients() without
3862  duplicating tests.
3863  """
3864
3865  def __init__(self, use_tape, persistent=False):
3866    self._use_tape = use_tape
3867    self._persistent = persistent
3868
3869  def __enter__(self):
3870    if self._use_tape:
3871      self._tape_impl = backprop.GradientTape(persistent=self._persistent)
3872    else:
3873      self._tape_impl = _fake_gradient_tape_context_manager()
3874    return self._tape_impl.__enter__()
3875
3876  def __exit__(self, exc_type, exc_val, exc_tb):
3877    self._tape_impl.__exit__(exc_type, exc_val, exc_tb)
3878
3879
3880@contextlib.contextmanager
3881def run_functions_eagerly(run_eagerly):
3882  """Runs functions eagerly if `run_eagerly` is true.
3883
3884  WARNING: Setting `run_eagerly` to True in tests running in V1 graph mode
3885  *WILL NOT* make the tf.function to run eagerly because eager is disabled by
3886  default in V1. Instead, tf.function will run as a traced graph function.
3887
3888  Ensures that the state (for running functions eagerly) is back to the initial
3889  `def_function.RUN_FUNCTIONS_EAGERLY` state.
3890
3891  Args:
3892    run_eagerly: Boolean determining whether to run the function eagerly or not.
3893
3894  Raises:
3895    ValueError if `run_eagerly` is not a boolean.
3896
3897  Yields:
3898    Nothing.
3899  """
3900  if not isinstance(run_eagerly, bool):
3901    raise ValueError(
3902        "Expected bool for `run_eagerly` but got {}".format(run_eagerly))
3903
3904  is_eager = context.executing_eagerly()
3905  if not is_eager and run_eagerly:
3906    logging.warning(
3907        "Running tf.function eagerly in V1 graph mode is not supported. "
3908        "tf.function will be run as a traced graph function.")
3909
3910  initial_state = def_function.functions_run_eagerly()
3911  def_function.run_functions_eagerly(run_eagerly)
3912  try:
3913    yield
3914  finally:
3915    def_function.run_functions_eagerly(initial_state)
3916
3917
3918class TestDelta:
3919  """A utility class to track increments to test counters."""
3920
3921  def __init__(self, name, label):
3922    self.name = name
3923    self.label = label
3924    self.Reset()
3925
3926  def Reset(self):
3927    self.last_value = _test_metrics_util.test_counter_value(
3928        self.name, self.label)
3929
3930  def Get(self):
3931    value = _test_metrics_util.test_counter_value(self.name, self.label)
3932    return value - self.last_value
3933