• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Test utils for tensorflow."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23from collections import OrderedDict
24import contextlib
25import functools
26import gc
27import itertools
28import math
29import os
30import random
31import re
32import tempfile
33import threading
34import time
35import unittest
36
37from absl.testing import parameterized
38import numpy as np
39import six
40
41from google.protobuf import descriptor_pool
42from google.protobuf import text_format
43
44from tensorflow.core.framework import graph_pb2
45from tensorflow.core.protobuf import rewriter_config_pb2
46from tensorflow.python import pywrap_sanitizers
47from tensorflow.python import tf2
48from tensorflow.python.client import device_lib
49from tensorflow.python.client import pywrap_tf_session
50from tensorflow.python.client import session
51from tensorflow.python.compat.compat import forward_compatibility_horizon
52from tensorflow.python.eager import backprop
53from tensorflow.python.eager import context
54from tensorflow.python.eager import def_function
55from tensorflow.python.eager import tape
56from tensorflow.python.framework import config
57from tensorflow.python.framework import device as pydev
58from tensorflow.python.framework import dtypes
59from tensorflow.python.framework import errors
60from tensorflow.python.framework import errors_impl
61from tensorflow.python.framework import gpu_util
62from tensorflow.python.framework import importer
63from tensorflow.python.framework import ops
64from tensorflow.python.framework import random_seed
65from tensorflow.python.framework import sparse_tensor
66from tensorflow.python.framework import tensor_shape
67from tensorflow.python.framework import tensor_util
68from tensorflow.python.framework import tfrt_utils
69from tensorflow.python.framework import versions
70from tensorflow.python.ops import array_ops
71from tensorflow.python.ops import control_flow_util
72from tensorflow.python.ops import control_flow_util_v2
73from tensorflow.python.ops import gradients_impl
74from tensorflow.python.ops import math_ops
75from tensorflow.python.ops import script_ops
76from tensorflow.python.ops import summary_ops_v2
77from tensorflow.python.ops import variables
78from tensorflow.python.ops.ragged import ragged_ops  # pylint: disable=unused-import
79from tensorflow.python.ops.ragged import ragged_tensor
80from tensorflow.python.ops.ragged import ragged_tensor_value
81from tensorflow.python.platform import _pywrap_stacktrace_handler
82from tensorflow.python.platform import googletest
83from tensorflow.python.platform import tf_logging as logging
84from tensorflow.python.training import server_lib
85from tensorflow.python.util import _pywrap_util_port
86from tensorflow.python.util import compat
87from tensorflow.python.util import deprecation
88from tensorflow.python.util import nest
89from tensorflow.python.util import tf_decorator
90from tensorflow.python.util import tf_inspect
91from tensorflow.python.util import traceback_utils
92from tensorflow.python.util.compat import collections_abc
93from tensorflow.python.util.protobuf import compare
94from tensorflow.python.util.tf_export import tf_export
95
96
97# If the below import is made available through the BUILD rule, then this
98# function is overridden and will instead return True and cause Tensorflow
99# graphs to be compiled with XLA.
100def is_xla_enabled():
101  return False
102
103
104try:
105  from tensorflow.python.framework.is_xla_test_true import is_xla_enabled  # pylint: disable=g-import-not-at-top, unused-import
106except Exception:  # pylint: disable=broad-except
107  pass
108
109
110# Uses the same mechanism as above to selectively enable/disable MLIR
111# compilation.
112def is_mlir_bridge_enabled():
113  return None
114
115
116try:
117  from tensorflow.python.framework.is_mlir_bridge_test_false import is_mlir_bridge_enabled  # pylint: disable=g-import-not-at-top, unused-import
118except ImportError:
119  try:
120    from tensorflow.python.framework.is_mlir_bridge_test_true import is_mlir_bridge_enabled  # pylint: disable=g-import-not-at-top, unused-import
121  except ImportError:
122    pass
123
124
125def is_asan_enabled():
126  """Check if ASAN is enabled."""
127  return pywrap_sanitizers.is_asan_enabled()
128
129
130def is_msan_enabled():
131  """Check if MSAN is enabled."""
132  return pywrap_sanitizers.is_msan_enabled()
133
134
135def is_tsan_enabled():
136  """Check if TSAN is enabled."""
137  return pywrap_sanitizers.is_tsan_enabled()
138
139
140def is_ubsan_enabled():
141  """Check if UBSAN is enabled."""
142  return pywrap_sanitizers.is_ubsan_enabled()
143
144
145def _get_object_count_by_type(exclude=()):
146  return (
147      collections.Counter([type(obj).__name__ for obj in gc.get_objects()]) -
148      collections.Counter([type(obj).__name__ for obj in exclude]))
149
150
151@tf_export("test.gpu_device_name")
152def gpu_device_name():
153  """Returns the name of a GPU device if available or a empty string.
154
155  This method should only be used in tests written with `tf.test.TestCase`.
156
157  >>> class MyTest(tf.test.TestCase):
158  ...
159  ...   def test_add_on_gpu(self):
160  ...     if not tf.test.is_built_with_gpu_support():
161  ...       self.skipTest("test is only applicable on GPU")
162  ...
163  ...     with tf.device(tf.test.gpu_device_name()):
164  ...       self.assertEqual(tf.math.add(1.0, 2.0), 3.0)
165
166  """
167  for x in device_lib.list_local_devices():
168    if x.device_type == "GPU":
169      return compat.as_str(x.name)
170  return ""
171
172
173def assert_ops_in_graph(expected_ops, graph):
174  """Assert all expected operations are found.
175
176  Args:
177    expected_ops: `dict<string, string>` of op name to op type.
178    graph: Graph to check.
179
180  Returns:
181    `dict<string, node>` of node name to node.
182
183  Raises:
184    ValueError: If the expected ops are not present in the graph.
185  """
186  actual_ops = {}
187  gd = graph.as_graph_def()
188  for node in gd.node:
189    if node.name in expected_ops:
190      if expected_ops[node.name] != node.op:
191        raise ValueError("Expected op for node %s is different. %s vs %s" %
192                         (node.name, expected_ops[node.name], node.op))
193      actual_ops[node.name] = node
194  if set(expected_ops.keys()) != set(actual_ops.keys()):
195    raise ValueError("Not all expected ops are present. Expected %s, found %s" %
196                     (expected_ops.keys(), actual_ops.keys()))
197  return actual_ops
198
199
200@tf_export("test.assert_equal_graph_def", v1=[])
201def assert_equal_graph_def_v2(expected, actual):
202  """Asserts that two `GraphDef`s are (mostly) the same.
203
204  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
205  nodes, attrs, and control inputs.  Node names are used to match up nodes
206  between the graphs, so the naming of nodes must be consistent. This function
207  ignores randomized attribute values that may appear in V2 checkpoints.
208
209  Args:
210    expected: The `GraphDef` we expected.
211    actual: The `GraphDef` we have.
212
213  Raises:
214    AssertionError: If the `GraphDef`s do not match.
215    TypeError: If either argument is not a `GraphDef`.
216  """
217  assert_equal_graph_def(actual, expected, checkpoint_v2=True,
218                         hash_table_shared_name=True)
219
220
221@tf_export(v1=["test.assert_equal_graph_def"])
222def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False,
223                              hash_table_shared_name=False):
224  """Asserts that two `GraphDef`s are (mostly) the same.
225
226  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
227  nodes, attrs, and control inputs.  Node names are used to match up nodes
228  between the graphs, so the naming of nodes must be consistent.
229
230  Args:
231    actual: The `GraphDef` we have.
232    expected: The `GraphDef` we expected.
233    checkpoint_v2: boolean determining whether to ignore randomized attribute
234      values that appear in V2 checkpoints.
235    hash_table_shared_name: boolean determining whether to ignore randomized
236      shared_names that appear in HashTableV2 op defs.
237
238  Raises:
239    AssertionError: If the `GraphDef`s do not match.
240    TypeError: If either argument is not a `GraphDef`.
241  """
242  assert_equal_graph_def(actual, expected, checkpoint_v2,
243                         hash_table_shared_name)
244
245
246def assert_equal_graph_def(actual, expected, checkpoint_v2=False,
247                           hash_table_shared_name=False):
248  if not isinstance(actual, graph_pb2.GraphDef):
249    raise TypeError("Expected tf.GraphDef for actual, got %s" %
250                    type(actual).__name__)
251  if not isinstance(expected, graph_pb2.GraphDef):
252    raise TypeError("Expected tf.GraphDef for expected, got %s" %
253                    type(expected).__name__)
254
255  if checkpoint_v2:
256    _strip_checkpoint_v2_randomized(actual)
257    _strip_checkpoint_v2_randomized(expected)
258
259  if hash_table_shared_name:
260    _strip_hash_table_shared_name(actual)
261    _strip_hash_table_shared_name(expected)
262
263  diff = pywrap_tf_session.EqualGraphDefWrapper(actual.SerializeToString(),
264                                                expected.SerializeToString())
265  if diff:
266    raise AssertionError(compat.as_str(diff))
267
268
269def assert_meta_graph_protos_equal(tester, a, b):
270  """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
271  # Carefully check the collection_defs
272  tester.assertEqual(set(a.collection_def), set(b.collection_def))
273  collection_keys = a.collection_def.keys()
274  for k in collection_keys:
275    a_value = a.collection_def[k]
276    b_value = b.collection_def[k]
277    proto_type = ops.get_collection_proto_type(k)
278    if proto_type:
279      a_proto = proto_type()
280      b_proto = proto_type()
281      # Number of entries in the collections is the same
282      tester.assertEqual(
283          len(a_value.bytes_list.value), len(b_value.bytes_list.value))
284      for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
285                                              b_value.bytes_list.value):
286        a_proto.ParseFromString(a_value_item)
287        b_proto.ParseFromString(b_value_item)
288        tester.assertProtoEquals(a_proto, b_proto)
289    else:
290      tester.assertEquals(a_value, b_value)
291  # Compared the fields directly, remove their raw values from the
292  # proto comparison below.
293  a.ClearField("collection_def")
294  b.ClearField("collection_def")
295
296  # Check the graph_defs.
297  assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
298  # Check graph_def versions (ignored by assert_equal_graph_def).
299  tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
300  # Compared the fields directly, remove their raw values from the
301  # proto comparison below.
302  a.ClearField("graph_def")
303  b.ClearField("graph_def")
304
305  tester.assertProtoEquals(a, b)
306
307
308# Matches attributes named via _SHARDED_SUFFIX in
309# tensorflow/python/training/saver.py
310_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part"
311
312
313def _strip_checkpoint_v2_randomized(graph_def):
314  for node in graph_def.node:
315    delete_keys = []
316    for attr_key in node.attr:
317      attr_tensor_value = node.attr[attr_key].tensor
318      if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
319        attr_tensor_string_value = attr_tensor_value.string_val[0]
320        if (attr_tensor_string_value and
321            re.match(compat.as_bytes(_SHARDED_SAVE_OP_PATTERN),
322                     attr_tensor_string_value)):
323          delete_keys.append(attr_key)
324    for attr_key in delete_keys:
325      del node.attr[attr_key]
326
327
328_TABLE_SHARED_NAME_PATTERN = r"hash_table_[0-9a-z\-]+"
329
330
331def _strip_hash_table_shared_name(graph_def):
332  for node in graph_def.node:
333    delete_keys = []
334    if node.op == "HashTableV2" and "shared_name" in node.attr:
335      if re.match(compat.as_bytes(_TABLE_SHARED_NAME_PATTERN),
336                  node.attr["shared_name"].s):
337        delete_keys.append("shared_name")
338    for attr_key in delete_keys:
339      del node.attr[attr_key]
340
341
342def IsGoogleCudaEnabled():
343  return _pywrap_util_port.IsGoogleCudaEnabled()
344
345
346def IsBuiltWithROCm():
347  return _pywrap_util_port.IsBuiltWithROCm()
348
349
350def IsBuiltWithXLA():
351  return _pywrap_util_port.IsBuiltWithXLA()
352
353
354def IsBuiltWithNvcc():
355  return _pywrap_util_port.IsBuiltWithNvcc()
356
357
358def GpuSupportsHalfMatMulAndConv():
359  return _pywrap_util_port.GpuSupportsHalfMatMulAndConv()
360
361
362def IsMklEnabled():
363  return (_pywrap_util_port.IsMklEnabled() or
364          os.getenv("TF_ENABLE_ONEDNN_OPTS", "False").lower() in ["true", "1"])
365
366
367def InstallStackTraceHandler():
368  _pywrap_stacktrace_handler.InstallStacktraceHandler()
369
370
371def NHWCToNCHW(input_tensor):
372  """Converts the input from the NHWC format to NCHW.
373
374  Args:
375    input_tensor: a 3-, 4-, or 5-D tensor, or an array representing shape
376
377  Returns:
378    converted tensor or shape array
379  """
380  # tensor dim -> new axis order
381  new_axes = {3: [0, 2, 1], 4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
382  if isinstance(input_tensor, ops.Tensor):
383    ndims = input_tensor.shape.ndims
384    return array_ops.transpose(input_tensor, new_axes[ndims])
385  else:
386    ndims = len(input_tensor)
387    return [input_tensor[a] for a in new_axes[ndims]]
388
389
390def NHWCToNCHW_VECT_C(input_shape_or_tensor):
391  """Transforms the input from the NHWC layout to NCHW_VECT_C layout.
392
393  Note: Does not include quantization or type conversion steps, which should
394  be applied afterwards.
395
396  Args:
397    input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
398
399  Returns:
400    tensor or shape array transformed into NCHW_VECT_C
401
402  Raises:
403    ValueError: if last dimension of `input_shape_or_tensor` is not evenly
404        divisible by 4.
405  """
406  permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
407  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
408  temp_shape = (
409      input_shape_or_tensor.shape.as_list()
410      if is_tensor else input_shape_or_tensor)
411  if temp_shape[-1] % 4 != 0:
412    raise ValueError(
413        "Last dimension of input must be evenly divisible by 4 to convert to "
414        "NCHW_VECT_C.")
415  temp_shape[-1] //= 4
416  temp_shape.append(4)
417  permutation = permutations[len(temp_shape)]
418  if is_tensor:
419    t = array_ops.reshape(input_shape_or_tensor, temp_shape)
420    return array_ops.transpose(t, permutation)
421  else:
422    return [temp_shape[a] for a in permutation]
423
424
425def NCHW_VECT_CToNHWC(input_shape_or_tensor):
426  """Transforms the input from the NCHW_VECT_C layout to NHWC layout.
427
428  Note: Does not include de-quantization or type conversion steps, which should
429  be applied beforehand.
430
431  Args:
432    input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
433
434  Returns:
435    tensor or shape array transformed into NHWC
436
437  Raises:
438    ValueError: if last dimension of `input_shape_or_tensor` is not 4.
439  """
440  permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
441  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
442  input_shape = (
443      input_shape_or_tensor.shape.as_list()
444      if is_tensor else input_shape_or_tensor)
445  if input_shape[-1] != 4:
446    raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
447  permutation = permutations[len(input_shape)]
448  nhwc_shape = [input_shape[a] for a in permutation[:-1]]
449  nhwc_shape[-1] *= input_shape[-1]
450  if is_tensor:
451    t = array_ops.transpose(input_shape_or_tensor, permutation)
452    return array_ops.reshape(t, nhwc_shape)
453  else:
454    return nhwc_shape
455
456
457def NCHWToNHWC(input_tensor):
458  """Converts the input from the NCHW format to NHWC.
459
460  Args:
461    input_tensor: a 4- or 5-D tensor, or an array representing shape
462
463  Returns:
464    converted tensor or shape array
465  """
466  # tensor dim -> new axis order
467  new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]}
468  if isinstance(input_tensor, ops.Tensor):
469    ndims = input_tensor.shape.ndims
470    return array_ops.transpose(input_tensor, new_axes[ndims])
471  else:
472    ndims = len(input_tensor)
473    return [input_tensor[a] for a in new_axes[ndims]]
474
475
476def skip_if(condition):
477  """Skips the decorated function if condition is or evaluates to True.
478
479  Args:
480    condition: Either an expression that can be used in "if not condition"
481      statement, or a callable whose result should be a boolean.
482
483  Returns:
484    The wrapped function
485  """
486
487  def real_skip_if(fn):
488
489    def wrapper(*args, **kwargs):
490      if callable(condition):
491        skip = condition()
492      else:
493        skip = condition
494      if not skip:
495        return fn(*args, **kwargs)
496
497    return wrapper
498
499  return real_skip_if
500
501
502@contextlib.contextmanager
503def skip_if_error(test_obj, error_type, messages=None):
504  """Context manager to skip cases not considered failures by the tests.
505
506  Note that this does not work if used in setUpClass/tearDownClass.
507  Usage in setUp/tearDown works fine just like regular test methods.
508
509  Args:
510    test_obj: A test object provided as `self` in the test methods; this object
511      is usually an instance of `unittest.TestCase`'s subclass and should have
512      `skipTest` method.
513    error_type: The error type to skip. Note that if `messages` are given, both
514      `error_type` and `messages` need to match for the test to be skipped.
515    messages: Optional, a string or list of strings. If `None`, the test will be
516      skipped if `error_type` matches what is raised; otherwise, the test is
517      skipped if any of the `messages` is contained in the message of the error
518      raised, and `error_type` matches the error raised.
519
520  Yields:
521    Nothing.
522  """
523  if messages:
524    messages = nest.flatten(messages)
525  try:
526    yield
527  except error_type as e:
528    if not messages or any(message in str(e) for message in messages):
529      test_obj.skipTest("Skipping error: {}: {}".format(type(e), str(e)))
530    else:
531      raise
532
533
534def enable_c_shapes(fn):
535  """No-op. TODO(b/74620627): Remove this."""
536  return fn
537
538
539def with_c_shapes(cls):
540  """No-op. TODO(b/74620627): Remove this."""
541  return cls
542
543
544def enable_control_flow_v2(fn):
545  """Decorator for enabling CondV2 and WhileV2 on a test.
546
547  Note this enables using CondV2 and WhileV2 after running the test class's
548  setup/teardown methods.
549
550  In addition to this, callers must import the while_v2 module in order to set
551  the _while_v2 module in control_flow_ops.
552
553  Args:
554    fn: the function to be wrapped
555
556  Returns:
557    The wrapped function
558  """
559
560  def wrapper(*args, **kwargs):
561    enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
562    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
563    try:
564      return fn(*args, **kwargs)
565    finally:
566      control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
567
568  return wrapper
569
570
571def with_control_flow_v2(cls):
572  """Adds methods that call original methods with WhileV2 and CondV2 enabled.
573
574  Note this enables CondV2 and WhileV2 in new methods after running the test
575  class's setup method.
576
577  In addition to this, callers must import the while_v2 module in order to set
578  the _while_v2 module in control_flow_ops.
579
580  If a test function has _disable_control_flow_v2 attr set to True (using the
581  @disable_control_flow_v2 decorator), the v2 function is not generated for it.
582
583  Example:
584
585  @test_util.with_control_flow_v2
586  class ControlFlowTest(test.TestCase):
587
588    def testEnabledForV2(self):
589      ...
590
591    @test_util.disable_control_flow_v2("b/xyzabc")
592    def testDisabledForV2(self):
593      ...
594
595  Generated class:
596  class ControlFlowTest(test.TestCase):
597
598    def testEnabledForV2(self):
599      ...
600
601    def testEnabledForV2WithControlFlowV2(self):
602      // Enable V2 flags.
603      testEnabledForV2(self)
604      // Restore V2 flags.
605
606    def testDisabledForV2(self):
607      ...
608
609  Args:
610    cls: class to decorate
611
612  Returns:
613    cls with new test methods added
614  """
615  if control_flow_util.ENABLE_CONTROL_FLOW_V2:
616    return cls
617
618  for name, value in cls.__dict__.copy().items():
619    if (callable(value) and
620        name.startswith(unittest.TestLoader.testMethodPrefix) and
621        not getattr(value, "_disable_control_flow_v2", False)):
622      setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
623  return cls
624
625
626def disable_control_flow_v2(unused_msg):
627  """Decorator for a function in a with_control_flow_v2 enabled test class.
628
629  Blocks the function from being run with v2 control flow ops.
630
631  Args:
632    unused_msg: Reason for disabling.
633
634  Returns:
635    The wrapped function with _disable_control_flow_v2 attr set to True.
636  """
637
638  def wrapper(func):
639    func._disable_control_flow_v2 = True
640    return func
641
642  return wrapper
643
644
645def enable_output_all_intermediates(fn):
646  """Force-enable outputing all intermediates from functional control flow ops.
647
648  Args:
649    fn: the function to be wrapped
650
651  Returns:
652    The wrapped function
653  """
654
655  def wrapper(*args, **kwargs):
656    output_all_intermediates_old = \
657        control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
658    control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = True
659    try:
660      return fn(*args, **kwargs)
661    finally:
662      control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE = \
663          output_all_intermediates_old
664
665  return wrapper
666
667
668def assert_no_new_pyobjects_executing_eagerly(func=None, warmup_iters=2):
669  """Decorator for asserting that no new Python objects persist after a test.
670
671  Runs the test multiple times executing eagerly, first as a warmup and then to
672  let objects accumulate. The warmup helps ignore caches which do not grow as
673  the test is run repeatedly.
674
675  Useful for checking that there are no missing Py_DECREFs in the C exercised by
676  a bit of Python.
677
678  Args:
679    func: The function to test.
680    warmup_iters: The numer of warmup iterations, excluded from measuring.
681
682  Returns:
683    The wrapped function performing the test.
684  """
685
686  def wrap_f(f):
687    def decorator(self, *args, **kwargs):
688      """Warms up, gets object counts, runs the test, checks for new objects."""
689      with context.eager_mode():
690        gc.disable()
691        # Run the test 2 times as warmup, in an attempt to fill up caches, which
692        # should not grow as the test is run repeatedly below.
693        #
694        # TODO(b/117156879): Running warmup twice is black magic; we have seen
695        # tests that fail with 1 warmup run, and pass with 2, on various
696        # versions of python2.7.x.
697        for _ in range(warmup_iters):
698          f(self, *args, **kwargs)
699        # Since we aren't in the normal test lifecycle, we need to manually run
700        # cleanups to clear out their object references.
701        self.doCleanups()
702
703        # Some objects are newly created by _get_object_count_by_type().  So
704        # create and save as a dummy variable to include it as a baseline.
705        obj_count_by_type = _get_object_count_by_type()
706        gc.collect()
707
708        # Make sure any registered functions are cleaned up in the C++ runtime.
709        registered_function_names = context.context().list_function_names()
710
711        # unittest.doCleanups adds to self._outcome with each unwound call.
712        # These objects are retained across gc collections so we exclude them
713        # from the object count calculation.
714        obj_count_by_type = _get_object_count_by_type(
715            exclude=gc.get_referents(self._outcome.errors,
716                                     self._outcome.skipped))
717
718        if ops.has_default_graph():
719          collection_sizes_before = {
720              collection: len(ops.get_collection(collection))
721              for collection in ops.get_default_graph().collections
722          }
723        for _ in range(3):
724          f(self, *args, **kwargs)
725        # Since we aren't in the normal test lifecycle, we need to manually run
726        # cleanups to clear out their object references.
727        self.doCleanups()
728        # Note that gc.get_objects misses anything that isn't subject to garbage
729        # collection (C types). Collections are a common source of leaks, so we
730        # test for collection sizes explicitly.
731        if ops.has_default_graph():
732          for collection_key in ops.get_default_graph().collections:
733            collection = ops.get_collection(collection_key)
734            size_before = collection_sizes_before.get(collection_key, 0)
735            if len(collection) > size_before:
736              raise AssertionError(
737                  ("Collection %s increased in size from "
738                   "%d to %d (current items %s).") %
739                  (collection_key, size_before, len(collection), collection))
740            # Make sure our collection checks don't show up as leaked memory by
741            # removing references to temporary variables.
742            del collection
743            del collection_key
744            del size_before
745          del collection_sizes_before
746        gc.collect()
747
748        # There should be no new Python objects hanging around.
749        obj_count_by_type = (
750            _get_object_count_by_type(
751                exclude=gc.get_referents(self._outcome.errors,
752                                         self._outcome.skipped)) -
753            obj_count_by_type)
754
755        # There should be no newly registered functions hanging around.
756        leftover_functions = (
757            context.context().list_function_names() - registered_function_names)
758        assert not leftover_functions, (
759            "The following functions were newly created: %s" %
760            leftover_functions)
761
762        # In some cases (specifically on MacOS), new_count is somehow
763        # smaller than previous_count.
764        # Using plain assert because not all classes using this decorator
765        # have assertLessEqual
766        assert not obj_count_by_type, (
767            "The following objects were newly created: %s" %
768            str(obj_count_by_type))
769        gc.enable()
770    return decorator
771
772  if func is None:
773    return wrap_f
774  else:
775    return wrap_f(func)
776
777
778def assert_no_new_tensors(f):
779  """Decorator for asserting that no new Tensors persist after a test.
780
781  Mainly useful for checking that code using the Python C API has correctly
782  manipulated reference counts.
783
784  Clears the caches that it knows about, runs the garbage collector, then checks
785  that there are no Tensor or Tensor-like objects still around. This includes
786  Tensors to which something still has a reference (e.g. from missing
787  Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
788  of the objects has __del__ defined).
789
790  Args:
791    f: The test case to run.
792
793  Returns:
794    The decorated test case.
795  """
796
797  def decorator(self, **kwargs):
798    """Finds existing Tensors, runs the test, checks for new Tensors."""
799
800    def _is_tensorflow_object(obj):
801      try:
802        return isinstance(obj,
803                          (ops.Tensor, variables.Variable,
804                           tensor_shape.Dimension, tensor_shape.TensorShape))
805      except (ReferenceError, AttributeError):
806        # If the object no longer exists, we don't care about it.
807        return False
808
809    tensors_before = set(
810        id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
811    outside_executed_eagerly = context.executing_eagerly()
812    # Run the test in a new graph so that collections get cleared when it's
813    # done, but inherit the graph key so optimizers behave.
814    outside_graph_key = ops.get_default_graph()._graph_key
815    with ops.Graph().as_default():
816      ops.get_default_graph()._graph_key = outside_graph_key
817      if outside_executed_eagerly:
818        with context.eager_mode():
819          result = f(self, **kwargs)
820      else:
821        result = f(self, **kwargs)
822    # Make an effort to clear caches, which would otherwise look like leaked
823    # Tensors.
824    context.context()._clear_caches()  # pylint: disable=protected-access
825    gc.collect()
826    tensors_after = [
827        obj for obj in gc.get_objects()
828        if _is_tensorflow_object(obj) and id(obj) not in tensors_before
829    ]
830    if tensors_after:
831      raise AssertionError(("%d Tensors not deallocated after test: %s" % (
832          len(tensors_after),
833          str(tensors_after),
834      )))
835    return result
836
837  return decorator
838
839
840def _find_reference_cycle(objects, idx):
841
842  def get_ignore_reason(obj, denylist):
843    """Tests whether an object should be omitted from the dependency graph."""
844    if len(denylist) > 100:
845      return "<depth limit>"
846    if tf_inspect.isframe(obj):
847      if "test_util.py" in tf_inspect.getframeinfo(obj)[0]:
848        return "<test code>"
849    for b in denylist:
850      if b is obj:
851        return "<test code>"
852    if obj is denylist:
853      return "<test code>"
854    return None
855
856  # Note: this function is meant to help with diagnostics. Its output is purely
857  # a human-readable representation, so you may freely modify it to suit your
858  # needs.
859  def describe(obj, denylist, leaves_only=False):
860    """Returns a custom human-readable summary of obj.
861
862    Args:
863      obj: the value to describe.
864      denylist: same as denylist in get_ignore_reason.
865      leaves_only: boolean flag used when calling describe recursively. Useful
866        for summarizing collections.
867    """
868    if get_ignore_reason(obj, denylist):
869      return "{}{}".format(get_ignore_reason(obj, denylist), type(obj))
870    if tf_inspect.isframe(obj):
871      return "frame: {}".format(tf_inspect.getframeinfo(obj))
872    elif tf_inspect.ismodule(obj):
873      return "module: {}".format(obj.__name__)
874    else:
875      if leaves_only:
876        return "{}, {}".format(type(obj), id(obj))
877      elif isinstance(obj, list):
878        return "list({}): {}".format(
879            id(obj), [describe(e, denylist, leaves_only=True) for e in obj])
880      elif isinstance(obj, tuple):
881        return "tuple({}): {}".format(
882            id(obj), [describe(e, denylist, leaves_only=True) for e in obj])
883      elif isinstance(obj, dict):
884        return "dict({}): {} keys".format(id(obj), len(obj.keys()))
885      elif tf_inspect.isfunction(obj):
886        return "function({}) {}; globals ID: {}".format(
887            id(obj), obj.__name__, id(obj.__globals__))
888      else:
889        return "{}, {}".format(type(obj), id(obj))
890
891  def build_ref_graph(obj, graph, reprs, denylist):
892    """Builds a reference graph as <referrer> -> <list of referents>.
893
894    Args:
895      obj: The object to start from. The graph will be built by recursively
896        adding its referrers.
897      graph: Dict holding the graph to be built. To avoid creating extra
898        references, the graph holds object IDs rather than actual objects.
899      reprs: Auxiliary structure that maps object IDs to their human-readable
900        description.
901      denylist: List of objects to ignore.
902    """
903    referrers = gc.get_referrers(obj)
904    denylist = denylist + (referrers,)
905
906    obj_id = id(obj)
907    for r in referrers:
908      if get_ignore_reason(r, denylist) is None:
909        r_id = id(r)
910        if r_id not in graph:
911          graph[r_id] = []
912        if obj_id not in graph[r_id]:
913          graph[r_id].append(obj_id)
914          build_ref_graph(r, graph, reprs, denylist)
915          reprs[r_id] = describe(r, denylist)
916
917  def find_cycle(el, graph, reprs, path):
918    """Finds and prints a single cycle in the dependency graph."""
919    if el not in graph:
920      return
921    for r in graph[el]:
922      if r in path:
923        logging.error("Reference cycle sample:")
924        for p in path + (r,):
925          logging.error(reprs.get(p, "unknown object " + str(p)))
926        return True
927      else:
928        if find_cycle(r, graph, reprs, path + (r,)):
929          return True
930    return False
931
932  obj = objects[idx]
933  graph = {}  # referrer ID -> object ID
934  reprs = {}  # object ID -> description
935  build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason,
936                                      describe, build_ref_graph, find_cycle))
937  for k in graph:
938    if find_cycle(k, graph, reprs, ()):
939      return True
940  return False
941
942
943def assert_no_garbage_created(f):
944  """Test method decorator to assert that no garbage has been created.
945
946  Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
947  cannot be un-set (i.e. will disable garbage collection for any other unit
948  tests in the same file/shard).
949
950  Args:
951    f: The function to decorate.
952
953  Returns:
954    The decorated function.
955  """
956
957  def decorator(self, **kwargs):
958    """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
959    # Force-load `distribution_strategy_context` to prevent GC at
960    # test time when using eager. Remove once b/117329403 is resolved.
961    tape.distribution_strategy_context.get_strategy()
962
963    gc.disable()
964    previous_debug_flags = gc.get_debug()
965    gc.set_debug(gc.DEBUG_SAVEALL)
966    gc.collect()
967    previous_garbage = len(gc.garbage)
968    result = f(self, **kwargs)
969    gc.collect()
970    new_garbage = len(gc.garbage)
971    if new_garbage > previous_garbage:
972
973      for i, obj in enumerate(gc.garbage[previous_garbage:]):
974        # Known false positive for ast.fix_missing_locations.
975        if getattr(obj, "__module__", "") == "ast":
976          new_garbage -= 3
977
978    if new_garbage > previous_garbage:
979      logging.error(
980          "The decorated test created work for Python's garbage collector, "
981          "likely due to a reference cycle. New objects in cycle(s):")
982      for i, obj in enumerate(gc.garbage[previous_garbage:]):
983        try:
984          logging.error("Object %d of %d", i,
985                        len(gc.garbage) - previous_garbage)
986
987          def _safe_object_str(obj):
988            return "<%s %d>" % (obj.__class__.__name__, id(obj))
989
990          logging.error("  Object type: %s", _safe_object_str(obj))
991          logging.error(
992              "  Referrer types: %s", ", ".join(
993                  [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
994          logging.error(
995              "  Referent types: %s", ", ".join(
996                  [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
997          logging.error("  Object attribute names: %s", dir(obj))
998          logging.error("  Object __str__:")
999          logging.error(obj)
1000          logging.error("  Object __repr__:")
1001          logging.error(repr(obj))
1002        except Exception:  # pylint: disable=broad-except
1003          logging.error("(Exception while printing object)")
1004
1005    # When garbage is created, this call can help identify reference cycles,
1006    # which are typically the cause of such garbage.
1007    if new_garbage > previous_garbage:
1008      for i in range(previous_garbage, new_garbage):
1009        if _find_reference_cycle(gc.garbage, i):
1010          break
1011
1012    # This will fail if any garbage has been created, typically because of a
1013    # reference cycle.
1014    self.assertEqual(previous_garbage, new_garbage)
1015    # TODO(allenl): Figure out why this debug flag reset doesn't work. It would
1016    # be nice to be able to decorate arbitrary tests in a large test suite and
1017    # not hold on to every object in other tests.
1018    gc.set_debug(previous_debug_flags)
1019    gc.enable()
1020    return result
1021
1022  return decorator
1023
1024
1025def _combine_named_parameters(**kwargs):
1026  """Generate combinations based on its keyword arguments.
1027
1028  Two sets of returned combinations can be concatenated using +.  Their product
1029  can be computed using `times()`.
1030
1031  Args:
1032    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
1033      `option=the_only_possibility`.
1034
1035  Returns:
1036    a list of dictionaries for each combination. Keys in the dictionaries are
1037    the keyword argument names.  Each key has one value - one of the
1038    corresponding keyword argument values.
1039  """
1040  sort_by_key = lambda k: k[0]
1041  combinations = []
1042  for key, values in sorted(kwargs.items(), key=sort_by_key):
1043    if not isinstance(values, list):
1044      values = [values]
1045    combinations.append([(key, value) for value in values])
1046
1047  return [OrderedDict(result) for result in itertools.product(*combinations)]
1048
1049
1050def generate_combinations_with_testcase_name(**kwargs):
1051  """Generate combinations based on its keyword arguments using combine().
1052
1053  This function calls combine() and appends a testcase name to the list of
1054  dictionaries returned. The 'testcase_name' key is a required for named
1055  parameterized tests.
1056
1057  Args:
1058    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
1059      `option=the_only_possibility`.
1060
1061  Returns:
1062    a list of dictionaries for each combination. Keys in the dictionaries are
1063    the keyword argument names.  Each key has one value - one of the
1064    corresponding keyword argument values.
1065  """
1066  combinations = _combine_named_parameters(**kwargs)
1067  named_combinations = []
1068  for combination in combinations:
1069    assert isinstance(combination, OrderedDict)
1070    name = "".join([
1071        "_{}_{}".format("".join(filter(str.isalnum, key)),
1072                        "".join(filter(str.isalnum, str(value))))
1073        for key, value in combination.items()
1074    ])
1075    named_combinations.append(
1076        OrderedDict(
1077            list(combination.items()) +
1078            [("testcase_name", "_test{}".format(name))]))
1079
1080  return named_combinations
1081
1082
1083def run_all_in_graph_and_eager_modes(cls):
1084  """Execute all test methods in the given class with and without eager."""
1085  base_decorator = run_in_graph_and_eager_modes
1086  for name in dir(cls):
1087    if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
1088        name.startswith("testSkipEager") or
1089        name.startswith("test_skip_eager") or
1090        name == "test_session"):
1091      continue
1092    value = getattr(cls, name, None)
1093    if callable(value):
1094      setattr(cls, name, base_decorator(value))
1095  return cls
1096
1097
1098def build_as_function_and_v1_graph(func=None):
1099  """Run a test case in v1 graph mode and inside tf.function in eager mode.
1100
1101  WARNING: This decorator can only be used in test cases that statically checks
1102  generated graph. Attempting to evaluate graph or function results via.
1103  session.run() or self.evaluate() will fail.
1104
1105  WARNING: This decorator can only be used for test cases that inherit from
1106  absl.testing.parameterized.TestCase.
1107
1108  Args:
1109    func: Test case function to be decorated.
1110
1111  Returns:
1112    Decorated test case function.
1113  """
1114
1115  def decorator(f):
1116    if tf_inspect.isclass(f):
1117      raise ValueError(
1118          "`run_in_graph_mode_and_function` only supports test methods.")
1119
1120    @parameterized.named_parameters(("_v1_graph", "v1_graph"),
1121                                    ("_function", "function"))
1122    @functools.wraps(f)
1123    def decorated(self, run_mode, *args, **kwargs):
1124      if run_mode == "v1_graph":
1125        with ops.Graph().as_default():
1126          f(self, *args, **kwargs)
1127      elif run_mode == "function":
1128
1129        @def_function.function
1130        def function_in_eager():
1131          f(self, *args, **kwargs)
1132
1133        # Create a new graph for the eagerly executed version of this test for
1134        # better isolation.
1135        graph_for_eager_test = ops.Graph()
1136        with graph_for_eager_test.as_default(), context.eager_mode():
1137          function_in_eager()
1138        ops.dismantle_graph(graph_for_eager_test)
1139      else:
1140        return ValueError("Unknown run mode %s" % run_mode)
1141
1142    return decorated
1143
1144  if func is not None:
1145    return decorator(func)
1146
1147  return decorator
1148
1149
1150def run_in_async_and_sync_mode(f):
1151  """Execute the test in async mode and sync mode."""
1152
1153  @parameterized.named_parameters([("Async", True), ("", False)])
1154  @functools.wraps(f)
1155  def decorator(self, async_mode, *args, **kwargs):
1156    if async_mode:
1157      with context.execution_mode(context.ASYNC):
1158        f(self, *args, **kwargs)
1159    else:
1160      with context.execution_mode(context.SYNC):
1161        f(self, *args, **kwargs)
1162  return decorator
1163
1164
1165def run_in_graph_and_eager_modes(func=None,
1166                                 config=None,
1167                                 use_gpu=True,
1168                                 assert_no_eager_garbage=False):
1169  """Execute the decorated test with and without enabling eager execution.
1170
1171  This function returns a decorator intended to be applied to test methods in
1172  a `tf.test.TestCase` class. Doing so will cause the contents of the test
1173  method to be executed twice - once normally, and once with eager execution
1174  enabled. This allows unittests to confirm the equivalence between eager
1175  and graph execution (see `tf.compat.v1.enable_eager_execution`).
1176
1177  For example, consider the following unittest:
1178
1179  ```python
1180  class MyTests(tf.test.TestCase):
1181
1182    @run_in_graph_and_eager_modes
1183    def test_foo(self):
1184      x = tf.constant([1, 2])
1185      y = tf.constant([3, 4])
1186      z = tf.add(x, y)
1187      self.assertAllEqual([4, 6], self.evaluate(z))
1188
1189  if __name__ == "__main__":
1190    tf.test.main()
1191  ```
1192
1193  This test validates that `tf.add()` has the same behavior when computed with
1194  eager execution enabled as it does when constructing a TensorFlow graph and
1195  executing the `z` tensor in a session.
1196
1197  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1198  `run_in_graph_and_eager_modes` are available decorators for different
1199  v1/v2/eager/graph combinations.
1200
1201
1202  Args:
1203    func: function to be annotated. If `func` is None, this method returns a
1204      decorator the can be applied to a function. If `func` is not None this
1205      returns the decorator applied to `func`.
1206    config: An optional config_pb2.ConfigProto to use to configure the session
1207      when executing graphs.
1208    use_gpu: If True, attempt to run as many operations as possible on GPU.
1209    assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
1210      collector and asserts that no extra garbage has been created when running
1211      the test with eager execution enabled. This will fail if there are
1212      reference cycles (e.g. a = []; a.append(a)). Off by default because some
1213      tests may create garbage for legitimate reasons (e.g. they define a class
1214      which inherits from `object`), and because DEBUG_SAVEALL is sticky in some
1215      Python interpreters (meaning that tests which rely on objects being
1216      collected elsewhere in the unit test file will not work). Additionally,
1217      checks that nothing still has a reference to Tensors that the test
1218      allocated.
1219
1220  Returns:
1221    Returns a decorator that will run the decorated test method twice:
1222    once by constructing and executing a graph in a session and once with
1223    eager execution enabled.
1224  """
1225
1226  def decorator(f):
1227    if tf_inspect.isclass(f):
1228      raise ValueError(
1229          "`run_in_graph_and_eager_modes` only supports test methods. "
1230          "Did you mean to use `run_all_in_graph_and_eager_modes`?")
1231
1232    def decorated(self, *args, **kwargs):
1233      try:
1234        with context.graph_mode():
1235          with self.test_session(use_gpu=use_gpu, config=config):
1236            f(self, *args, **kwargs)
1237      except unittest.case.SkipTest:
1238        pass
1239
1240      def run_eagerly(self, **kwargs):
1241        if not use_gpu:
1242          with ops.device("/device:CPU:0"):
1243            f(self, *args, **kwargs)
1244        else:
1245          f(self, *args, **kwargs)
1246
1247      if assert_no_eager_garbage:
1248        ops.reset_default_graph()
1249        run_eagerly = assert_no_new_tensors(
1250            assert_no_garbage_created(run_eagerly))
1251
1252      # This decorator runs the wrapped test twice.
1253      # Reset the test environment between runs.
1254      self.tearDown()
1255      self._tempdir = None
1256      # Create a new graph for the eagerly executed version of this test for
1257      # better isolation.
1258      graph_for_eager_test = ops.Graph()
1259      with graph_for_eager_test.as_default(), context.eager_mode():
1260        self.setUp()
1261        run_eagerly(self, **kwargs)
1262      ops.dismantle_graph(graph_for_eager_test)
1263
1264    return tf_decorator.make_decorator(f, decorated)
1265
1266  if func is not None:
1267    return decorator(func)
1268
1269  return decorator
1270
1271
1272def py_func_if_in_function(f):
1273
1274  def decorated(*args, **kwds):
1275    if not ops.inside_function():
1276      return f(*args, **kwds)
1277
1278    tensor_args = []
1279    tensor_indices = []
1280    for i, arg in enumerate(args):
1281      if isinstance(arg, (ops.Tensor, variables.Variable)):
1282        tensor_args.append(arg)
1283        tensor_indices.append(i)
1284
1285    def inner_f(*inner_tensor_args):
1286      my_args = list(args)
1287      for i, n in zip(tensor_indices, inner_tensor_args):
1288        my_args[i] = n
1289      return f(*my_args, **kwds)
1290
1291    return script_ops.py_func(inner_f, tensor_args, [])
1292
1293  return tf_decorator.make_decorator(f, decorated)
1294
1295
1296def also_run_as_tf_function(f):
1297  """Runs the decorated test twice--once as is, once inside a tf.function.
1298
1299  This allows you to run a test both in eager execution and inside a
1300  tf.function, exercising the two execution modes supported in tf 2.0. The test
1301  assertions are automatically done inside tf.py_funcs, and tf.function ensures
1302  that they run in the proper order and with the proper side effects.
1303
1304  Currently variable creation is not supported in tests annotated with this
1305  decorator since it's tricky to ensure the variable doesn't get repeatedly
1306  created when retracing the tf.function.
1307
1308  Args:
1309    f: the test method to be decorated
1310
1311  Returns:
1312    The decorated test method, which will run both in eager and inside a
1313    tf.function.
1314  """
1315
1316  def decorated(*args, **kwds):
1317
1318    def bound_f():
1319      f(*args, **kwds)
1320
1321    with context.eager_mode():
1322      # Running in eager mode
1323      bound_f()
1324      # Running as TF function
1325      # TODO(b/121143941): Remove the autograph override.
1326      def_function.function(bound_f, autograph=False)()
1327
1328  return decorated
1329
1330
1331def deprecated_graph_mode_only(func=None):
1332  """Execute the decorated test in graph mode.
1333
1334  This function returns a decorator intended to be applied to tests that are not
1335  compatible with eager mode. When this decorator is applied, the test body will
1336  be run in an environment where API calls construct graphs instead of executing
1337  eagerly.
1338
1339  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1340  `run_in_graph_and_eager_modes` are available decorators for different
1341  v1/v2/eager/graph combinations.
1342
1343  Args:
1344    func: function to be annotated. If `func` is None, this method returns a
1345      decorator the can be applied to a function. If `func` is not None this
1346      returns the decorator applied to `func`.
1347
1348  Returns:
1349    Returns a decorator that will run the decorated test method in graph mode.
1350  """
1351
1352  def decorator(f):
1353    if tf_inspect.isclass(f):
1354      setup = f.__dict__.get("setUp")
1355      if setup is not None:
1356        setattr(f, "setUp", decorator(setup))
1357
1358      for name, value in f.__dict__.copy().items():
1359        if (callable(value) and
1360            name.startswith(unittest.TestLoader.testMethodPrefix)):
1361          setattr(f, name, decorator(value))
1362
1363      return f
1364
1365    def decorated(self, *args, **kwargs):
1366      if context.executing_eagerly():
1367        with context.graph_mode():
1368          return f(self, *args, **kwargs)
1369      else:
1370        return f(self, *args, **kwargs)
1371
1372    return decorated
1373
1374  if func is not None:
1375    return decorator(func)
1376
1377  return decorator
1378
1379
1380run_deprecated_v1 = deprecated_graph_mode_only
1381
1382
1383def run_all_in_deprecated_graph_mode_only(cls):
1384  """Execute all tests in a class in graph mode."""
1385  base_decorator = deprecated_graph_mode_only
1386  for name in dir(cls):
1387    if (not name.startswith(unittest.TestLoader.testMethodPrefix) or
1388        name == "test_session"):
1389      continue
1390    value = getattr(cls, name, None)
1391    if callable(value):
1392      setattr(cls, name, base_decorator(value))
1393  return cls
1394
1395
1396def run_v1_only(reason, func=None):
1397  """Execute the decorated test only if running in v1 mode.
1398
1399  This function is intended to be applied to tests that exercise v1 only
1400  functionality. If the test is run in v2 mode it will simply be skipped.
1401
1402  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1403  `run_in_graph_and_eager_modes` are available decorators for different
1404  v1/v2/eager/graph combinations.
1405
1406  Args:
1407    reason: string giving a reason for limiting the test to v1 only.
1408    func: function to be annotated. If `func` is None, this method returns a
1409      decorator the can be applied to a function. If `func` is not None this
1410      returns the decorator applied to `func`.
1411
1412  Returns:
1413    Returns a decorator that will conditionally skip the decorated test method.
1414  """
1415  if not isinstance(reason, str):
1416    raise ValueError("'reason' should be string, got {}".format(type(reason)))
1417
1418  def decorator(f):
1419    if tf_inspect.isclass(f):
1420      # To skip an entire test suite class, we only decorate the setUp method
1421      # to skip all tests. There are cases when setUp is not defined (not
1422      # overridden in subclasses of TestCase, so not available in f.__dict__
1423      # below). For those cases, we walk the method resolution order list and
1424      # pick the first setUp method we find (usually this should be the one in
1425      # the parent class since that's the TestCase class).
1426      for cls in type.mro(f):
1427        setup = cls.__dict__.get("setUp")
1428        if setup is not None:
1429          setattr(f, "setUp", decorator(setup))
1430          break
1431
1432      return f
1433    else:
1434      # If f is just a function, just create a decorator for it and return it
1435      def decorated(self, *args, **kwargs):
1436        if tf2.enabled():
1437          self.skipTest(reason)
1438
1439        return f(self, *args, **kwargs)
1440
1441      return decorated
1442
1443  if func is not None:
1444    return decorator(func)
1445
1446  return decorator
1447
1448
1449def run_v2_only(func=None):
1450  """Execute the decorated test only if running in v2 mode.
1451
1452  This function is intended to be applied to tests that exercise v2 only
1453  functionality. If the test is run in v1 mode it will simply be skipped.
1454
1455  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1456  `run_in_graph_and_eager_modes` are available decorators for different
1457  v1/v2/eager/graph combinations.
1458
1459  Args:
1460    func: function to be annotated. If `func` is None, this method returns a
1461      decorator the can be applied to a function. If `func` is not None this
1462      returns the decorator applied to `func`.
1463
1464  Returns:
1465    Returns a decorator that will conditionally skip the decorated test method.
1466  """
1467
1468  def decorator(f):
1469    if tf_inspect.isclass(f):
1470      raise ValueError("`run_v2_only` only supports test methods.")
1471
1472    def decorated(self, *args, **kwargs):
1473      if not tf2.enabled():
1474        self.skipTest("Test is only compatible with v2")
1475
1476      return f(self, *args, **kwargs)
1477
1478    return decorated
1479
1480  if func is not None:
1481    return decorator(func)
1482
1483  return decorator
1484
1485
1486def run_gpu_only(func=None):
1487  """Execute the decorated test only if a GPU is available.
1488
1489  This function is intended to be applied to tests that require the presence
1490  of a GPU. If a GPU is absent, it will simply be skipped.
1491
1492  Args:
1493    func: function to be annotated. If `func` is None, this method returns a
1494      decorator the can be applied to a function. If `func` is not None this
1495      returns the decorator applied to `func`.
1496
1497  Returns:
1498    Returns a decorator that will conditionally skip the decorated test method.
1499  """
1500
1501  def decorator(f):
1502    if tf_inspect.isclass(f):
1503      raise ValueError("`run_gpu_only` only supports test methods.")
1504
1505    def decorated(self, *args, **kwargs):
1506      if not is_gpu_available():
1507        self.skipTest("Test requires GPU")
1508
1509      return f(self, *args, **kwargs)
1510
1511    return decorated
1512
1513  if func is not None:
1514    return decorator(func)
1515
1516  return decorator
1517
1518
1519def run_cuda_only(func=None):
1520  """Execute the decorated test only if a GPU is available.
1521
1522  This function is intended to be applied to tests that require the presence
1523  of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped.
1524
1525  Args:
1526    func: function to be annotated. If `func` is None, this method returns a
1527      decorator the can be applied to a function. If `func` is not None this
1528      returns the decorator applied to `func`.
1529
1530  Returns:
1531    Returns a decorator that will conditionally skip the decorated test method.
1532  """
1533
1534  def decorator(f):
1535    if tf_inspect.isclass(f):
1536      raise ValueError("`run_cuda_only` only supports test methods.")
1537
1538    def decorated(self, *args, **kwargs):
1539      if not is_gpu_available(cuda_only=True):
1540        self.skipTest("Test requires CUDA GPU")
1541
1542      return f(self, *args, **kwargs)
1543
1544    return decorated
1545
1546  if func is not None:
1547    return decorator(func)
1548
1549  return decorator
1550
1551
1552def run_gpu_or_tpu(func=None):
1553  """Execute the decorated test only if a physical GPU or TPU is available.
1554
1555  This function is intended to be applied to tests that require the presence
1556  of a physical GPU or TPU. It complies with the following rules:
1557  - If a GPU is available, the test will run on the GPU.
1558  - If a GPU is absent and a TPU is available, the test will run on the TPU.
1559  - If both GPU and TPU are absent, the test will be skipped.
1560
1561  Args:
1562    func: function to be annotated. If `func` is None, this method returns a
1563      decorator the can be applied to a function. If `func` is not None this
1564      returns the decorator applied to `func`.
1565
1566  Returns:
1567    Returns a decorator that will conditionally skip the decorated test method.
1568  """
1569
1570  def decorator(f):
1571    if tf_inspect.isclass(f):
1572      raise ValueError("`run_gpu_or_tpu` only supports test methods.")
1573
1574    def decorated(self, *args, **kwargs):
1575      if config.list_physical_devices("GPU"):
1576        return f(self, "GPU", *args, **kwargs)
1577
1578      if config.list_physical_devices("TPU"):
1579        return f(self, "TPU", *args, **kwargs)
1580
1581      self.skipTest("Test requires GPU or TPU")
1582
1583    return decorated
1584
1585  return decorator if func is None else decorator(func)
1586
1587
1588def with_forward_compatibility_horizons(*horizons):
1589  """Executes the decorated test with the specified forward-compat horizons.
1590
1591  Args:
1592    *horizons: A list of (year, month, day) tuples.  If the list includes
1593      `None`, then the test will also be run with no forward-compatibility
1594      horizon set.
1595
1596  Returns:
1597    A decorator that will execute the test with the specified horizons.
1598  """
1599  if not horizons:
1600    raise ValueError("Expected at least one horizon.")
1601  for horizon in horizons:
1602    if not ((horizon is None) or
1603            (len(horizon) == 3 and all(isinstance(x, int) for x in horizon))):
1604      raise ValueError("Bad horizon value: %r" % horizon)
1605
1606  def decorator(f):
1607    if tf_inspect.isclass(f):
1608      raise ValueError("`with_forward_compatibility_horizons` only "
1609                       "supports test methods.")
1610    def decorated(self, *args, **kwargs):
1611      for horizon in horizons:
1612        if horizon is None:
1613          f(self, *args, **kwargs)
1614        else:
1615          (year, month, day) = horizon
1616          with forward_compatibility_horizon(year, month, day):
1617            f(self, *args, **kwargs)
1618    return decorated
1619
1620  return decorator
1621
1622
1623@deprecation.deprecated(None,
1624                        "Use `tf.config.list_physical_devices('GPU')` instead.")
1625@tf_export("test.is_gpu_available")
1626def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
1627  """Returns whether TensorFlow can access a GPU.
1628
1629  Warning: if a non-GPU version of the package is installed, the function would
1630  also return False. Use `tf.test.is_built_with_cuda` to validate if TensorFlow
1631  was build with CUDA support.
1632
1633  For example,
1634  >>> gpu_available = tf.test.is_gpu_available()
1635  >>> is_cuda_gpu_available = tf.test.is_gpu_available(cuda_only=True)
1636  >>> is_cuda_gpu_min_3 = tf.test.is_gpu_available(True, (3,0))
1637
1638  Args:
1639    cuda_only: limit the search to CUDA GPUs.
1640    min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
1641      CUDA compute capability required, or None if no requirement.
1642
1643  Note that the keyword arg name "cuda_only" is misleading (since routine will
1644  return true when a GPU device is available irrespective of whether TF was
1645  built with CUDA support or ROCm support. However no changes here because
1646
1647  ++ Changing the name "cuda_only" to something more generic would break
1648     backward compatibility
1649
1650  ++ Adding an equivalent "rocm_only" would require the implementation check
1651     the build type. This in turn would require doing the same for CUDA and thus
1652     potentially break backward compatibility
1653
1654  ++ Adding a new "cuda_or_rocm_only" would not break backward compatibility,
1655     but would require most (if not all) callers to update the call to use
1656     "cuda_or_rocm_only" instead of "cuda_only"
1657
1658  Returns:
1659    True if a GPU device of the requested kind is available.
1660  """
1661
1662  # This was needed earlier when we had support for SYCL in TensorFlow.
1663  del cuda_only
1664
1665  try:
1666    for local_device in device_lib.list_local_devices():
1667      if local_device.device_type == "GPU":
1668        gpu_info = gpu_util.compute_capability_from_device_desc(local_device)
1669        cc = gpu_info.compute_capability or (0, 0)
1670        if not min_cuda_compute_capability or cc >= min_cuda_compute_capability:
1671          return True
1672    return False
1673  except errors_impl.NotFoundError as e:
1674    if not all(x in str(e) for x in ["CUDA", "not find"]):
1675      raise e
1676    else:
1677      logging.error(str(e))
1678      return False
1679
1680
1681@contextlib.contextmanager
1682def device(use_gpu):
1683  """Uses gpu when requested and available."""
1684  if use_gpu and is_gpu_available():
1685    dev = "/device:GPU:0"
1686  else:
1687    dev = "/device:CPU:0"
1688  with ops.device(dev):
1689    yield
1690
1691
1692@contextlib.contextmanager
1693def use_gpu():
1694  """Uses gpu when requested and available."""
1695  with device(use_gpu=True):
1696    yield
1697
1698
1699@contextlib.contextmanager
1700def force_gpu():
1701  """Force the gpu to be used."""
1702  with ops.device("/device:GPU:0"):
1703    yield
1704
1705
1706@contextlib.contextmanager
1707def force_cpu():
1708  """Force the cpu to be used."""
1709  with ops.device("/device:CPU:0"):
1710    yield
1711
1712
1713class CapturedWrites(object):
1714  """A utility class to load the captured writes made to a stream."""
1715
1716  def __init__(self, capture_location):
1717    self.capture_location = capture_location
1718
1719  def contents(self):
1720    """Get the captured writes as a single string."""
1721    with open(self.capture_location) as tmp_file:
1722      output_data = "".join(tmp_file.readlines())
1723    return output_data
1724
1725
1726class FakeEagerSession(object):
1727  """Fake session so tests that conditionally use placeholders can use eager.
1728
1729  There are a number of tests that conditionally use placeholders for shape
1730  inference. The pattern is demonstrated here:
1731
1732  ```python
1733  with self.cached_session() as sess:
1734    if static_shape:
1735      y = math_ops.matmul(x, ...)
1736      feed_dict = {}
1737    else:
1738      x_ph = array_ops.placeholder(...)
1739      y = math_ops.matmul(x_ph, ...)
1740      feed_dict = {x_ph: x}
1741    val = sess.run(y, feed_dict=feed_dict)
1742  ```
1743
1744  Since the feed_dict is empty when not using placeholders we should be able to
1745  call self.evaluate(), however this requires rewriting the test case.
1746  This class should be considered a stop-gap solution to get tests running with
1747  eager with minimal changes to the actual test.
1748  """
1749
1750  def __init__(self, test_case):
1751    self._test_case = test_case
1752
1753  def run(self, fetches, *args, **kwargs):
1754    """Evaluate `fetches`.
1755
1756    Fail if additional args are specified.
1757
1758    Args:
1759      fetches: A Tensor or a nested list/tuple of Tensors.
1760      *args: Positional arguments
1761      **kwargs: Keyword arguments
1762
1763    Raises:
1764      RuntimeError: If args or kwargs are specified.
1765
1766    Returns:
1767      Tensors as numpy values.
1768    """
1769    feed_dict = kwargs.pop("feed_dict", {})
1770    if feed_dict:
1771      raise RuntimeError(
1772          "feed_dict is not supported when eager execution is enabled "
1773          "(in this case, sess.run(t) is shorthand for t.numpy()")
1774
1775    if args or kwargs:
1776      raise RuntimeError(
1777          "Optional args are not supported when eager execution is enabled "
1778          "(in this case, sess.run(t) is shorthand for t.numpy()")
1779
1780    return self._test_case.evaluate(fetches)
1781
1782
1783class ErrorLoggingSession(session.Session):
1784  """Wrapper around a Session that logs errors in run()."""
1785
1786  def run(self, *args, **kwargs):
1787    try:
1788      return super(ErrorLoggingSession, self).run(*args, **kwargs)
1789    except Exception as e:  # pylint: disable=broad-except
1790      # Note: disable the logging for OutOfRangeError, which makes the output
1791      # of tf.data tests hard to read, because OutOfRangeError is used as the
1792      # signal completion
1793      if not isinstance(e, errors.OutOfRangeError):
1794        logging.error(str(e))
1795      raise
1796
1797
1798def disable_cudnn_autotune(func):
1799  """Disable autotuning during the call to this function.
1800
1801  Some tests want to base assertions on a graph being isomorphic with a copy.
1802  To ensure this, this decorator disables autotuning.
1803
1804  Args:
1805    func: Function to run with CuDNN autotuning turned off.
1806
1807  Returns:
1808    Decorated function.
1809  """
1810
1811  def decorator(f):
1812
1813    def decorated(self, *args, **kwargs):
1814      original_tf_cudnn_use_autotune = os.environ.get("TF_CUDNN_USE_AUTOTUNE")
1815      os.environ["TF_CUDNN_USE_AUTOTUNE"] = "false"
1816      original_xla_flags = os.environ.get("XLA_FLAGS")
1817      new_xla_flags = "--xla_gpu_autotune_level=0"
1818      if original_xla_flags:
1819        new_xla_flags = original_xla_flags + " " + new_xla_flags
1820      os.environ["XLA_FLAGS"] = new_xla_flags
1821
1822      result = f(self, *args, **kwargs)
1823
1824      if (original_tf_cudnn_use_autotune is None):
1825        del os.environ["TF_CUDNN_USE_AUTOTUNE"]
1826      else:
1827        os.environ["TF_CUDNN_USE_AUTOTUNE"] = original_tf_cudnn_use_autotune
1828      if (original_xla_flags is None):
1829        del os.environ["XLA_FLAGS"]
1830      else:
1831        os.environ["XLA_FLAGS"] = original_xla_flags
1832
1833      return result
1834
1835    return decorated
1836
1837  if func is not None:
1838    return decorator(func)
1839
1840  return decorator
1841
1842
1843# The description is just for documentation purposes.
1844def enable_tf_xla_constant_folding(description):
1845
1846  if not isinstance(description, str):
1847    raise ValueError("'description' should be string, got {}".format(
1848        type(description)))
1849
1850  def enable_tf_xla_constant_folding_impl(func):
1851    """Enable constant folding during the call to this function.
1852
1853    Some tests fail without constant folding.
1854
1855    Args:
1856      func: Function to run with constant folding turned on.
1857
1858    Returns:
1859      Decorated function.
1860    """
1861
1862    def decorator(f):
1863
1864      def decorated(self, *args, **kwargs):
1865        original_var = pywrap_tf_session.TF_GetXlaConstantFoldingDisabled()
1866        pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(False)
1867        result = f(self, *args, **kwargs)
1868        pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(original_var)
1869        return result
1870
1871      return decorated
1872
1873    if func is not None:
1874      return decorator(func)
1875
1876    return decorator
1877
1878  return enable_tf_xla_constant_folding_impl
1879
1880
1881# Updates test function by selectively disabling it.
1882def _disable_test(execute_func):
1883
1884  def disable_test_impl(func):
1885
1886    def decorator(func):
1887
1888      def decorated(self, *args, **kwargs):
1889        if execute_func:
1890          return func(self, *args, **kwargs)
1891
1892      return tf_decorator.make_decorator(func, decorated)
1893
1894    if func is not None:
1895      return decorator(func)
1896
1897    return decorator
1898
1899  return disable_test_impl
1900
1901
1902# The description is just for documentation purposes.
1903def disable_xla(description):  # pylint: disable=unused-argument
1904  """Execute the test method only if xla is not enabled."""
1905  execute_func = not is_xla_enabled()
1906  return _disable_test(execute_func)
1907
1908
1909# The description is just for documentation purposes.
1910def disable_mlir_bridge(description):  # pylint: disable=unused-argument
1911  """Execute the test method only if MLIR bridge is not enabled."""
1912  execute_func = not is_mlir_bridge_enabled()
1913  return _disable_test(execute_func)
1914
1915
1916# The description is just for documentation purposes.
1917def disable_asan(description):  # pylint: disable=unused-argument
1918  """Execute the test method only if ASAN is not enabled."""
1919  execute_func = not is_asan_enabled()
1920  return _disable_test(execute_func)
1921
1922
1923# The description is just for documentation purposes.
1924def disable_msan(description):  # pylint: disable=unused-argument
1925  """Execute the test method only if MSAN is not enabled."""
1926  execute_func = not is_msan_enabled()
1927  return _disable_test(execute_func)
1928
1929
1930# The description is just for documentation purposes.
1931def disable_tsan(description):  # pylint: disable=unused-argument
1932  """Execute the test method only if TSAN is not enabled."""
1933  execute_func = not is_tsan_enabled()
1934  return _disable_test(execute_func)
1935
1936
1937# The description is just for documentation purposes.
1938def disable_ubsan(description):  # pylint: disable=unused-argument
1939  """Execute the test method only if UBSAN is not enabled."""
1940  execute_func = not is_ubsan_enabled()
1941  return _disable_test(execute_func)
1942
1943
1944# The description is just for documentation purposes.
1945def disable_tfrt(unused_description):
1946
1947  def disable_tfrt_impl(cls_or_func):
1948    """Execute the test only if tfrt is not enabled."""
1949
1950    if tf_inspect.isclass(cls_or_func):
1951      if tfrt_utils.enabled():
1952        return None
1953      else:
1954        return cls_or_func
1955    else:
1956      def decorator(func):
1957
1958        def decorated(self, *args, **kwargs):
1959          if tfrt_utils.enabled():
1960            return
1961          else:
1962            return func(self, *args, **kwargs)
1963
1964        return decorated
1965
1966      if cls_or_func is not None:
1967        return decorator(cls_or_func)
1968
1969      return decorator
1970
1971  return disable_tfrt_impl
1972
1973
1974def for_all_test_methods(decorator, *args, **kwargs):
1975  """Generate class-level decorator from given method-level decorator.
1976
1977  It is expected for the given decorator to take some arguments and return
1978  a method that is then called on the test method to produce a decorated
1979  method.
1980
1981  Args:
1982    decorator: The decorator to apply.
1983    *args: Positional arguments
1984    **kwargs: Keyword arguments
1985  Returns: Function that will decorate a given classes test methods with the
1986    decorator.
1987  """
1988
1989  def all_test_methods_impl(cls):
1990    """Apply decorator to all test methods in class."""
1991    for name in dir(cls):
1992      value = getattr(cls, name)
1993      if callable(value) and name.startswith(
1994          "test") and (name != "test_session"):
1995        setattr(cls, name, decorator(*args, **kwargs)(value))
1996    return cls
1997
1998  return all_test_methods_impl
1999
2000
2001# The description is just for documentation purposes.
2002def no_xla_auto_jit(description):  # pylint: disable=unused-argument
2003  """This test is not intended to be run with XLA auto jit enabled."""
2004  execute_func = not is_xla_enabled()
2005  return _disable_test(execute_func)
2006
2007
2008# The description is just for documentation purposes.
2009def xla_allow_fallback(description):  # pylint: disable=unused-argument
2010
2011  def xla_allow_fallback_impl(func):
2012    """Allow fallback to TF even though testing xla."""
2013
2014    def decorator(func):
2015
2016      def decorated(self, *args, **kwargs):
2017        if is_xla_enabled():
2018          # Update the global XLABuildOpsPassFlags to enable lazy compilation,
2019          # which allows the compiler to fall back to TF classic. Remember the
2020          # old value so that we can reset it.
2021          old_value = pywrap_tf_session.TF_SetXlaEnableLazyCompilation(True)
2022          result = func(self, *args, **kwargs)
2023          pywrap_tf_session.TF_SetXlaEnableLazyCompilation(old_value)
2024          return result
2025        else:
2026          return func(self, *args, **kwargs)
2027
2028      return decorated
2029
2030    if func is not None:
2031      return decorator(func)
2032
2033    return decorator
2034
2035  return xla_allow_fallback_impl
2036
2037
2038# The description is just for documentation purposes.
2039def run_without_tensor_float_32(description):  # pylint: disable=unused-argument
2040  """Execute test with TensorFloat-32 disabled.
2041
2042  While almost every real-world deep learning model runs fine with
2043  TensorFloat-32, many tests use assertAllClose or similar methods.
2044  TensorFloat-32 matmuls typically will cause such methods to fail with the
2045  default tolerances.
2046
2047  Args:
2048    description: A description used for documentation purposes, describing why
2049      the test requires TensorFloat-32 to be disabled.
2050
2051  Returns:
2052    Decorator which runs a test with TensorFloat-32 disabled.
2053  """
2054
2055  def decorator(f):
2056
2057    @functools.wraps(f)
2058    def decorated(self, *args, **kwargs):
2059      allowed = config.tensor_float_32_execution_enabled()
2060      try:
2061        config.enable_tensor_float_32_execution(False)
2062        f(self, *args, **kwargs)
2063      finally:
2064        config.enable_tensor_float_32_execution(allowed)
2065
2066    return decorated
2067
2068  return decorator
2069
2070
2071# The description is just for documentation purposes.
2072def run_all_without_tensor_float_32(description):  # pylint: disable=unused-argument
2073  """Execute all tests in a class with TensorFloat-32 disabled."""
2074  return for_all_test_methods(run_without_tensor_float_32, description)
2075
2076
2077def matmul_without_tf32(a, b, *args, **kwargs):
2078  """Run matmul but cast float32 inputs to float64 if TensorFloat-32 is enabled.
2079
2080  This effectively runs matmul without TensorFloat-32. It should only be used in
2081  tests when verifying some other op or functions works correctly, e.g. to test
2082  `tf.linalg.sqrtm` by matrix multiplying the output of the op by itself. In
2083  such cases, the matmul itself is not being tested so it's OK to run it with
2084  higher precision.
2085
2086  If a matmul itself is being tested, or some other op which uses matmul, use
2087  `run_without_tensor_float_32` instead.
2088
2089  This also casts complex64 inputs to complex128, since TensorFloat-32 can also
2090  be used with complex64
2091
2092  Args:
2093    a: First input to tf.linalg.matmul
2094    b: Second input to tf.linalg.matmul
2095    args: Other positional arguments to tf.linalg.matmul
2096    **kwargs: Other keyword arguments to tf.linalg.matmul
2097
2098  Returns:
2099    A tensor with the same type as `a`.
2100  """
2101  if config.tensor_float_32_execution_enabled() and a.dtype == "float32":
2102    a = math_ops.cast(a, "float64")
2103    b = math_ops.cast(b, "float64")
2104    ret = math_ops.matmul(a, b, *args, **kwargs)
2105    return math_ops.cast(ret, a.dtype)
2106  elif config.tensor_float_32_execution_enabled() and a.dtype == "complex64":
2107    a = math_ops.cast(a, "complex128")
2108    b = math_ops.cast(b, "complex128")
2109    ret = math_ops.matmul(a, b, *args, **kwargs)
2110    return math_ops.cast(ret, a.dtype)
2111  else:
2112    return math_ops.matmul(a, b, *args, **kwargs)
2113
2114
2115class EagerSessionWarner(object):
2116
2117  def __getattr__(self, attr):
2118    raise AttributeError(
2119        "Trying to access properties or call methods on the result of "
2120        "self.session(), self.cached_session(), etc while eager execution "
2121        "is enabled. If you're porting this test case to TF 2.0, either "
2122        "adapt the test to work with eager execution or insert a call to "
2123        "tf.disable_eager_execution() in the main() function of this test "
2124        "file.")
2125
2126
2127@tf_export("test.TestCase")
2128class TensorFlowTestCase(googletest.TestCase):
2129  """Base class for tests that need to test TensorFlow."""
2130
2131  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
2132    super(TensorFlowTestCase, self).__init__(methodName)
2133    # Make sure we get unfiltered stack traces during the test
2134    traceback_utils.disable_traceback_filtering()
2135    if is_xla_enabled():
2136      pywrap_tf_session.TF_SetXlaAutoJitMode("2")
2137      pywrap_tf_session.TF_SetXlaMinClusterSize(1)
2138      pywrap_tf_session.TF_SetXlaEnableLazyCompilation(False)
2139      pywrap_tf_session.TF_SetTfXlaCpuGlobalJit(True)
2140      # Constant folding secretly runs code on TF:Classic CPU, so we also
2141      # disable it here.
2142      pywrap_tf_session.TF_SetXlaConstantFoldingDisabled(True)
2143
2144    # Check if the mlir bridge has been explicitly enabled or disabled. If
2145    # is_mlir_bridge_enabled() returns None, the user did not explictly enable
2146    # or disable the bridge so do not update enable_mlir_bridge.
2147    if is_mlir_bridge_enabled():
2148      context.context().enable_mlir_bridge = True
2149    elif is_mlir_bridge_enabled() is not None:
2150      context.context().enable_mlir_bridge = False
2151
2152    self._threads = []
2153    self._tempdir = None
2154    self._cached_session = None
2155    self._test_start_time = None
2156    # This flag provides the ability to control whether the graph mode gets
2157    # initialized for TF1 or not. Initializing for TF1, which is what was
2158    # happening earlier, was preventing enablement of 'eager mode' in the test.
2159    self._set_default_seed = True
2160
2161  def setUp(self):
2162    super(TensorFlowTestCase, self).setUp()
2163    self._ClearCachedSession()
2164    random.seed(random_seed.DEFAULT_GRAPH_SEED)
2165    np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
2166    # Note: The following line is necessary because some test methods may error
2167    # out from within nested graph contexts (e.g., via assertRaises and
2168    # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty
2169    # under certain versions of Python. That would cause
2170    # ops.reset_default_graph() to throw an exception if the stack were not
2171    # cleared first.
2172    ops._default_graph_stack.reset()  # pylint: disable=protected-access
2173    ops.reset_default_graph()
2174    if self._set_default_seed:
2175      random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
2176    # Reset summary writer in case another test used set_as_default() with their
2177    # summary writer.
2178    summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access
2179    summary_state.writer = None
2180
2181    # Avoiding calling setUp() for the poorly named test_session method.
2182    if self.id().endswith(".test_session"):
2183      self.skipTest("Not a test.")
2184
2185    self._test_start_time = time.time()
2186
2187  def tearDown(self):
2188    # If a subclass overrides setUp and doesn't call the parent class's setUp,
2189    # then we may not have set the start time.
2190    if self._test_start_time is not None:
2191      logging.info("time(%s): %ss", self.id(),
2192                   round(time.time() - self._test_start_time, 2))
2193
2194    for thread in self._threads:
2195      thread.check_termination()
2196
2197    self._ClearCachedSession()
2198    super(TensorFlowTestCase, self).tearDown()
2199
2200  def _ClearCachedSession(self):
2201    if self._cached_session is not None:
2202      self._cached_session.close()
2203      self._cached_session = None
2204
2205  def get_temp_dir(self):
2206    """Returns a unique temporary directory for the test to use.
2207
2208    If you call this method multiple times during in a test, it will return the
2209    same folder. However, across different runs the directories will be
2210    different. This will ensure that across different runs tests will not be
2211    able to pollute each others environment.
2212    If you need multiple unique directories within a single test, you should
2213    use tempfile.mkdtemp as follows:
2214      tempfile.mkdtemp(dir=self.get_temp_dir()):
2215
2216    Returns:
2217      string, the path to the unique temporary directory created for this test.
2218    """
2219    if not self._tempdir:
2220      self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
2221    return self._tempdir
2222
2223  @contextlib.contextmanager
2224  def captureWritesToStream(self, stream):
2225    """A context manager that captures the writes to a given stream.
2226
2227    This context manager captures all writes to a given stream inside of a
2228    `CapturedWrites` object. When this context manager is created, it yields
2229    the `CapturedWrites` object. The captured contents can be accessed  by
2230    calling `.contents()` on the `CapturedWrites`.
2231
2232    For this function to work, the stream must have a file descriptor that
2233    can be modified using `os.dup` and `os.dup2`, and the stream must support
2234    a `.flush()` method. The default python sys.stdout and sys.stderr are
2235    examples of this. Note that this does not work in Colab or Jupyter
2236    notebooks, because those use alternate stdout streams.
2237
2238    Example:
2239    ```python
2240    class MyOperatorTest(test_util.TensorFlowTestCase):
2241      def testMyOperator(self):
2242        input = [1.0, 2.0, 3.0, 4.0, 5.0]
2243        with self.captureWritesToStream(sys.stdout) as captured:
2244          result = MyOperator(input).eval()
2245        self.assertStartsWith(captured.contents(), "This was printed.")
2246    ```
2247
2248    Args:
2249      stream: The stream whose writes should be captured. This stream must have
2250        a file descriptor, support writing via using that file descriptor, and
2251        must have a `.flush()` method.
2252
2253    Yields:
2254      A `CapturedWrites` object that contains all writes to the specified stream
2255      made during this context.
2256    """
2257    stream.flush()
2258    fd = stream.fileno()
2259    tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
2260    tmp_file = open(tmp_file_path, "w")
2261    orig_fd = os.dup(fd)
2262    os.dup2(tmp_file.fileno(), fd)
2263    try:
2264      yield CapturedWrites(tmp_file_path)
2265    finally:
2266      tmp_file.close()
2267      os.dup2(orig_fd, fd)
2268
2269  def _AssertProtoEquals(self, a, b, msg=None):
2270    """Asserts that a and b are the same proto.
2271
2272    Uses ProtoEq() first, as it returns correct results
2273    for floating point attributes, and then use assertProtoEqual()
2274    in case of failure as it provides good error messages.
2275
2276    Args:
2277      a: a proto.
2278      b: another proto.
2279      msg: Optional message to report on failure.
2280    """
2281    if not compare.ProtoEq(a, b):
2282      compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)
2283
2284  def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
2285    """Asserts that message is same as parsed expected_message_ascii.
2286
2287    Creates another prototype of message, reads the ascii message into it and
2288    then compares them using self._AssertProtoEqual().
2289
2290    Args:
2291      expected_message_maybe_ascii: proto message in original or ascii form.
2292      message: the message to validate.
2293      msg: Optional message to report on failure.
2294    """
2295    if isinstance(expected_message_maybe_ascii, type(message)):
2296      expected_message = expected_message_maybe_ascii
2297      self._AssertProtoEquals(expected_message, message, msg=msg)
2298    elif isinstance(expected_message_maybe_ascii, (str, bytes)):
2299      expected_message = type(message)()
2300      text_format.Merge(
2301          expected_message_maybe_ascii,
2302          expected_message,
2303          descriptor_pool=descriptor_pool.Default())
2304      self._AssertProtoEquals(expected_message, message, msg=msg)
2305    else:
2306      assert False, ("Can't compare protos of type %s and %s." %
2307                     (type(expected_message_maybe_ascii), type(message)))
2308
2309  def assertProtoEqualsVersion(
2310      self,
2311      expected,
2312      actual,
2313      producer=versions.GRAPH_DEF_VERSION,
2314      min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER,
2315      msg=None):
2316    expected = "versions { producer: %d min_consumer: %d };\n%s" % (
2317        producer, min_consumer, expected)
2318    self.assertProtoEquals(expected, actual, msg=msg)
2319
2320  def assertStartsWith(self, actual, expected_start, msg=None):
2321    """Assert that actual.startswith(expected_start) is True.
2322
2323    Args:
2324      actual: str
2325      expected_start: str
2326      msg: Optional message to report on failure.
2327    """
2328    if not actual.startswith(expected_start):
2329      fail_msg = "%r does not start with %r" % (actual, expected_start)
2330      fail_msg += " : %r" % (msg) if msg else ""
2331      self.fail(fail_msg)
2332
2333  def _eval_tensor(self, tensor):
2334    if tensor is None:
2335      return None
2336    elif callable(tensor):
2337      return self._eval_helper(tensor())
2338    else:
2339      try:
2340        if sparse_tensor.is_sparse(tensor):
2341          return sparse_tensor.SparseTensorValue(tensor.indices.numpy(),
2342                                                 tensor.values.numpy(),
2343                                                 tensor.dense_shape.numpy())
2344        elif ragged_tensor.is_ragged(tensor):
2345          return ragged_tensor_value.RaggedTensorValue(
2346              self._eval_tensor(tensor.values),
2347              self._eval_tensor(tensor.row_splits))
2348        elif isinstance(tensor, ops.IndexedSlices):
2349          return ops.IndexedSlicesValue(
2350              values=tensor.values.numpy(),
2351              indices=tensor.indices.numpy(),
2352              dense_shape=tensor.dense_shape.numpy())
2353        # Convert tensors and composite tensors to numpy arrays.
2354        return nest.map_structure(lambda t: t.numpy(), tensor,
2355                                  expand_composites=True)
2356      except AttributeError as e:
2357        six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e)
2358
2359  def _eval_helper(self, tensors):
2360    if tensors is None:
2361      return None
2362    return nest.map_structure(self._eval_tensor, tensors)
2363
2364  def evaluate(self, tensors):
2365    """Evaluates tensors and returns numpy values.
2366
2367    Args:
2368      tensors: A Tensor or a nested list/tuple of Tensors.
2369
2370    Returns:
2371      tensors numpy values.
2372    """
2373    if context.executing_eagerly():
2374      return self._eval_helper(tensors)
2375    else:
2376      sess = ops.get_default_session()
2377      if sess is None:
2378        with self.test_session() as sess:
2379          return sess.run(tensors)
2380      else:
2381        return sess.run(tensors)
2382
2383  # pylint: disable=g-doc-return-or-yield
2384  @contextlib.contextmanager
2385  def session(self, graph=None, config=None, use_gpu=True, force_gpu=False):
2386    """A context manager for a TensorFlow Session for use in executing tests.
2387
2388    Note that this will set this session and the graph as global defaults.
2389
2390    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
2391    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
2392    `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
2393    possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
2394    the CPU.
2395
2396    Example:
2397
2398    ``` python
2399    class MyOperatorTest(test_util.TensorFlowTestCase):
2400      def testMyOperator(self):
2401        with self.session():
2402          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
2403          result = MyOperator(valid_input).eval()
2404          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
2405          invalid_input = [-1.0, 2.0, 7.0]
2406          with self.assertRaisesOpError("negative input not supported"):
2407            MyOperator(invalid_input).eval()
2408    ```
2409
2410    Args:
2411      graph: Optional graph to use during the returned session.
2412      config: An optional config_pb2.ConfigProto to use to configure the
2413        session.
2414      use_gpu: If True, attempt to run as many ops as possible on GPU.
2415      force_gpu: If True, pin all ops to `/device:GPU:0`.
2416
2417    Yields:
2418      A Session object that should be used as a context manager to surround
2419      the graph building and execution code in a test case.
2420    """
2421    if context.executing_eagerly():
2422      yield EagerSessionWarner()
2423    else:
2424      with self._create_session(graph, config, force_gpu) as sess:
2425        with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu):
2426          yield sess
2427
2428  @contextlib.contextmanager
2429  def cached_session(self,
2430                     graph=None,
2431                     config=None,
2432                     use_gpu=True,
2433                     force_gpu=False):
2434    """Returns a TensorFlow Session for use in executing tests.
2435
2436    This method behaves differently than self.session(): for performance reasons
2437    `cached_session` will by default reuse the same session within the same
2438    test. The session returned by this function will only be closed at the end
2439    of the test (in the TearDown function).
2440
2441    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
2442    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
2443    `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
2444    possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
2445    the CPU.
2446
2447    Example:
2448    ```python
2449    class MyOperatorTest(test_util.TensorFlowTestCase):
2450      def testMyOperator(self):
2451        with self.cached_session() as sess:
2452          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
2453          result = MyOperator(valid_input).eval()
2454          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
2455          invalid_input = [-1.0, 2.0, 7.0]
2456          with self.assertRaisesOpError("negative input not supported"):
2457            MyOperator(invalid_input).eval()
2458    ```
2459
2460    Args:
2461      graph: Optional graph to use during the returned session.
2462      config: An optional config_pb2.ConfigProto to use to configure the
2463        session.
2464      use_gpu: If True, attempt to run as many ops as possible on GPU.
2465      force_gpu: If True, pin all ops to `/device:GPU:0`.
2466
2467    Yields:
2468      A Session object that should be used as a context manager to surround
2469      the graph building and execution code in a test case.
2470    """
2471    if context.executing_eagerly():
2472      yield FakeEagerSession(self)
2473    else:
2474      sess = self._get_cached_session(
2475          graph, config, force_gpu, crash_if_inconsistent_args=True)
2476      with self._constrain_devices_and_set_default(sess, use_gpu,
2477                                                   force_gpu) as cached:
2478        yield cached
2479
2480  @contextlib.contextmanager
2481  @deprecation.deprecated(None, "Use `self.session()` or "
2482                          "`self.cached_session()` instead.")
2483  def test_session(self,
2484                   graph=None,
2485                   config=None,
2486                   use_gpu=True,
2487                   force_gpu=False):
2488    """Use cached_session instead."""
2489    if self.id().endswith(".test_session"):
2490      self.skipTest(
2491          "Tests that have the name \"test_session\" are automatically skipped "
2492          "by TensorFlow test fixture, as the name is reserved for creating "
2493          "sessions within tests. Please rename your test if you have a test "
2494          "with this name.")
2495    if context.executing_eagerly():
2496      yield None
2497    else:
2498      if graph is None:
2499        sess = self._get_cached_session(
2500            graph, config, force_gpu, crash_if_inconsistent_args=False)
2501        with self._constrain_devices_and_set_default(sess, use_gpu,
2502                                                     force_gpu) as cached:
2503          yield cached
2504      else:
2505        with self.session(graph, config, use_gpu, force_gpu) as sess:
2506          yield sess
2507
2508  # pylint: enable=g-doc-return-or-yield
2509
2510  class _CheckedThread(object):
2511    """A wrapper class for Thread that asserts successful completion.
2512
2513    This class should be created using the TensorFlowTestCase.checkedThread()
2514    method.
2515    """
2516
2517    def __init__(self, testcase, target, args=None, kwargs=None):
2518      """Constructs a new instance of _CheckedThread.
2519
2520      Args:
2521        testcase: The TensorFlowTestCase for which this thread is being created.
2522        target: A callable object representing the code to be executed in the
2523          thread.
2524        args: A tuple of positional arguments that will be passed to target.
2525        kwargs: A dictionary of keyword arguments that will be passed to target.
2526      """
2527      self._testcase = testcase
2528      self._target = target
2529      self._args = () if args is None else args
2530      self._kwargs = {} if kwargs is None else kwargs
2531      self._thread = threading.Thread(target=self._protected_run)
2532      self._exception = None
2533
2534      self._is_thread_joined = False
2535
2536    def _protected_run(self):
2537      """Target for the wrapper thread. Sets self._exception on failure."""
2538      try:
2539        self._target(*self._args, **self._kwargs)
2540      except Exception as e:  # pylint: disable=broad-except
2541        self._exception = e
2542
2543    def start(self):
2544      """Starts the thread's activity.
2545
2546      This must be called at most once per _CheckedThread object. It arranges
2547      for the object's target to be invoked in a separate thread of control.
2548      """
2549      self._thread.start()
2550
2551    def join(self):
2552      """Blocks until the thread terminates.
2553
2554      Raises:
2555        self._testcase.failureException: If the thread terminates with due to
2556          an exception.
2557      """
2558      self._is_thread_joined = True
2559      self._thread.join()
2560      if self._exception is not None:
2561        self._testcase.fail("Error in checkedThread: %s" % str(self._exception))
2562
2563    def is_alive(self):
2564      """Returns whether the thread is alive.
2565
2566      This method returns True just before the run() method starts
2567      until just after the run() method terminates.
2568
2569      Returns:
2570        True if the thread is alive, otherwise False.
2571      """
2572      return self._thread.is_alive()
2573
2574    def check_termination(self):
2575      """Returns whether the checked thread was properly used and did terminate.
2576
2577      Every checked thread should be "join"ed after starting, and before the
2578      test tears down. If it is not joined, it is possible the thread will hang
2579      and cause flaky failures in tests.
2580
2581      Raises:
2582        self._testcase.failureException: If check_termination was called before
2583        thread was joined.
2584
2585        RuntimeError: If the thread is not terminated. This means thread was not
2586        joined with the main thread.
2587      """
2588      if self._is_thread_joined:
2589        if self.is_alive():
2590          raise RuntimeError(
2591              "Thread was not joined with main thread, and is still running "
2592              "when the test finished.")
2593      else:
2594        self._testcase.fail("A checked thread was not joined.")
2595
2596  def checkedThread(self, target, args=None, kwargs=None):
2597    """Returns a Thread wrapper that asserts 'target' completes successfully.
2598
2599    This method should be used to create all threads in test cases, as
2600    otherwise there is a risk that a thread will silently fail, and/or
2601    assertions made in the thread will not be respected.
2602
2603    Args:
2604      target: A callable object to be executed in the thread.
2605      args: The argument tuple for the target invocation. Defaults to ().
2606      kwargs: A dictionary of keyword arguments for the target invocation.
2607        Defaults to {}.
2608
2609    Returns:
2610      A wrapper for threading.Thread that supports start() and join() methods.
2611    """
2612    ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs)
2613    self._threads.append(ret)
2614    return ret
2615
2616  # pylint: enable=invalid-name
2617  @py_func_if_in_function
2618  def assertNear(self, f1, f2, err, msg=None):
2619    """Asserts that two floats are near each other.
2620
2621    Checks that |f1 - f2| < err and asserts a test failure
2622    if not.
2623
2624    Args:
2625      f1: A float value.
2626      f2: A float value.
2627      err: A float value.
2628      msg: An optional string message to append to the failure message.
2629    """
2630    # f1 == f2 is needed here as we might have: f1, f2 = inf, inf
2631    self.assertTrue(
2632        f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" %
2633        (f1, f2, err, " (%s)" % msg if msg is not None else ""))
2634
2635  @py_func_if_in_function
2636  def assertArrayNear(self, farray1, farray2, err, msg=None):
2637    """Asserts that two float arrays are near each other.
2638
2639    Checks that for all elements of farray1 and farray2
2640    |f1 - f2| < err.  Asserts a test failure if not.
2641
2642    Args:
2643      farray1: a list of float values.
2644      farray2: a list of float values.
2645      err: a float value.
2646      msg: Optional message to report on failure.
2647    """
2648    self.assertEqual(len(farray1), len(farray2), msg=msg)
2649    for f1, f2 in zip(farray1, farray2):
2650      self.assertNear(float(f1), float(f2), err, msg=msg)
2651
2652  def _NDArrayNear(self, ndarray1, ndarray2, err):
2653    return np.linalg.norm(ndarray1 - ndarray2) < err
2654
2655  @py_func_if_in_function
2656  def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
2657    """Asserts that two numpy arrays have near values.
2658
2659    Args:
2660      ndarray1: a numpy ndarray.
2661      ndarray2: a numpy ndarray.
2662      err: a float. The maximum absolute difference allowed.
2663      msg: Optional message to report on failure.
2664    """
2665    self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg)
2666
2667  def _GetNdArray(self, a):
2668    # If a is tensor-like then convert it to ndarray
2669    if tensor_util.is_tf_type(a):
2670      if isinstance(a, ops._EagerTensorBase):
2671        a = a.numpy()
2672      else:
2673        a = self.evaluate(a)
2674    if not isinstance(a, np.ndarray):
2675      return np.array(a)
2676    return a
2677
2678  def evaluate_if_both_tensors(self, a, b):
2679    if (tensor_util.is_tf_type(a) and tensor_util.is_tf_type(b) and
2680        not isinstance(a, ops._EagerTensorBase) and
2681        not isinstance(b, ops._EagerTensorBase)):
2682      return self.evaluate((a, b))
2683    else:
2684      return (a, b)
2685
2686  def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2687    (a, b) = self.evaluate_if_both_tensors(a, b)
2688    a = self._GetNdArray(a)
2689    b = self._GetNdArray(b)
2690    # When the array rank is small, print its contents. Numpy array printing is
2691    # implemented using inefficient recursion so prints can cause tests to
2692    # time out.
2693    if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
2694      shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
2695                            "%s.") % (a.shape, b.shape, b)
2696    else:
2697      shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
2698                                                                     b.shape)
2699    self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
2700
2701    msgs = [msg]
2702    # np.allclose does not always work for our custom bfloat16 extension type
2703    # when type promotions are involved, so we first cast any bfloat16 arrays
2704    # to float32.
2705    a_dtype = a.dtype
2706    a = a.astype(np.float32) if a.dtype == dtypes.bfloat16.as_numpy_dtype else a
2707    b = b.astype(np.float32) if b.dtype == dtypes.bfloat16.as_numpy_dtype else b
2708    if not np.allclose(a, b, rtol=rtol, atol=atol):
2709      # Adds more details to np.testing.assert_allclose.
2710      #
2711      # NOTE: numpy.allclose (and numpy.testing.assert_allclose)
2712      # checks whether two arrays are element-wise equal within a
2713      # tolerance. The relative difference (rtol * abs(b)) and the
2714      # absolute difference atol are added together to compare against
2715      # the absolute difference between a and b.  Here, we want to
2716      # tell user which elements violate such conditions.
2717      cond = np.logical_or(
2718          np.abs(a - b) > atol + rtol * np.abs(b),
2719          np.isnan(a) != np.isnan(b))
2720      if a.ndim:
2721        x = a[np.where(cond)]
2722        y = b[np.where(cond)]
2723        msgs.append("not close where = {}".format(np.where(cond)))
2724      else:
2725        # np.where is broken for scalars
2726        x, y = a, b
2727      msgs.append("not close lhs = {}".format(x))
2728      msgs.append("not close rhs = {}".format(y))
2729      msgs.append("not close dif = {}".format(np.abs(x - y)))
2730      msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
2731      msgs.append("dtype = {}, shape = {}".format(a_dtype, a.shape))
2732      # TODO(xpan): There seems to be a bug:
2733      # tensorflow/compiler/tests:binary_ops_test pass with float32
2734      # nan even though the equal_nan is False by default internally.
2735      np.testing.assert_allclose(
2736          a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True)
2737
2738  def _assertAllCloseRecursive(self,
2739                               a,
2740                               b,
2741                               rtol=1e-6,
2742                               atol=1e-6,
2743                               path=None,
2744                               msg=None):
2745    path = path or []
2746    path_str = (("[" + "][".join(str(p) for p in path) + "]") if path else "")
2747    msg = msg if msg else ""
2748
2749    # Check if a and/or b are namedtuples.
2750    if hasattr(a, "_asdict"):
2751      a = a._asdict()
2752    if hasattr(b, "_asdict"):
2753      b = b._asdict()
2754    a_is_dict = isinstance(a, collections_abc.Mapping)
2755    if a_is_dict != isinstance(b, collections_abc.Mapping):
2756      raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
2757                       (path_str, path_str, msg))
2758    if a_is_dict:
2759      self.assertItemsEqual(
2760          a.keys(),
2761          b.keys(),
2762          msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" %
2763          (path_str, a.keys(), path_str, b.keys(), msg))
2764      for k in a:
2765        path.append(k)
2766        self._assertAllCloseRecursive(
2767            a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg)
2768        del path[-1]
2769    elif isinstance(a, (list, tuple)):
2770      # Try to directly compare a, b as ndarrays; if not work, then traverse
2771      # through the sequence, which is more expensive.
2772      try:
2773        (a, b) = self.evaluate_if_both_tensors(a, b)
2774        a_as_ndarray = self._GetNdArray(a)
2775        b_as_ndarray = self._GetNdArray(b)
2776        self._assertArrayLikeAllClose(
2777            a_as_ndarray,
2778            b_as_ndarray,
2779            rtol=rtol,
2780            atol=atol,
2781            msg="Mismatched value: a%s is different from b%s. %s" %
2782            (path_str, path_str, msg))
2783      except (ValueError, TypeError) as e:
2784        if len(a) != len(b):
2785          raise ValueError(
2786              "Mismatched length: a%s has %d items, but b%s has %d items. %s" %
2787              (path_str, len(a), path_str, len(b), msg))
2788        for idx, (a_ele, b_ele) in enumerate(zip(a, b)):
2789          path.append(str(idx))
2790          self._assertAllCloseRecursive(
2791              a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg)
2792          del path[-1]
2793    # a and b are ndarray like objects
2794    else:
2795      try:
2796        self._assertArrayLikeAllClose(
2797            a,
2798            b,
2799            rtol=rtol,
2800            atol=atol,
2801            msg=("Mismatched value: a%s is different from b%s. %s" %
2802                 (path_str, path_str, msg)))
2803      except TypeError as e:
2804        msg = ("Error: a%s has %s, but b%s has %s. %s" %
2805               (path_str, type(a), path_str, type(b), msg))
2806        e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
2807        raise
2808
2809  @py_func_if_in_function
2810  def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2811    """Asserts that two structures of numpy arrays or Tensors, have near values.
2812
2813    `a` and `b` can be arbitrarily nested structures. A layer of a nested
2814    structure can be a `dict`, `namedtuple`, `tuple` or `list`.
2815
2816    Note: the implementation follows
2817    [`numpy.allclose`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html)
2818    (and numpy.testing.assert_allclose). It checks whether two arrays are
2819    element-wise equal within a tolerance. The relative difference
2820    (`rtol * abs(b)`) and the absolute difference `atol` are added together
2821    to compare against the absolute difference between `a` and `b`.
2822
2823    Args:
2824      a: The expected numpy `ndarray`, or anything that can be converted into a
2825        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2826        structure of these.
2827      b: The actual numpy `ndarray`, or anything that can be converted into a
2828        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2829        structure of these.
2830      rtol: relative tolerance.
2831      atol: absolute tolerance.
2832      msg: Optional message to report on failure.
2833
2834    Raises:
2835      ValueError: if only one of `a[p]` and `b[p]` is a dict or
2836          `a[p]` and `b[p]` have different length, where `[p]` denotes a path
2837          to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
2838          `[p] = [1]['d']`, then `a[p] = (6, 7)`.
2839    """
2840    if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b):
2841      return self._assertRaggedClose(a, b, rtol, atol, msg)
2842    self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
2843
2844  @py_func_if_in_function
2845  def assertAllCloseAccordingToType(self,
2846                                    a,
2847                                    b,
2848                                    rtol=1e-6,
2849                                    atol=1e-6,
2850                                    float_rtol=1e-6,
2851                                    float_atol=1e-6,
2852                                    half_rtol=1e-3,
2853                                    half_atol=1e-3,
2854                                    bfloat16_rtol=1e-2,
2855                                    bfloat16_atol=1e-2,
2856                                    msg=None):
2857    """Like assertAllClose, but also suitable for comparing fp16 arrays.
2858
2859    In particular, the tolerance is reduced to 1e-3 if at least
2860    one of the arguments is of type float16.
2861
2862    Args:
2863      a: the expected numpy ndarray or anything can be converted to one.
2864      b: the actual numpy ndarray or anything can be converted to one.
2865      rtol: relative tolerance.
2866      atol: absolute tolerance.
2867      float_rtol: relative tolerance for float32.
2868      float_atol: absolute tolerance for float32.
2869      half_rtol: relative tolerance for float16.
2870      half_atol: absolute tolerance for float16.
2871      bfloat16_rtol: relative tolerance for bfloat16.
2872      bfloat16_atol: absolute tolerance for bfloat16.
2873      msg: Optional message to report on failure.
2874    """
2875    (a, b) = self.evaluate_if_both_tensors(a, b)
2876    a = self._GetNdArray(a)
2877    b = self._GetNdArray(b)
2878    # types with lower tol are put later to overwrite previous ones.
2879    if (a.dtype == np.float32 or b.dtype == np.float32 or
2880        a.dtype == np.complex64 or b.dtype == np.complex64):
2881      rtol = max(rtol, float_rtol)
2882      atol = max(atol, float_atol)
2883    if a.dtype == np.float16 or b.dtype == np.float16:
2884      rtol = max(rtol, half_rtol)
2885      atol = max(atol, half_atol)
2886    if (a.dtype == dtypes.bfloat16.as_numpy_dtype or
2887        b.dtype == dtypes.bfloat16.as_numpy_dtype):
2888      rtol = max(rtol, bfloat16_rtol)
2889      atol = max(atol, bfloat16_atol)
2890
2891    self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
2892
2893  @py_func_if_in_function
2894  def assertNotAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2895    """Assert that two numpy arrays, or Tensors, do not have near values.
2896
2897    Args:
2898      a: The expected numpy `ndarray`, or anything that can be converted into a
2899        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2900        structure of these.
2901      b: The actual numpy `ndarray`, or anything that can be converted into a
2902        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2903        structure of these.
2904      rtol: relative tolerance.
2905      atol: absolute tolerance.
2906      msg: Optional message to report on failure.
2907
2908    Raises:
2909      AssertionError: If `a` and `b` are unexpectedly close at all elements.
2910    """
2911    try:
2912      self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
2913    except AssertionError:
2914      return
2915    msg = msg or ""
2916    raise AssertionError("The two values are close at all elements. %s" % msg)
2917
2918  @py_func_if_in_function
2919  def assertAllEqual(self, a, b, msg=None):
2920    """Asserts that two numpy arrays or Tensors have the same values.
2921
2922    Args:
2923      a: the expected numpy ndarray or anything can be converted to one.
2924      b: the actual numpy ndarray or anything can be converted to one.
2925      msg: Optional message to report on failure.
2926    """
2927    if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)):
2928      return self._assertRaggedEqual(a, b, msg)
2929    msg = msg if msg else ""
2930    (a, b) = self.evaluate_if_both_tensors(a, b)
2931    a = self._GetNdArray(a)
2932    b = self._GetNdArray(b)
2933    # Arbitrary bounds so that we don't print giant tensors.
2934    if (b.ndim <= 3 or b.size < 500):
2935      self.assertEqual(
2936          a.shape, b.shape, "Shape mismatch: expected %s, got %s."
2937          " Contents: %r. \n%s." % (a.shape, b.shape, b, msg))
2938    else:
2939      self.assertEqual(
2940          a.shape, b.shape, "Shape mismatch: expected %s, got %s."
2941          " %s" % (a.shape, b.shape, msg))
2942
2943    same = (a == b)
2944
2945    if (a.dtype in [
2946        np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
2947    ]):
2948      same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
2949    msgs = [msg]
2950    if not np.all(same):
2951      # Adds more details to np.testing.assert_array_equal.
2952      diff = np.logical_not(same)
2953      if a.ndim:
2954        x = a[np.where(diff)]
2955        y = b[np.where(diff)]
2956        msgs.append("not equal where = {}".format(np.where(diff)))
2957      else:
2958        # np.where is broken for scalars
2959        x, y = a, b
2960      msgs.append("not equal lhs = %r" % x)
2961      msgs.append("not equal rhs = %r" % y)
2962
2963      # Handle mixed string types as a result of PY2to3 migration. That is, the
2964      # mixing between bytes (b-prefix strings, PY2 default) and unicodes
2965      # (u-prefix strings, PY3 default).
2966      if six.PY3:
2967        if (a.dtype.kind != b.dtype.kind and
2968            {a.dtype.kind, b.dtype.kind}.issubset({"U", "S", "O"})):
2969          a_list = []
2970          b_list = []
2971          # OK to flatten `a` and `b` because they are guaranteed to have the
2972          # same shape.
2973          for out_list, flat_arr in [(a_list, a.flat), (b_list, b.flat)]:
2974            for item in flat_arr:
2975              if isinstance(item, str):
2976                out_list.append(item.encode("utf-8"))
2977              else:
2978                out_list.append(item)
2979          a = np.array(a_list)
2980          b = np.array(b_list)
2981
2982      np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
2983
2984  @py_func_if_in_function
2985  def assertNotAllEqual(self, a, b, msg=None):
2986    """Asserts that two numpy arrays or Tensors do not have the same values.
2987
2988    Args:
2989      a: the expected numpy ndarray or anything can be converted to one.
2990      b: the actual numpy ndarray or anything can be converted to one.
2991      msg: Optional message to report on failure.
2992    """
2993    try:
2994      self.assertAllEqual(a, b)
2995    except AssertionError:
2996      return
2997    raise AssertionError("The two values are equal at all elements. %s" % msg)
2998
2999  @py_func_if_in_function
3000  def assertAllGreater(self, a, comparison_target):
3001    """Assert element values are all greater than a target value.
3002
3003    Args:
3004      a: The numpy `ndarray`, or anything that can be converted into a numpy
3005        `ndarray` (including Tensor).
3006      comparison_target: The target value of comparison.
3007    """
3008    (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3009    a = self._GetNdArray(a)
3010    self.assertGreater(np.min(a), comparison_target)
3011
3012  @py_func_if_in_function
3013  def assertAllLess(self, a, comparison_target):
3014    """Assert element values are all less than a target value.
3015
3016    Args:
3017      a: The numpy `ndarray`, or anything that can be converted into a numpy
3018        `ndarray` (including Tensor).
3019      comparison_target: The target value of comparison.
3020    """
3021    (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3022    a = self._GetNdArray(a)
3023    self.assertLess(np.max(a), comparison_target)
3024
3025  @py_func_if_in_function
3026  def assertAllGreaterEqual(self, a, comparison_target):
3027    """Assert element values are all greater than or equal to a target value.
3028
3029    Args:
3030      a: The numpy `ndarray`, or anything that can be converted into a numpy
3031        `ndarray` (including Tensor).
3032      comparison_target: The target value of comparison.
3033    """
3034    (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3035    a = self._GetNdArray(a)
3036    self.assertGreaterEqual(np.min(a), comparison_target)
3037
3038  @py_func_if_in_function
3039  def assertAllLessEqual(self, a, comparison_target):
3040    """Assert element values are all less than or equal to a target value.
3041
3042    Args:
3043      a: The numpy `ndarray`, or anything that can be converted into a numpy
3044        `ndarray` (including Tensor).
3045      comparison_target: The target value of comparison.
3046    """
3047    (a, comparison_target) = self.evaluate_if_both_tensors(a, comparison_target)
3048    a = self._GetNdArray(a)
3049    self.assertLessEqual(np.max(a), comparison_target)
3050
3051  def _format_subscripts(self, subscripts, value, limit=10, indent=2):
3052    """Generate a summary of ndarray subscripts as a list of str.
3053
3054    If limit == N, this method will print up to the first N subscripts on
3055    separate
3056    lines. A line of ellipses (...) will be appended at the end if the number of
3057    subscripts exceeds N.
3058
3059    Args:
3060      subscripts: The tensor (np.ndarray) subscripts, of the same format as
3061        np.where()'s return value, i.e., a tuple of arrays with each array
3062        corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])).
3063      value: (np.ndarray) value of the tensor.
3064      limit: (int) The maximum number of indices to print.
3065      indent: (int) Number of characters to indent at the beginning of each
3066        line.
3067
3068    Returns:
3069      (list of str) the multi-line representation of the subscripts and values,
3070        potentially with omission at the end.
3071    """
3072    lines = []
3073    subscripts = np.transpose(subscripts)
3074    prefix = " " * indent
3075    if np.ndim(value) == 0:
3076      return [prefix + "[0] : " + str(value)]
3077    for subscript in itertools.islice(subscripts, limit):
3078      lines.append(prefix + str(subscript) + " : " +
3079                   str(value[tuple(subscript)]))
3080    if len(subscripts) > limit:
3081      lines.append(prefix + "...")
3082    return lines
3083
3084  @py_func_if_in_function
3085  def assertAllInRange(self,
3086                       target,
3087                       lower_bound,
3088                       upper_bound,
3089                       open_lower_bound=False,
3090                       open_upper_bound=False):
3091    """Assert that elements in a Tensor are all in a given range.
3092
3093    Args:
3094      target: The numpy `ndarray`, or anything that can be converted into a
3095        numpy `ndarray` (including Tensor).
3096      lower_bound: lower bound of the range
3097      upper_bound: upper bound of the range
3098      open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather
3099        than the default >=)
3100      open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather
3101        than the default <=)
3102
3103    Raises:
3104      AssertionError:
3105        if the value tensor does not have an ordered numeric type (float* or
3106          int*), or
3107        if there are nan values, or
3108        if any of the elements do not fall in the specified range.
3109    """
3110    target = self._GetNdArray(target)
3111    if not (np.issubdtype(target.dtype, np.floating) or
3112            np.issubdtype(target.dtype, np.integer)):
3113      raise AssertionError(
3114          "The value of %s does not have an ordered numeric type, instead it "
3115          "has type: %s" % (target, target.dtype))
3116
3117    nan_subscripts = np.where(np.isnan(target))
3118    if np.size(nan_subscripts):
3119      raise AssertionError(
3120          "%d of the %d element(s) are NaN. "
3121          "Subscripts(s) and value(s) of the NaN element(s):\n" %
3122          (len(nan_subscripts[0]), np.size(target)) +
3123          "\n".join(self._format_subscripts(nan_subscripts, target)))
3124
3125    range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " +
3126                 str(upper_bound) + (")" if open_upper_bound else "]"))
3127
3128    violations = (
3129        np.less_equal(target, lower_bound) if open_lower_bound else np.less(
3130            target, lower_bound))
3131    violations = np.logical_or(
3132        violations,
3133        np.greater_equal(target, upper_bound)
3134        if open_upper_bound else np.greater(target, upper_bound))
3135    violation_subscripts = np.where(violations)
3136    if np.size(violation_subscripts):
3137      raise AssertionError(
3138          "%d of the %d element(s) are outside the range %s. " %
3139          (len(violation_subscripts[0]), np.size(target), range_str) +
3140          "Subscript(s) and value(s) of the offending elements:\n" +
3141          "\n".join(self._format_subscripts(violation_subscripts, target)))
3142
3143  @py_func_if_in_function
3144  def assertAllInSet(self, target, expected_set):
3145    """Assert that elements of a Tensor are all in a given closed set.
3146
3147    Args:
3148      target: The numpy `ndarray`, or anything that can be converted into a
3149        numpy `ndarray` (including Tensor).
3150      expected_set: (`list`, `tuple` or `set`) The closed set that the elements
3151        of the value of `target` are expected to fall into.
3152
3153    Raises:
3154      AssertionError:
3155        if any of the elements do not fall into `expected_set`.
3156    """
3157    target = self._GetNdArray(target)
3158
3159    # Elements in target that are not in expected_set.
3160    diff = np.setdiff1d(target.flatten(), list(expected_set))
3161    if np.size(diff):
3162      raise AssertionError("%d unique element(s) are not in the set %s: %s" %
3163                           (np.size(diff), expected_set, diff))
3164
3165  @py_func_if_in_function
3166  def assertDTypeEqual(self, target, expected_dtype):
3167    """Assert ndarray data type is equal to expected.
3168
3169    Args:
3170      target: The numpy `ndarray`, or anything that can be converted into a
3171        numpy `ndarray` (including Tensor).
3172      expected_dtype: Expected data type.
3173    """
3174    target = self._GetNdArray(target)
3175    if not isinstance(target, list):
3176      arrays = [target]
3177    for arr in arrays:
3178      self.assertEqual(arr.dtype, expected_dtype)
3179
3180  # pylint: disable=g-doc-return-or-yield
3181  @contextlib.contextmanager
3182  def assertRaisesWithPredicateMatch(self, exception_type,
3183                                     expected_err_re_or_predicate):
3184    """Returns a context manager to enclose code expected to raise an exception.
3185
3186    If the exception is an OpError, the op stack is also included in the message
3187    predicate search.
3188
3189    Args:
3190      exception_type: The expected type of exception that should be raised.
3191      expected_err_re_or_predicate: If this is callable, it should be a function
3192        of one argument that inspects the passed-in exception and returns True
3193        (success) or False (please fail the test). Otherwise, the error message
3194        is expected to match this regular expression partially.
3195
3196    Returns:
3197      A context manager to surround code that is expected to raise an
3198      exception.
3199    """
3200    if callable(expected_err_re_or_predicate):
3201      predicate = expected_err_re_or_predicate
3202    else:
3203
3204      def predicate(e):
3205        err_str = e.message if isinstance(e, errors.OpError) else str(e)
3206        op = e.op if isinstance(e, errors.OpError) else None
3207        while op is not None:
3208          err_str += "\nCaused by: " + op.name
3209          op = op._original_op  # pylint: disable=protected-access
3210        logging.info("Searching within error strings: '%s' within '%s'",
3211                     expected_err_re_or_predicate, err_str)
3212        return re.search(expected_err_re_or_predicate, err_str)
3213
3214    try:
3215      yield
3216      self.fail(exception_type.__name__ + " not raised")
3217    except Exception as e:  # pylint: disable=broad-except
3218      if not isinstance(e, exception_type) or not predicate(e):
3219        raise AssertionError("Exception of type %s: %s" %
3220                             (str(type(e)), str(e)))
3221
3222  # pylint: enable=g-doc-return-or-yield
3223
3224  def assertRaisesOpError(self, expected_err_re_or_predicate):
3225    return self.assertRaisesWithPredicateMatch(errors.OpError,
3226                                               expected_err_re_or_predicate)
3227
3228  def assertRaisesIncompatibleShapesError(
3229      self, exception_type=errors.InvalidArgumentError):
3230    return self.assertRaisesWithPredicateMatch(
3231        exception_type, r"Incompatible shapes|Dimensions must be equal|"
3232        r"required broadcastable shapes")
3233
3234  def assertShapeEqual(self, np_array, tf_tensor, msg=None):
3235    """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape.
3236
3237    Args:
3238      np_array: A Numpy ndarray or Numpy scalar.
3239      tf_tensor: A Tensor.
3240      msg: Optional message to report on failure.
3241
3242    Raises:
3243      TypeError: If the arguments have the wrong type.
3244    """
3245    if not isinstance(np_array, (np.ndarray, np.generic)):
3246      raise TypeError("np_array must be a Numpy ndarray or Numpy scalar")
3247    if not isinstance(tf_tensor, ops.Tensor):
3248      raise TypeError("tf_tensor must be a Tensor")
3249    self.assertAllEqual(
3250        np_array.shape, tf_tensor.get_shape().as_list(), msg=msg)
3251
3252  def assertDeviceEqual(self, device1, device2, msg=None):
3253    """Asserts that the two given devices are the same.
3254
3255    Args:
3256      device1: A string device name or TensorFlow `DeviceSpec` object.
3257      device2: A string device name or TensorFlow `DeviceSpec` object.
3258      msg: Optional message to report on failure.
3259    """
3260    device1 = pydev.canonical_name(device1)
3261    device2 = pydev.canonical_name(device2)
3262    self.assertEqual(
3263        device1, device2,
3264        "Devices %s and %s are not equal. %s" % (device1, device2, msg))
3265
3266  def _GetPyList(self, a):
3267    """Converts `a` to a nested python list."""
3268    if isinstance(a, ragged_tensor.RaggedTensor):
3269      return self.evaluate(a).to_list()
3270    elif isinstance(a, ops.Tensor):
3271      a = self.evaluate(a)
3272      return a.tolist() if isinstance(a, np.ndarray) else a
3273    elif isinstance(a, np.ndarray):
3274      return a.tolist()
3275    elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
3276      return a.to_list()
3277    else:
3278      return np.array(a).tolist()
3279
3280  def _assertRaggedEqual(self, a, b, msg):
3281    """Asserts that two ragged tensors are equal."""
3282    a_list = self._GetPyList(a)
3283    b_list = self._GetPyList(b)
3284    self.assertEqual(a_list, b_list, msg)
3285
3286    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
3287      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
3288      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
3289      self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
3290
3291  def _assertRaggedClose(self, a, b, rtol, atol, msg=None):
3292    a_list = self._GetPyList(a)
3293    b_list = self._GetPyList(b)
3294    self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg)
3295
3296    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
3297      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
3298      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
3299      self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
3300
3301  def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"):
3302    self.assertEqual(type(a), type(b))
3303    if isinstance(a, (list, tuple)):
3304      self.assertLen(a, len(b), "Length differs for %s" % path)
3305      for i in range(len(a)):
3306        self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg,
3307                                       "%s[%s]" % (path, i))
3308    else:
3309      self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
3310
3311  # Fix Python 3+ compatibility issues
3312  if not six.PY2:
3313    # pylint: disable=invalid-name
3314
3315    # Silence a deprecation warning
3316    assertRaisesRegexp = googletest.TestCase.assertRaisesRegex
3317
3318    # assertItemsEqual is assertCountEqual as of 3.2.
3319    assertItemsEqual = googletest.TestCase.assertCountEqual
3320
3321    # pylint: enable=invalid-name
3322
3323  @contextlib.contextmanager
3324  def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
3325    """Set the session and its graph to global default and constrain devices."""
3326    if context.executing_eagerly():
3327      yield None
3328    else:
3329      with sess.graph.as_default(), sess.as_default():
3330        if force_gpu:
3331          # Use the name of an actual device if one is detected, or
3332          # '/device:GPU:0' otherwise
3333          gpu_name = gpu_device_name()
3334          if not gpu_name:
3335            gpu_name = "/device:GPU:0"
3336          with sess.graph.device(gpu_name):
3337            yield sess
3338        elif use_gpu:
3339          yield sess
3340        else:
3341          with sess.graph.device("/device:CPU:0"):
3342            yield sess
3343
3344  def _create_session(self, graph, config, force_gpu):
3345    """See session() for details."""
3346
3347    def prepare_config(config):
3348      """Returns a config for sessions.
3349
3350      Args:
3351        config: An optional config_pb2.ConfigProto to use to configure the
3352          session.
3353
3354      Returns:
3355        A config_pb2.ConfigProto object.
3356      """
3357      # TODO(b/114333779): Enforce allow_soft_placement=False when
3358      # use_gpu=False. Currently many tests rely on the fact that any device
3359      # will be used even when a specific device is supposed to be used.
3360      allow_soft_placement = not force_gpu
3361      if config is None:
3362        config = context.context().config
3363        config.allow_soft_placement = allow_soft_placement
3364      elif not allow_soft_placement and config.allow_soft_placement:
3365        config_copy = context.context().config
3366        config = config_copy
3367        config.allow_soft_placement = False
3368      # Don't perform optimizations for tests so we don't inadvertently run
3369      # gpu ops on cpu
3370      config.graph_options.optimizer_options.opt_level = -1
3371      # Disable Grappler constant folding since some tests & benchmarks
3372      # use constant input and become meaningless after constant folding.
3373      # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
3374      # GRAPPLER TEAM.
3375      config.graph_options.rewrite_options.constant_folding = (
3376          rewriter_config_pb2.RewriterConfig.OFF)
3377      config.graph_options.rewrite_options.pin_to_host_optimization = (
3378          rewriter_config_pb2.RewriterConfig.OFF)
3379      return config
3380
3381    return ErrorLoggingSession(graph=graph, config=prepare_config(config))
3382
3383  def _get_cached_session(self,
3384                          graph=None,
3385                          config=None,
3386                          force_gpu=False,
3387                          crash_if_inconsistent_args=True):
3388    """See cached_session() for documentation."""
3389    if self._cached_session is None:
3390      sess = self._create_session(
3391          graph=graph, config=config, force_gpu=force_gpu)
3392      self._cached_session = sess
3393      self._cached_graph = graph
3394      self._cached_config = config
3395      self._cached_force_gpu = force_gpu
3396      return sess
3397    else:
3398      if crash_if_inconsistent_args and self._cached_graph is not graph:
3399        raise ValueError("The graph used to get the cached session is "
3400                         "different than the one that was used to create the "
3401                         "session. Maybe create a new session with "
3402                         "self.session()")
3403      if crash_if_inconsistent_args and self._cached_config is not config:
3404        raise ValueError("The config used to get the cached session is "
3405                         "different than the one that was used to create the "
3406                         "session. Maybe create a new session with "
3407                         "self.session()")
3408      if crash_if_inconsistent_args and (self._cached_force_gpu is
3409                                         not force_gpu):
3410        raise ValueError(
3411            "The force_gpu value used to get the cached session is "
3412            "different than the one that was used to create the "
3413            "session. Maybe create a new session with "
3414            "self.session()")
3415      return self._cached_session
3416
3417
3418@tf_export("test.create_local_cluster")
3419def create_local_cluster(num_workers,
3420                         num_ps,
3421                         protocol="grpc",
3422                         worker_config=None,
3423                         ps_config=None):
3424  """Create and start local servers and return the associated `Server` objects.
3425
3426  "PS" stands for "parameter server": a task responsible for storing and
3427  updating the model's parameters. Other tasks send updates to these parameters
3428  as they work on optimizing the parameters. This particular division of labor
3429  between tasks is not required, but is common for distributed training.
3430
3431  Read more at https://www.tensorflow.org/guide/extend/architecture
3432
3433  ![components](https://www.tensorflow.org/images/diag1.svg "components")
3434
3435
3436  Figure illustrates the interaction of these components.
3437  "/job:worker/task:0" and "/job:ps/task:0" are both tasks with worker services.
3438
3439
3440  Example:
3441  ```python
3442  workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2)
3443
3444  worker_sessions = [tf.compat.v1.Session(w.target) for w in workers]
3445
3446  with tf.device("/job:ps/task:0"):
3447    ...
3448  with tf.device("/job:ps/task:1"):
3449    ...
3450  with tf.device("/job:worker/task:0"):
3451    ...
3452  with tf.device("/job:worker/task:1"):
3453    ...
3454
3455  worker_sessions[0].run(...)
3456  ```
3457
3458  Args:
3459    num_workers: Number of worker servers to start.
3460    num_ps: Number of PS servers to start.
3461    protocol: Communication protocol. Allowed values are documented in the
3462      documentation of `tf.distribute.Server`.
3463    worker_config: (optional) `tf.ConfigProto` to initialize workers. Can be
3464      used to instantiate multiple devices etc.
3465    ps_config: (optional) `tf.ConfigProto` to initialize PS servers.
3466
3467  Returns:
3468    A tuple `(worker_servers, ps_servers)`.  `worker_servers` is a list
3469    of `num_workers` objects of type `tf.distribute.Server` (all running
3470    locally);
3471    and `ps_servers` is a list of `num_ps` objects of similar type.
3472
3473  Raises:
3474    ImportError: if portpicker module was not found at load time
3475  """
3476  import portpicker  # pylint: disable=g-import-not-at-top
3477  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
3478  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
3479  cluster_dict = {
3480      "worker": ["localhost:%s" % port for port in worker_ports],
3481      "ps": ["localhost:%s" % port for port in ps_ports]
3482  }
3483  cs = server_lib.ClusterSpec(cluster_dict)
3484
3485  workers = [
3486      server_lib.Server(
3487          cs,
3488          job_name="worker",
3489          protocol=protocol,
3490          task_index=ix,
3491          config=worker_config,
3492          start=True) for ix in range(num_workers)
3493  ]
3494  ps_servers = [
3495      server_lib.Server(
3496          cs,
3497          job_name="ps",
3498          protocol=protocol,
3499          task_index=ix,
3500          config=ps_config,
3501          start=True) for ix in range(num_ps)
3502  ]
3503
3504  return workers, ps_servers
3505
3506
3507def get_node_def_from_graph(node_name, graph_def):
3508  """Returns the `NodeDef` instance for given node name in the graph def.
3509
3510  This method explores only the NodeDefs in `graph_def.node`.
3511
3512  Args:
3513    node_name: Name of the NodeDef to search for.
3514    graph_def: An instance of `GraphDef` proto.
3515
3516  Returns:
3517    the `NodeDef` instance whose name field matches the given node_name or None.
3518  """
3519  for node_def in graph_def.node:
3520    if node_def.name == node_name:
3521      return node_def
3522  return None
3523
3524
3525def set_producer_version(graph, producer_version):
3526  """Sets graph.graph_def_versions.producer to `producer_version`."""
3527  # The C API doesn't expose altering GraphDefVersions. We can indirectly set
3528  # it via import_graph_def though.
3529  graph_def = graph_pb2.GraphDef()
3530  graph_def.versions.producer = producer_version
3531  with graph.as_default():
3532    importer.import_graph_def(graph_def)
3533  assert graph.graph_def_versions.producer, producer_version
3534
3535
3536@contextlib.contextmanager
3537def _fake_gradient_tape_context_manager():
3538  """tf.gradients(...) implemented as tf.GradientTape context manager interface.
3539
3540  This is useful to test tf.gradients() in tests that uses tf.GradientTape().
3541
3542  Yields:
3543    gradient tape instance that's implemented by tf.gradients() underneath.
3544  """
3545  try:
3546    class FakeGradientTape:
3547
3548      def watch(self, x):
3549        pass
3550
3551      def gradient(self, y, x, grad_ys=None):
3552        result = gradients_impl.gradients(y, x, grad_ys)
3553
3554        # Unlike `tape.gradient()`, `tf.gradients()` returns a list for a single
3555        # element. So unpack if needed to match `tape.gradient()` behavior.
3556        if not isinstance(x, (list, tuple)):
3557          assert len(result) == 1
3558          return result[0]
3559
3560        return result
3561
3562    yield FakeGradientTape()
3563  finally:
3564    pass
3565
3566
3567class AbstractGradientTape:
3568  """Abstract GradientTape context manager that has multiple implementations.
3569
3570  This is useful to test both tf.GradientTape() and tf.gradients() without
3571  duplicating tests.
3572  """
3573
3574  def __init__(self, use_tape, persistent=False):
3575    self._use_tape = use_tape
3576    self._persistent = persistent
3577
3578  def __enter__(self):
3579    if self._use_tape:
3580      self._tape_impl = backprop.GradientTape(persistent=self._persistent)
3581    else:
3582      self._tape_impl = _fake_gradient_tape_context_manager()
3583    return self._tape_impl.__enter__()
3584
3585  def __exit__(self, exc_type, exc_val, exc_tb):
3586    self._tape_impl.__exit__(exc_type, exc_val, exc_tb)
3587
3588
3589@contextlib.contextmanager
3590def run_functions_eagerly(run_eagerly):
3591  """Runs functions eagerly if `run_eagerly` is true.
3592
3593  WARNING: Setting `run_eagerly` to True in tests running in V1 graph mode
3594  *WILL NOT* make the tf.function to run eagerly because eager is disabled by
3595  default in V1. Instead, tf.function will run as a traced graph function.
3596
3597  Ensures that the state (for running functions eagerly) is back to the initial
3598  `def_function.RUN_FUNCTIONS_EAGERLY` state.
3599
3600  Args:
3601    run_eagerly: Boolean determining whether to run the function eagerly or not.
3602
3603  Raises:
3604    ValueError if `run_eagerly` is not a boolean.
3605
3606  Yields:
3607    Nothing.
3608  """
3609  if not isinstance(run_eagerly, bool):
3610    raise ValueError(
3611        "Expected bool for `run_eagerly` but got {}".format(run_eagerly))
3612
3613  is_eager = context.executing_eagerly()
3614  if not is_eager and run_eagerly:
3615    logging.warning(
3616        "Running tf.function eagerly in V1 graph mode is not supported. "
3617        "tf.function will be run as a traced graph function.")
3618
3619  initial_state = def_function.functions_run_eagerly()
3620  def_function.run_functions_eagerly(run_eagerly)
3621  try:
3622    yield
3623  finally:
3624    def_function.run_functions_eagerly(initial_state)
3625